Run v3-8 tests with cores set at 8

This commit is contained in:
Michael Hudgins 2024-05-15 14:37:27 -04:00 committed by GitHub
parent 181da12809
commit 0232cb9f8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,9 +26,9 @@ jobs:
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
{type: "v3-8", core: "4"},
{type: "v4-8", core: "4"},
{type: "v5e-8", core: "8"}
{type: "v3-8", cores: "8"},
{type: "v4-8", cores: "4"},
{type: "v5e-8", cores: "8"}
]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
@ -88,7 +88,7 @@ jobs:
PY_COLORS: 1
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.core }} --tb=short \
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests