mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Disable all_gather_test on non-v5e TPUs
Also consolidate logic for selectively enabling TPU tests on TPU versions PiperOrigin-RevId: 588597889
This commit is contained in:
parent
7fd6c8d860
commit
045a9ef1ef
@ -337,9 +337,17 @@ def pjrt_c_api_version_at_least(major_version: int, minor_version: int):
|
||||
return True
|
||||
return pjrt_c_api_versions >= (major_version, minor_version)
|
||||
|
||||
|
||||
def is_device_tpu_v4():
|
||||
return jax.devices()[0].device_kind == "TPU v4"
|
||||
def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
|
||||
if device_under_test() != "tpu":
|
||||
return False
|
||||
if version is None:
|
||||
return True
|
||||
device_kind = jax.devices()[0].device_kind
|
||||
expected_version = f"v{version}{variant}"
|
||||
# Special case v5e until the name is updated in device_kind
|
||||
if expected_version == "v5e":
|
||||
return "v5 lite" in device_kind
|
||||
return expected_version in device_kind
|
||||
|
||||
def _get_device_tags():
|
||||
"""returns a set of tags defined for the device under test"""
|
||||
|
@ -244,7 +244,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
self.assertEqual(include_metadata, key1 != key2)
|
||||
|
||||
def test_xla_flags(self):
|
||||
if jtu.is_device_tpu_v4():
|
||||
if jtu.is_device_tpu(version=4):
|
||||
raise unittest.SkipTest("TODO(b/240151176)")
|
||||
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
@ -290,7 +290,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
sys.argv = orig_argv
|
||||
|
||||
def test_libtpu_init_args(self):
|
||||
if jtu.is_device_tpu_v4():
|
||||
if jtu.is_device_tpu(version=4):
|
||||
raise unittest.SkipTest("TODO(b/240151176)")
|
||||
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
|
@ -80,6 +80,9 @@ class AllGatherTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Need TPU devices")
|
||||
if not jtu.is_device_tpu(version=5, variant="e"):
|
||||
# TODO(sharadmv,apaszke): expand support to more versions
|
||||
self.skipTest("Currently only supported on TPU v5e")
|
||||
|
||||
@hp.given(hps.booleans(), _array_shapes(), _array_dtypes())
|
||||
def test_all_gather_1d_mesh(self, is_vmem, shape, dtype):
|
||||
|
Loading…
x
Reference in New Issue
Block a user