mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Skip flaky memories tests on GPU backend.
PiperOrigin-RevId: 658177202
This commit is contained in:
parent
b677a712ab
commit
a7e071ec42
@ -493,6 +493,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_parameter_and_output_streaming_with_scalar(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test is flaky on GPU backend.")
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
||||
self.skipTest("This test requires an xla_version >= 2.")
|
||||
|
||||
@ -560,6 +562,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_hbm.sharding, out_s)
|
||||
|
||||
def test_output_streaming(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test is flaky on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((1, 1), ("x", "y"))
|
||||
np_inp = np.arange(16.0).reshape(8, 2)
|
||||
s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device")
|
||||
|
Loading…
x
Reference in New Issue
Block a user