Skip remote_transfer_test on cloud TPU.

The necessary API isn't yet implemented for Cloud TPU.

PiperOrigin-RevId: 442058546
This commit is contained in:
Peter Hawkins 2022-04-15 11:24:13 -07:00 committed by jax authors
parent a4b8a443be
commit 7ffdac0746

View File

@ -36,6 +36,8 @@ class RemoteTransferTest(jtu.JaxTestCase):
if not hasattr(dev_a.client, "make_cross_host_receive_buffers"):
# TODO(jheek) remove this once a new version of JAX lib is released
raise unittest.SkipTest("jax-lib doesn't include cross host APIs")
if "libtpu" in jax.local_devices()[0].client.platform_version:
raise unittest.SkipTest("Test does not yet work on cloud TPU")
send_buf = jax.device_put(np.ones((32,)), dev_a)
shapes = [send_buf.xla_shape()]
(tag, recv_buf), = dev_b.client.make_cross_host_receive_buffers(