mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Enable the passed tests for memories and layout
This commit is contained in:
parent
85c30d2a86
commit
c774d7b29e
@ -41,11 +41,13 @@ def tearDownModule():
|
||||
class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(['tpu']):
|
||||
self.skipTest("Layouts do not work on CPU and GPU backends yet.")
|
||||
if not jtu.test_device_matches(['tpu', 'gpu']):
|
||||
self.skipTest("Layouts do not work on CPU backend yet.")
|
||||
super().setUp()
|
||||
|
||||
def test_auto_layout(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape1 = (128, 128)
|
||||
shape2 = (128, 128)
|
||||
@ -111,6 +113,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)
|
||||
|
||||
def test_default_layout(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (4, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -150,6 +154,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
out_shardings=DLL.AUTO).lower(sds).compile()
|
||||
|
||||
def test_in_layouts_out_layouts(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (8, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -174,6 +180,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
def test_sharding_and_layouts(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (4, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -236,6 +244,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
compiled(*arrs)
|
||||
|
||||
def test_aot_layout_mismatch(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (256, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -405,6 +415,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, inp.T)
|
||||
|
||||
def test_device_put_user_concrete_layout(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
if xla_extension_version < 274:
|
||||
self.skipTest('Requires xla_extension_version >= 274')
|
||||
|
||||
|
@ -190,8 +190,8 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Memories do not work on CPU and GPU backends yet.")
|
||||
if not jtu.test_device_matches(["tpu", "gpu"]):
|
||||
self.skipTest("Memories do not work on CPU backend yet.")
|
||||
super().setUp()
|
||||
|
||||
def _check_device_put_addressable_shards(
|
||||
@ -215,6 +215,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters("unpinned_host", "pinned_host")
|
||||
def test_device_put_host_to_hbm(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
@ -229,6 +231,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters("unpinned_host", "pinned_host")
|
||||
def test_device_put_hbm_to_host(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
|
||||
inp = jnp.arange(16).reshape(8, 2)
|
||||
@ -246,6 +250,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
def test_device_put_different_device_and_memory_host_to_hbm(
|
||||
self, host_memory_kind: str
|
||||
):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
if jax.device_count() < 3:
|
||||
raise unittest.SkipTest("Test requires >=3 devices")
|
||||
|
||||
@ -266,6 +272,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
def test_device_put_different_device_and_memory_hbm_to_host(
|
||||
self, host_memory_kind: str
|
||||
):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
if jax.device_count() < 3:
|
||||
raise unittest.SkipTest("Test requires >=3 devices")
|
||||
|
||||
@ -285,6 +293,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters("unpinned_host", "pinned_host")
|
||||
def test_device_put_on_different_device_with_the_same_memory_kind(
|
||||
self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
if len(jax.devices()) < 2:
|
||||
raise unittest.SkipTest("Test requires >=2 devices.")
|
||||
|
||||
@ -331,6 +341,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters("unpinned_host", "pinned_host")
|
||||
def test_device_put_numpy_array(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device")
|
||||
@ -345,6 +357,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters("unpinned_host", "pinned_host")
|
||||
def test_device_put_numpy_scalar(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
np_inp = np.float32(8)
|
||||
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
||||
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
||||
@ -358,6 +372,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters("unpinned_host", "pinned_host")
|
||||
def test_device_put_python_scalar(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
py_scalar = float(8)
|
||||
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
||||
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
||||
@ -372,6 +388,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters("unpinned_host", "pinned_host")
|
||||
def test_device_put_python_int(self, host_memory_kind: str):
|
||||
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
||||
self.skipTest("unpinned_host does not work on GPU backend.")
|
||||
py_inp = 8
|
||||
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
||||
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
||||
@ -399,6 +417,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
out, np_inp * np_inp, s_dev, "device")
|
||||
|
||||
def test_parameter_streaming(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
_, s_host, np_inp, inp_host = _create_inputs(
|
||||
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
||||
s_dev = s_host.with_memory_kind('device')
|
||||
@ -422,6 +442,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
out2, np_inp * np_inp * 2, s_host, 'pinned_host')
|
||||
|
||||
def test_parameter_streaming_with_scalar_and_constant(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
scalar_inp = 1
|
||||
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
||||
@ -569,6 +591,8 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
out_host, np_inp * 2, s_host, 'pinned_host')
|
||||
|
||||
def test_output_streaming_inside_scan(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
||||
self.skipTest("This test requires an xla_version >= 2.")
|
||||
mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user