mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Skip remote_transfer_test because Array does not have the xla_shape method since its deprecated.
PiperOrigin-RevId: 474913967
This commit is contained in:
parent
dce93e45bb
commit
e6bdb00d31
@ -31,7 +31,10 @@ class RemoteTransferTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def test_remote_transfer(self):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Remote transfer requires at lest 2 devices")
|
||||
raise unittest.SkipTest("Remote transfer requires at least 2 devices")
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest("Array does not have xla_shape method since "
|
||||
"it is deprecated.")
|
||||
dev_a, dev_b = jax.local_devices()[:2]
|
||||
if "libtpu" in jax.local_devices()[0].client.platform_version:
|
||||
raise unittest.SkipTest("Test does not yet work on cloud TPU")
|
||||
|
Loading…
x
Reference in New Issue
Block a user