#sdy add JAX Shardy support for memories.

PiperOrigin-RevId: 684867097
This commit is contained in:
Bart Chrzaszcz 2024-10-11 09:43:46 -07:00 committed by jax authors
parent 59ae2af699
commit fb32841b1b
3 changed files with 23 additions and 5 deletions

View File

@ -588,8 +588,14 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
device.memory_kind is not None):
if isinstance(device, Sharding):
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
if config.use_shardy_partitioner.value:
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval,
device._to_sdy_sharding(aval.ndim))
else:
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval,
device._to_xla_hlo_sharding(aval.ndim).to_proto())
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
return x
return x

View File

@ -225,6 +225,10 @@ jax_multiplatform_test(
"tpu_v4_2x2",
"tpu_v5p_2x2",
"tpu_v5e_4x2",
"cpu_shardy",
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v5e_4x2_shardy",
],
shard_count = {
"tpu": 5,

View File

@ -442,7 +442,7 @@ class DevicePutTest(jtu.JaxTestCase):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
_, s_host, np_inp, inp_host = _create_inputs(
(0,), P("x"), mem_kind="pinned_host")
(0,), P(), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind('device')
@functools.partial(jax.jit, out_shardings=s_host)
@ -730,6 +730,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
def test_compute_no_inputs_host_replicated(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3:
self.skipTest("This test requires an xla_version >= 3.")
if config.use_shardy_partitioner.value:
self.skipTest("XLA failure due to b/370786664 and b/366411266. "
"Enable when fixed.")
mesh = jtu.create_mesh((4,), ('data'))
tpu_sharding = NamedSharding(mesh, P('data'))
@ -737,8 +740,8 @@ class ComputeOffload(jtu.BufferDonationTestCase):
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
def init():
tpu_array = jax.random.normal(jax.random.key(42), (16,16))
cpu_array = jax.random.normal(jax.random.key(42), (16,16))
tpu_array = jax.random.normal(jax.random.key(42), (16, 16))
cpu_array = jax.random.normal(jax.random.key(42), (16, 16))
return tpu_array, cpu_array
tpu_array, cpu_array = init()
@ -1245,6 +1248,8 @@ class ComputeOffload(jtu.BufferDonationTestCase):
self.assertArraysEqual(out2, np_inp @ np_inp.T)
def test_jit_compilation_cache_hit(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy doesn't support GSPMDSharding")
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
inp2 = jax.device_put(
np_inp, GSPMDSharding(tuple(mesh.devices.flat),
@ -1396,6 +1401,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
self.assertArraysAllClose(out, expected_out, rtol=1e-3)
def test_mem_kind_donation_pinned_host(self):
if config.use_shardy_partitioner.value:
self.skipTest("XLA failure due to b/370786664 and b/366411266. "
"Enable when fixed.")
mesh = jtu.create_mesh((2,), "x")
s = NamedSharding(mesh, P(), memory_kind='pinned_host')
s_dev = s.with_memory_kind('device')