mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
parent
9f80402845
commit
4b1fd63263
@ -1815,7 +1815,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@jax_array(True)
|
||||
def test_unspecified_out_axis_resources(self):
|
||||
# TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend.
|
||||
if xc._version < 102: # Remove when jaxlib 0.3.23 is released
|
||||
if (xla_bridge.get_backend().runtime_type == 'stream_executor' and
|
||||
jtu.device_under_test() == 'tpu'):
|
||||
self.skipTest('Does not work with the cloud TPU SE runtime.')
|
||||
@ -1852,7 +1852,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@jax_array(True)
|
||||
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
|
||||
s2_shape, s3_shape, s4_shape):
|
||||
# TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend.
|
||||
if xc._version < 102: # Remove when jaxlib 0.3.23 is released
|
||||
if (xla_bridge.get_backend().runtime_type == 'stream_executor' and
|
||||
jtu.device_under_test() == 'tpu'):
|
||||
self.skipTest('Does not work with the cloud TPU SE runtime.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user