Add PJRT C API to Cloud TPU test matrix

Also shortens the job names so the full name is visible from the
github UI (this was driving me crazy), and marks a new test that can't
be run on the PJRT C API yet.

Example run: https://github.com/google/jax/actions/runs/4019968334
This commit is contained in:
Skye Wanderman-Milne 2023-01-26 18:40:37 +00:00
parent 1107e79999
commit 93cd07efb8
3 changed files with 11 additions and 4 deletions

View File

@ -13,8 +13,10 @@ jobs:
fail-fast: false # don't cancel all jobs on failure
matrix:
python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available.
jaxlib-version: ["latest-release", "nightly"]
jaxlib-version: ["latest", "nightly"]
tpu-type: ["v3-8", "v4-8"]
pjrt: ["true", "false"]
name: "TPU test (${{ matrix.jaxlib-version }}, pjrt=${{ matrix.pjrt }}, ${{ matrix.tpu-type }})"
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
steps:
# https://opensource.google/documentation/reference/github/services#actions
@ -31,7 +33,7 @@ jobs:
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu-nightly
if [ "${{ matrix.jaxlib-version }}" == "latest-release" ]; then
if [ "${{ matrix.jaxlib-version }}" == "latest" ]; then
pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
@ -55,12 +57,15 @@ jobs:
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
JAX_USE_PJRT_C_API_ON_TPU: ${{ matrix.pjrt }}
EXTRA_TAGS: "${{ matrix.pjrt == 'true' && 'and not pjrt_c_api_unimplemented' || '' }}"
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
--maxfail=20 -m "not multiaccelerator" tests examples
--maxfail=20 -m "not multiaccelerator ${EXTRA_TAGS}" tests examples
# Run multi-accelerator across all chips
python -m pytest -m "multiaccelerator" --tb=short --maxfail=20 tests
python -m pytest --tb=short --maxfail=20 \
-m "multiaccelerator ${EXTRA_TAGS}" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' }}

View File

@ -1069,6 +1069,7 @@ class PJitTest(jtu.BufferDonationTestCase):
"valid for values of rank at least 4, but was applied to a value of rank 1"):
pjit_f(jnp.array([1, 2, 3]))
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # custom partitoner
@jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU.
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_custom_partitioner(self):

View File

@ -880,6 +880,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
x = np.arange(6, dtype=np.int32).reshape((3, 2))
np.testing.assert_allclose(g(x), x)
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class IOPythonCallbackTest(jtu.JaxTestCase):
def setUp(self):