Re-enable skipped test

Fixes #12927

PiperOrigin-RevId: 484304818
This commit is contained in:
Skye Wanderman-Milne 2022-10-27 11:25:12 -07:00 committed by jax authors
parent 9f80402845
commit 4b1fd63263

View File

@ -1815,7 +1815,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax_array(True)
def test_unspecified_out_axis_resources(self):
# TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend.
if xc._version < 102: # Remove when jaxlib 0.3.23 is released
if (xla_bridge.get_backend().runtime_type == 'stream_executor' and
jtu.device_under_test() == 'tpu'):
self.skipTest('Does not work with the cloud TPU SE runtime.')
@ -1852,7 +1852,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax_array(True)
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
s2_shape, s3_shape, s4_shape):
# TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend.
if xc._version < 102: # Remove when jaxlib 0.3.23 is released
if (xla_bridge.get_backend().runtime_type == 'stream_executor' and
jtu.device_under_test() == 'tpu'):
self.skipTest('Does not work with the cloud TPU SE runtime.')