mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Disable a shard_map test case that fails on TPU v5e.
PiperOrigin-RevId: 672618556
This commit is contained in:
parent
4bdfe09241
commit
5cc5ed2c5c
@ -1544,6 +1544,9 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(e2.primitive.name, 'pbroadcast')
|
||||
|
||||
def test_check_rep_false_grads(self):
|
||||
if jtu.is_device_tpu(5, 'e'):
|
||||
self.skipTest('TODO(b/307508823): Test currently fails on TPU v5e')
|
||||
|
||||
# This test is redundant with the systematic tests below, but it serves as a
|
||||
# direct regression test for a bug.
|
||||
mesh = jtu.create_mesh((4,), ('heads',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user