name: Cloud TPU nightly on: schedule: - cron: "0 14 * * *" # daily at 7am PST workflow_dispatch: # allows triggering the workflow run manually # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. permissions: contents: read jobs: cloud-tpu-test: runs-on: [self-hosted, tpu, v4-8] strategy: fail-fast: false # don't cancel all jobs on failure matrix: python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available. jaxlib-version: [latest, nightly] steps: # https://opensource.google/documentation/reference/github/services#actions # mandates using a specific commit for non-Google actions. We use # https://github.com/sethvargo/ratchet to pin specific versions. - uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # ratchet:actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # ratchet:actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements run: | pip install -r build/test-requirements.txt - name: Install JAX run: | pip uninstall -y jax jaxlib libtpu-nightly if [ "${{ matrix.jaxlib-version }}" == "latest" ]; then pip install .[tpu] \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then pip install . pip install --pre jaxlib \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html pip install libtpu-nightly \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html else echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" exit 1 fi python3 -c 'import jax; print("jax version:", jax.__version__)' python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' python3 -c 'import jax; print("libtpu version:", jax.lib.xla_bridge.get_backend().platform_version)' - name: Run tests env: JAX_PLATFORMS: tpu,cpu run: | # Run single-accelerator tests in parallel JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \ -m "not multiaccelerator" tests examples # Run multi-accelerator across all chips python -m pytest -m "multiaccelerator" --tb=short tests - name: Send chat on failure # Don't notify when testing the workflow from a branch. if: ${{ failure() && github.ref_name == 'main' }} run: | curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \ --header 'Content-Type: application/json' \ --data-raw "{ 'text': '\"$GITHUB_WORKFLOW\" job failed: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID' }"