rocm_jax/.github/workflows/cloud-tpu-ci-nightly.yml
Skye Wanderman-Milne ef5e4a4035 Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00

76 lines
3.5 KiB
YAML

name: Cloud TPU nightly
on:
schedule:
- cron: "0 14 * * *" # daily at 7am PST
workflow_dispatch: # allows triggering the workflow run manually
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
permissions:
contents: read
jobs:
cloud-tpu-test:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available.
jaxlib-version: ["latest", "nightly"]
tpu-type: ["v3-8", "v4-8"]
pjrt: ["true", "false"]
name: "TPU test (${{ matrix.jaxlib-version }}, pjrt=${{ matrix.pjrt }}, ${{ matrix.tpu-type }})"
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # ratchet:actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install JAX test requirements
run: |
pip install -r build/test-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu-nightly
if [ "${{ matrix.jaxlib-version }}" == "latest" ]; then
pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
pip install .
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install --pre libtpu-nightly \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi
python -c 'import jax; print("jax version:", jax.__version__)'
python -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
python -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
JAX_USE_PJRT_C_API_ON_TPU: ${{ matrix.pjrt }}
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' }}
run: |
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
--header 'Content-Type: application/json' \
--data-raw "{
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu-type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
}"