Skip flaky memories tests on GPU backend.

PiperOrigin-RevId: 658177202
This commit is contained in:
Kanglan Tang 2024-07-31 16:12:14 -07:00 committed by jax authors
parent b677a712ab
commit a7e071ec42

View File

@ -493,6 +493,8 @@ class DevicePutTest(jtu.JaxTestCase):
)
def test_parameter_and_output_streaming_with_scalar(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test is flaky on GPU backend.")
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
@ -560,6 +562,8 @@ class DevicePutTest(jtu.JaxTestCase):
self.assertEqual(out_hbm.sharding, out_s)
def test_output_streaming(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test is flaky on GPU backend.")
mesh = jtu.create_global_mesh((1, 1), ("x", "y"))
np_inp = np.arange(16.0).reshape(8, 2)
s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device")