Ignore Pallas TPU tests when testing with the oldest supported libtpu

I missed adding this in from https://github.com/jax-ml/jax/blob/main/.github/workflows/cloud-tpu-ci-nightly.yml when I added the TPU jobs to the new CI workflows

PiperOrigin-RevId: 736094492
This commit is contained in:
Nitin Srinivasan 2025-03-12 05:19:55 -07:00 committed by jax authors
parent 61ba2b2603
commit a6ab6bbc20
2 changed files with 12 additions and 1 deletions

View File

@ -135,6 +135,8 @@ jobs:
echo "Using oldest supported libtpu"
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
else
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
exit 1

View File

@ -53,10 +53,19 @@ export JAX_SKIP_SLOW_TESTS=true
echo "Running TPU tests..."
if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic
# TPU does not guarantee anything about forward compatibility (unless
# jax.export is used) and the 12 week compatibility window accumulates way
# too many failures.
IGNORE_FLAGS=""
if [ "${libtpu_version_type:-""}" == "oldest_supported_libtpu" ]; then
IGNORE_FLAGS="--ignore=tests/pallas"
fi
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \