From 5cc5ed2c5cd6efcc474aaf97963d8994037f9b44 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Sep 2024 11:44:51 -0700 Subject: [PATCH] Disable a shard_map test case that fails on TPU v5e. PiperOrigin-RevId: 672618556 --- tests/shard_map_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 2df477454..e9c23b3e5 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1544,6 +1544,9 @@ class ShardMapTest(jtu.JaxTestCase): self.assertEqual(e2.primitive.name, 'pbroadcast') def test_check_rep_false_grads(self): + if jtu.is_device_tpu(5, 'e'): + self.skipTest('TODO(b/307508823): Test currently fails on TPU v5e') + # This test is redundant with the systematic tests below, but it serves as a # direct regression test for a bug. mesh = jtu.create_mesh((4,), ('heads',))