mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
#sdy add JAX Shardy support for memories.
PiperOrigin-RevId: 684867097
This commit is contained in:
parent
59ae2af699
commit
fb32841b1b
@ -588,8 +588,14 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
|
||||
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
if isinstance(device, Sharding):
|
||||
x = mlir.wrap_with_sharding_op(
|
||||
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
if config.use_shardy_partitioner.value:
|
||||
x = mlir.wrap_with_sharding_op(
|
||||
ctx, x, out_aval,
|
||||
device._to_sdy_sharding(aval.ndim))
|
||||
else:
|
||||
x = mlir.wrap_with_sharding_op(
|
||||
ctx, x, out_aval,
|
||||
device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
|
||||
return x
|
||||
return x
|
||||
|
@ -225,6 +225,10 @@ jax_multiplatform_test(
|
||||
"tpu_v4_2x2",
|
||||
"tpu_v5p_2x2",
|
||||
"tpu_v5e_4x2",
|
||||
"cpu_shardy",
|
||||
"gpu_2gpu_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"tpu_v5e_4x2_shardy",
|
||||
],
|
||||
shard_count = {
|
||||
"tpu": 5,
|
||||
|
@ -442,7 +442,7 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test does not work on GPU backend.")
|
||||
_, s_host, np_inp, inp_host = _create_inputs(
|
||||
(0,), P("x"), mem_kind="pinned_host")
|
||||
(0,), P(), mem_kind="pinned_host")
|
||||
s_dev = s_host.with_memory_kind('device')
|
||||
|
||||
@functools.partial(jax.jit, out_shardings=s_host)
|
||||
@ -730,6 +730,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
def test_compute_no_inputs_host_replicated(self):
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3:
|
||||
self.skipTest("This test requires an xla_version >= 3.")
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("XLA failure due to b/370786664 and b/366411266. "
|
||||
"Enable when fixed.")
|
||||
mesh = jtu.create_mesh((4,), ('data'))
|
||||
|
||||
tpu_sharding = NamedSharding(mesh, P('data'))
|
||||
@ -737,8 +740,8 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
|
||||
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
|
||||
def init():
|
||||
tpu_array = jax.random.normal(jax.random.key(42), (16,16))
|
||||
cpu_array = jax.random.normal(jax.random.key(42), (16,16))
|
||||
tpu_array = jax.random.normal(jax.random.key(42), (16, 16))
|
||||
cpu_array = jax.random.normal(jax.random.key(42), (16, 16))
|
||||
return tpu_array, cpu_array
|
||||
|
||||
tpu_array, cpu_array = init()
|
||||
@ -1245,6 +1248,8 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertArraysEqual(out2, np_inp @ np_inp.T)
|
||||
|
||||
def test_jit_compilation_cache_hit(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("Shardy doesn't support GSPMDSharding")
|
||||
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
|
||||
inp2 = jax.device_put(
|
||||
np_inp, GSPMDSharding(tuple(mesh.devices.flat),
|
||||
@ -1396,6 +1401,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertArraysAllClose(out, expected_out, rtol=1e-3)
|
||||
|
||||
def test_mem_kind_donation_pinned_host(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("XLA failure due to b/370786664 and b/366411266. "
|
||||
"Enable when fixed.")
|
||||
mesh = jtu.create_mesh((2,), "x")
|
||||
s = NamedSharding(mesh, P(), memory_kind='pinned_host')
|
||||
s_dev = s.with_memory_kind('device')
|
||||
|
Loading…
x
Reference in New Issue
Block a user