From fb32841b1becef235b160254fd7b908dddceadd8 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 11 Oct 2024 09:43:46 -0700 Subject: [PATCH] #sdy add JAX Shardy support for memories. PiperOrigin-RevId: 684867097 --- jax/_src/dispatch.py | 10 ++++++++-- tests/BUILD | 4 ++++ tests/memories_test.py | 14 +++++++++++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 7874a79cf..97680bd0f 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -588,8 +588,14 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): if (isinstance(device, (Sharding, TransferToMemoryKind)) and device.memory_kind is not None): if isinstance(device, Sharding): - x = mlir.wrap_with_sharding_op( - ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) + if config.use_shardy_partitioner.value: + x = mlir.wrap_with_sharding_op( + ctx, x, out_aval, + device._to_sdy_sharding(aval.ndim)) + else: + x = mlir.wrap_with_sharding_op( + ctx, x, out_aval, + device._to_xla_hlo_sharding(aval.ndim).to_proto()) x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval) return x return x diff --git a/tests/BUILD b/tests/BUILD index 1bf103875..615437ce4 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -225,6 +225,10 @@ jax_multiplatform_test( "tpu_v4_2x2", "tpu_v5p_2x2", "tpu_v5e_4x2", + "cpu_shardy", + "gpu_2gpu_shardy", + "tpu_v3_2x2_shardy", + "tpu_v5e_4x2_shardy", ], shard_count = { "tpu": 5, diff --git a/tests/memories_test.py b/tests/memories_test.py index 4fc91cf17..781172f88 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -442,7 +442,7 @@ class DevicePutTest(jtu.JaxTestCase): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") _, s_host, np_inp, inp_host = _create_inputs( - (0,), P("x"), mem_kind="pinned_host") + (0,), P(), mem_kind="pinned_host") s_dev = s_host.with_memory_kind('device') @functools.partial(jax.jit, out_shardings=s_host) @@ -730,6 +730,9 @@ class ComputeOffload(jtu.BufferDonationTestCase): def test_compute_no_inputs_host_replicated(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: self.skipTest("This test requires an xla_version >= 3.") + if config.use_shardy_partitioner.value: + self.skipTest("XLA failure due to b/370786664 and b/366411266. " + "Enable when fixed.") mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) @@ -737,8 +740,8 @@ class ComputeOffload(jtu.BufferDonationTestCase): @functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding)) def init(): - tpu_array = jax.random.normal(jax.random.key(42), (16,16)) - cpu_array = jax.random.normal(jax.random.key(42), (16,16)) + tpu_array = jax.random.normal(jax.random.key(42), (16, 16)) + cpu_array = jax.random.normal(jax.random.key(42), (16, 16)) return tpu_array, cpu_array tpu_array, cpu_array = init() @@ -1245,6 +1248,8 @@ class ComputeOffload(jtu.BufferDonationTestCase): self.assertArraysEqual(out2, np_inp @ np_inp.T) def test_jit_compilation_cache_hit(self): + if config.use_shardy_partitioner.value: + self.skipTest("Shardy doesn't support GSPMDSharding") mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) inp2 = jax.device_put( np_inp, GSPMDSharding(tuple(mesh.devices.flat), @@ -1396,6 +1401,9 @@ class ComputeOffload(jtu.BufferDonationTestCase): self.assertArraysAllClose(out, expected_out, rtol=1e-3) def test_mem_kind_donation_pinned_host(self): + if config.use_shardy_partitioner.value: + self.skipTest("XLA failure due to b/370786664 and b/366411266. " + "Enable when fixed.") mesh = jtu.create_mesh((2,), "x") s = NamedSharding(mesh, P(), memory_kind='pinned_host') s_dev = s.with_memory_kind('device')