mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Skip remote_transfer_test on cloud TPU.
The necessary API isn't yet implemented for Cloud TPU. PiperOrigin-RevId: 442058546
This commit is contained in:
parent
a4b8a443be
commit
7ffdac0746
@ -36,6 +36,8 @@ class RemoteTransferTest(jtu.JaxTestCase):
|
||||
if not hasattr(dev_a.client, "make_cross_host_receive_buffers"):
|
||||
# TODO(jheek) remove this once a new version of JAX lib is released
|
||||
raise unittest.SkipTest("jax-lib doesn't include cross host APIs")
|
||||
if "libtpu" in jax.local_devices()[0].client.platform_version:
|
||||
raise unittest.SkipTest("Test does not yet work on cloud TPU")
|
||||
send_buf = jax.device_put(np.ones((32,)), dev_a)
|
||||
shapes = [send_buf.xla_shape()]
|
||||
(tag, recv_buf), = dev_b.client.make_cross_host_receive_buffers(
|
||||
|
Loading…
x
Reference in New Issue
Block a user