mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Re-enable disabled pjit tests due to MSAN failure.
PiperOrigin-RevId: 613266308
This commit is contained in:
parent
026b2d207f
commit
fc8dc8364e
@ -3898,8 +3898,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertIn('unspecified_dims=[0,1]', lowered_text)
|
||||
|
||||
def test_jit_partially_specified_shardings(self):
|
||||
if jtu.is_device_tpu(version=5, variant="e"):
|
||||
self.skipTest('Remove this once b/328054509 is fixed')
|
||||
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user