From a6ab6bbc20accd61c39f6c02ce160dee49a15d55 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 12 Mar 2025 05:19:55 -0700 Subject: [PATCH] Ignore Pallas TPU tests when testing with the oldest supported libtpu I missed adding this in from https://github.com/jax-ml/jax/blob/main/.github/workflows/cloud-tpu-ci-nightly.yml when I added the TPU jobs to the new CI workflows PiperOrigin-RevId: 736094492 --- .github/workflows/pytest_tpu.yml | 2 ++ ci/run_pytest_tpu.sh | 11 ++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 2341bfb79..a105a2feb 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -135,6 +135,8 @@ jobs: echo "Using oldest supported libtpu" $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + + echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV else echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}" exit 1 diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 9b4a3bbfd..5d8aa9ed6 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -53,10 +53,19 @@ export JAX_SKIP_SLOW_TESTS=true echo "Running TPU tests..." if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then + # We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic + # TPU does not guarantee anything about forward compatibility (unless + # jax.export is used) and the 12 week compatibility window accumulates way + # too many failures. + IGNORE_FLAGS="" + if [ "${libtpu_version_type:-""}" == "oldest_supported_libtpu" ]; then + IGNORE_FLAGS="--ignore=tests/pallas" + fi + # Run single-accelerator tests in parallel JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ - --maxfail=20 -m "not multiaccelerator" tests examples + --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples # Run Pallas printing tests, which need to run with I/O capturing disabled. TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \