rocm_jax/.github/workflows/cloud-tpu-ci-nightly.yml
Skye Wanderman-Milne 120125f3dd Make pytest-xdist work on TPU and update Cloud TPU CI.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).

By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).

I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
2022-11-18 22:05:13 +00:00

71 lines
3.1 KiB
YAML

name: Cloud TPU nightly
on:
schedule:
- cron: "0 14 * * *" # daily at 7am PST
workflow_dispatch: # allows triggering the workflow run manually
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
permissions:
contents: read
jobs:
cloud-tpu-test:
runs-on: [self-hosted, tpu, v4-8]
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available.
jaxlib-version: [latest, nightly]
steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # ratchet:actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install JAX test requirements
run: |
pip install -r build/test-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu-nightly
if [ "${{ matrix.jaxlib-version }}" == "latest" ]; then
pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
pip install .
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install libtpu-nightly \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
-m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python -m pytest -m "multiaccelerator" --tb=short tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ failure() && github.ref_name == 'main' }}
run: |
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
--header 'Content-Type: application/json' \
--data-raw "{
'text': '\"$GITHUB_WORKFLOW\" job failed: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
}"