Disable all_gather_test on non-v5e TPUs

Also consolidate logic for selectively enabling TPU tests on TPU versions

PiperOrigin-RevId: 588597889
This commit is contained in:
Sharad Vikram 2023-12-06 17:46:46 -08:00 committed by jax authors
parent 7fd6c8d860
commit 045a9ef1ef
3 changed files with 16 additions and 5 deletions

View File

@ -337,9 +337,17 @@ def pjrt_c_api_version_at_least(major_version: int, minor_version: int):
return True
return pjrt_c_api_versions >= (major_version, minor_version)
def is_device_tpu_v4():
return jax.devices()[0].device_kind == "TPU v4"
def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
if device_under_test() != "tpu":
return False
if version is None:
return True
device_kind = jax.devices()[0].device_kind
expected_version = f"v{version}{variant}"
# Special case v5e until the name is updated in device_kind
if expected_version == "v5e":
return "v5 lite" in device_kind
return expected_version in device_kind
def _get_device_tags():
"""returns a set of tags defined for the device under test"""

View File

@ -244,7 +244,7 @@ class CacheKeyTest(jtu.JaxTestCase):
self.assertEqual(include_metadata, key1 != key2)
def test_xla_flags(self):
if jtu.is_device_tpu_v4():
if jtu.is_device_tpu(version=4):
raise unittest.SkipTest("TODO(b/240151176)")
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
@ -290,7 +290,7 @@ class CacheKeyTest(jtu.JaxTestCase):
sys.argv = orig_argv
def test_libtpu_init_args(self):
if jtu.is_device_tpu_v4():
if jtu.is_device_tpu(version=4):
raise unittest.SkipTest("TODO(b/240151176)")
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()

View File

@ -80,6 +80,9 @@ class AllGatherTest(jtu.JaxTestCase):
super().setUp()
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Need TPU devices")
if not jtu.is_device_tpu(version=5, variant="e"):
# TODO(sharadmv,apaszke): expand support to more versions
self.skipTest("Currently only supported on TPU v5e")
@hp.given(hps.booleans(), _array_shapes(), _array_dtypes())
def test_all_gather_1d_mesh(self, is_vmem, shape, dtype):