mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add PJRT C API to Cloud TPU test matrix
Also shortens the job names so the full name is visible from the github UI (this was driving me crazy), and marks a new test that can't be run on the PJRT C API yet. Example run: https://github.com/google/jax/actions/runs/4019968334
This commit is contained in:
parent
1107e79999
commit
93cd07efb8
13
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
13
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
@ -13,8 +13,10 @@ jobs:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
python-version: ["3.10"] # TODO(jakevdp): update to 3.11 when available.
|
||||
jaxlib-version: ["latest-release", "nightly"]
|
||||
jaxlib-version: ["latest", "nightly"]
|
||||
tpu-type: ["v3-8", "v4-8"]
|
||||
pjrt: ["true", "false"]
|
||||
name: "TPU test (${{ matrix.jaxlib-version }}, pjrt=${{ matrix.pjrt }}, ${{ matrix.tpu-type }})"
|
||||
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
|
||||
steps:
|
||||
# https://opensource.google/documentation/reference/github/services#actions
|
||||
@ -31,7 +33,7 @@ jobs:
|
||||
- name: Install JAX
|
||||
run: |
|
||||
pip uninstall -y jax jaxlib libtpu-nightly
|
||||
if [ "${{ matrix.jaxlib-version }}" == "latest-release" ]; then
|
||||
if [ "${{ matrix.jaxlib-version }}" == "latest" ]; then
|
||||
pip install .[tpu] \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
|
||||
@ -55,12 +57,15 @@ jobs:
|
||||
- name: Run tests
|
||||
env:
|
||||
JAX_PLATFORMS: tpu,cpu
|
||||
JAX_USE_PJRT_C_API_ON_TPU: ${{ matrix.pjrt }}
|
||||
EXTRA_TAGS: "${{ matrix.pjrt == 'true' && 'and not pjrt_c_api_unimplemented' || '' }}"
|
||||
run: |
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true python -m pytest -n=4 --tb=short \
|
||||
--maxfail=20 -m "not multiaccelerator" tests examples
|
||||
--maxfail=20 -m "not multiaccelerator ${EXTRA_TAGS}" tests examples
|
||||
# Run multi-accelerator across all chips
|
||||
python -m pytest -m "multiaccelerator" --tb=short --maxfail=20 tests
|
||||
python -m pytest --tb=short --maxfail=20 \
|
||||
-m "multiaccelerator ${EXTRA_TAGS}" tests
|
||||
- name: Send chat on failure
|
||||
# Don't notify when testing the workflow from a branch.
|
||||
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' }}
|
||||
|
@ -1069,6 +1069,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
"valid for values of rank at least 4, but was applied to a value of rank 1"):
|
||||
pjit_f(jnp.array([1, 2, 3]))
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # custom partitoner
|
||||
@jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU.
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_custom_partitioner(self):
|
||||
|
@ -880,6 +880,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
x = np.arange(6, dtype=np.int32).reshape((3, 2))
|
||||
np.testing.assert_allclose(g(x), x)
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class IOPythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user