Disable a shard_map test case that fails on TPU v5e.

PiperOrigin-RevId: 672618556
This commit is contained in:
Peter Hawkins 2024-09-09 11:44:51 -07:00 committed by jax authors
parent 4bdfe09241
commit 5cc5ed2c5c

View File

@ -1544,6 +1544,9 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(e2.primitive.name, 'pbroadcast')
def test_check_rep_false_grads(self):
if jtu.is_device_tpu(5, 'e'):
self.skipTest('TODO(b/307508823): Test currently fails on TPU v5e')
# This test is redundant with the systematic tests below, but it serves as a
# direct regression test for a bug.
mesh = jtu.create_mesh((4,), ('heads',))