diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 04e9dd23d..5ab815f75 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3898,8 +3898,6 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertIn('unspecified_dims=[0,1]', lowered_text) def test_jit_partially_specified_shardings(self): - if jtu.is_device_tpu(version=5, variant="e"): - self.skipTest('Remove this once b/328054509 is fixed') mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2)