From 4b1fd63263058f1c48a64fdbf47a9e545e4f336d Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 27 Oct 2022 11:25:12 -0700 Subject: [PATCH] Re-enable skipped test Fixes #12927 PiperOrigin-RevId: 484304818 --- tests/pjit_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9301c108d..84e573340 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1815,10 +1815,10 @@ class ArrayPjitTest(jtu.JaxTestCase): @jax_array(True) def test_unspecified_out_axis_resources(self): - # TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend. - if (xla_bridge.get_backend().runtime_type == 'stream_executor' and - jtu.device_under_test() == 'tpu'): - self.skipTest('Does not work with the cloud TPU SE runtime.') + if xc._version < 102: # Remove when jaxlib 0.3.23 is released + if (xla_bridge.get_backend().runtime_type == 'stream_executor' and + jtu.device_under_test() == 'tpu'): + self.skipTest('Does not work with the cloud TPU SE runtime.') def _checks(out, input_data): self.assertIsInstance(out, array.ArrayImpl) @@ -1852,10 +1852,10 @@ class ArrayPjitTest(jtu.JaxTestCase): @jax_array(True) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): - # TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend. - if (xla_bridge.get_backend().runtime_type == 'stream_executor' and - jtu.device_under_test() == 'tpu'): - self.skipTest('Does not work with the cloud TPU SE runtime.') + if xc._version < 102: # Remove when jaxlib 0.3.23 is released + if (xla_bridge.get_backend().runtime_type == 'stream_executor' and + jtu.device_under_test() == 'tpu'): + self.skipTest('Does not work with the cloud TPU SE runtime.') global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2)