Skip remote_transfer_test because Array does not have the xla_shape method since its deprecated.

PiperOrigin-RevId: 474913967
This commit is contained in:
Yash Katariya 2022-09-16 15:25:43 -07:00 committed by jax authors
parent dce93e45bb
commit e6bdb00d31

View File

@ -31,7 +31,10 @@ class RemoteTransferTest(jtu.JaxTestCase):
@jtu.skip_on_devices("gpu")
def test_remote_transfer(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Remote transfer requires at lest 2 devices")
raise unittest.SkipTest("Remote transfer requires at least 2 devices")
if config.jax_array:
raise unittest.SkipTest("Array does not have xla_shape method since "
"it is deprecated.")
dev_a, dev_b = jax.local_devices()[:2]
if "libtpu" in jax.local_devices()[0].client.platform_version:
raise unittest.SkipTest("Test does not yet work on cloud TPU")