diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 821d61e57..d1338fec0 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -337,9 +337,17 @@ def pjrt_c_api_version_at_least(major_version: int, minor_version: int): return True return pjrt_c_api_versions >= (major_version, minor_version) - -def is_device_tpu_v4(): - return jax.devices()[0].device_kind == "TPU v4" +def is_device_tpu(version: int | None = None, variant: str = "") -> bool: + if device_under_test() != "tpu": + return False + if version is None: + return True + device_kind = jax.devices()[0].device_kind + expected_version = f"v{version}{variant}" + # Special case v5e until the name is updated in device_kind + if expected_version == "v5e": + return "v5 lite" in device_kind + return expected_version in device_kind def _get_device_tags(): """returns a set of tags defined for the device under test""" diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index e66e6c2f8..99cc583ae 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -244,7 +244,7 @@ class CacheKeyTest(jtu.JaxTestCase): self.assertEqual(include_metadata, key1 != key2) def test_xla_flags(self): - if jtu.is_device_tpu_v4(): + if jtu.is_device_tpu(version=4): raise unittest.SkipTest("TODO(b/240151176)") computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() @@ -290,7 +290,7 @@ class CacheKeyTest(jtu.JaxTestCase): sys.argv = orig_argv def test_libtpu_init_args(self): - if jtu.is_device_tpu_v4(): + if jtu.is_device_tpu(version=4): raise unittest.SkipTest("TODO(b/240151176)") computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() diff --git a/tests/pallas/all_gather_test.py b/tests/pallas/all_gather_test.py index 7b72c07d9..bde800ac8 100644 --- a/tests/pallas/all_gather_test.py +++ b/tests/pallas/all_gather_test.py @@ -80,6 +80,9 @@ class AllGatherTest(jtu.JaxTestCase): super().setUp() if not jtu.test_device_matches(["tpu"]): self.skipTest("Need TPU devices") + if not jtu.is_device_tpu(version=5, variant="e"): + # TODO(sharadmv,apaszke): expand support to more versions + self.skipTest("Currently only supported on TPU v5e") @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) def test_all_gather_1d_mesh(self, is_vmem, shape, dtype):