mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

`profiler_test.py:ProfilerTest.test_remote_profiler` fails with the protobuf upgrade. However, I was seeing mysterious hangs without this, and in general I think we should be testing with up-to-date deps given that we don't pin. I'm gonna continue working on getting the Cloud TPU CI green.
75 lines
3.4 KiB
YAML
75 lines
3.4 KiB
YAML
name: CI - Cloud TPU (nightly)
|
|
on:
|
|
schedule:
|
|
- cron: "0 14 * * *" # daily at 7am PST
|
|
workflow_dispatch: # allows triggering the workflow run manually
|
|
# This should also be set to read-only in the project settings, but it's nice to
|
|
# document and enforce the permissions here.
|
|
permissions:
|
|
contents: read
|
|
jobs:
|
|
cloud-tpu-test:
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
jaxlib-version: ["latest", "nightly"]
|
|
tpu-type: ["v3-8", "v4-8", "nightly"]
|
|
pjrt: ["true", "false"]
|
|
name: "TPU test (${{ matrix.jaxlib-version }}, pjrt=${{ matrix.pjrt }}, ${{ matrix.tpu-type }})"
|
|
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
|
|
timeout-minutes: 60
|
|
steps:
|
|
# https://opensource.google/documentation/reference/github/services#actions
|
|
# mandates using a specific commit for non-Google actions. We use
|
|
# https://github.com/sethvargo/ratchet to pin specific versions.
|
|
- uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3
|
|
- name: Install JAX test requirements
|
|
run: |
|
|
pip install -U -r build/test-requirements.txt
|
|
pip install -U -r build/collect-profile-requirements.txt
|
|
- name: Install JAX
|
|
run: |
|
|
pip uninstall -y jax jaxlib libtpu-nightly
|
|
if [ "${{ matrix.jaxlib-version }}" == "latest" ]; then
|
|
pip install .[tpu] \
|
|
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
|
|
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
|
|
pip install .
|
|
pip install --pre jaxlib \
|
|
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
|
pip install --pre libtpu-nightly \
|
|
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
pip install requests
|
|
|
|
else
|
|
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
|
|
exit 1
|
|
fi
|
|
|
|
python3 -c 'import sys; print("python version:", sys.version)'
|
|
python3 -c 'import jax; print("jax version:", jax.__version__)'
|
|
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
|
|
python3 -c 'import jax; print("libtpu version:",
|
|
jax.lib.xla_bridge.get_backend().platform_version)'
|
|
- name: Run tests
|
|
env:
|
|
JAX_PLATFORMS: tpu,cpu
|
|
JAX_USE_PJRT_C_API_ON_TPU: ${{ matrix.pjrt }}
|
|
PY_COLORS: 1
|
|
run: |
|
|
# Run single-accelerator tests in parallel
|
|
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=4 --tb=short \
|
|
--maxfail=20 -m "not multiaccelerator" tests examples
|
|
# Run multi-accelerator across all chips
|
|
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
|
|
- name: Send chat on failure
|
|
# Don't notify when testing the workflow from a branch.
|
|
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' }}
|
|
run: |
|
|
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
|
|
--header 'Content-Type: application/json' \
|
|
--data-raw "{
|
|
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu-type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
|
|
}"
|