mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

Instead, we skip tests that the PJRT C API doesn't support. We had this tag for feature development so it was easy to broadly disable, but now we don't expect to need to do that.
76 lines
3.5 KiB
YAML
76 lines
3.5 KiB
YAML
name: Cloud TPU nightly
|
|
on:
|
|
schedule:
|
|
- cron: "0 14 * * *" # daily at 7am PST
|
|
workflow_dispatch: # allows triggering the workflow run manually
|
|
# This should also be set to read-only in the project settings, but it's nice to
|
|
# document and enforce the permissions here.
|
|
permissions:
|
|
contents: read
|
|
jobs:
|
|
cloud-tpu-test:
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available.
|
|
jaxlib-version: ["latest", "nightly"]
|
|
tpu-type: ["v3-8", "v4-8"]
|
|
pjrt: ["true", "false"]
|
|
name: "TPU test (${{ matrix.jaxlib-version }}, pjrt=${{ matrix.pjrt }}, ${{ matrix.tpu-type }})"
|
|
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
|
|
steps:
|
|
# https://opensource.google/documentation/reference/github/services#actions
|
|
# mandates using a specific commit for non-Google actions. We use
|
|
# https://github.com/sethvargo/ratchet to pin specific versions.
|
|
- uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # ratchet:actions/setup-python@v4
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Install JAX test requirements
|
|
run: |
|
|
pip install -r build/test-requirements.txt
|
|
- name: Install JAX
|
|
run: |
|
|
pip uninstall -y jax jaxlib libtpu-nightly
|
|
if [ "${{ matrix.jaxlib-version }}" == "latest" ]; then
|
|
pip install .[tpu] \
|
|
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
|
|
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
|
|
pip install .
|
|
pip install --pre jaxlib \
|
|
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
|
pip install --pre libtpu-nightly \
|
|
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
pip install requests
|
|
|
|
else
|
|
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
|
|
exit 1
|
|
fi
|
|
|
|
python -c 'import jax; print("jax version:", jax.__version__)'
|
|
python -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
|
|
python -c 'import jax; print("libtpu version:",
|
|
jax.lib.xla_bridge.get_backend().platform_version)'
|
|
- name: Run tests
|
|
env:
|
|
JAX_PLATFORMS: tpu,cpu
|
|
JAX_USE_PJRT_C_API_ON_TPU: ${{ matrix.pjrt }}
|
|
run: |
|
|
# Run single-accelerator tests in parallel
|
|
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
|
|
--maxfail=20 -m "not multiaccelerator" tests examples
|
|
# Run multi-accelerator across all chips
|
|
python -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
|
|
- name: Send chat on failure
|
|
# Don't notify when testing the workflow from a branch.
|
|
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' }}
|
|
run: |
|
|
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
|
|
--header 'Content-Type: application/json' \
|
|
--data-raw "{
|
|
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu-type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
|
|
}"
|