From a8df383ccf8fe5d8b56a278b87d772b0ac34ed02 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 30 Jan 2025 12:04:28 -0800 Subject: [PATCH] Fix `lax.ragged_all_to_all` degenerate case In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately. Added small improvement to error messages. PiperOrigin-RevId: 721473063 --- jax/_src/lax/parallel.py | 2 - tests/BUILD | 4 ++ tests/ragged_collective_test.py | 100 ++++++++++++++++++++++++++++++-- 3 files changed, 98 insertions(+), 8 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 7128d48e8..d8e7431e9 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1159,8 +1159,6 @@ def _ragged_all_to_all_lowering( split_count = len(replica_groups[0]) if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') - if len(replica_groups[0]) == 1: - return [operand] ragged_all_to_all_attrs = { "replica_groups": _replica_groups_hlo(replica_groups) diff --git a/tests/BUILD b/tests/BUILD index 6b90309bd..6aca3372c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1390,6 +1390,10 @@ jax_multiplatform_test( enable_configs = [ "gpu_p100x2_shardy", ], + shard_count = { + "gpu": 10, + "tpu": 10, + }, tags = [ "multiaccelerator", ], diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 2a184afc1..2f19764f1 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -45,10 +45,10 @@ class RaggedCollectiveTest(jtu.JaxTestCase): ) def test_ragged_all_to_all(self, axis_name, mesh_axes): device_type = jax.devices()[0].platform - if device_type == 'tpu' and jtu.get_tpu_version() == 3: + if device_type == 'tpu' and jtu.get_tpu_version() < 4: raise unittest.SkipTest( - 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by' - ' TPU v3' + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' ) mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) operand = jax.device_put( @@ -132,10 +132,10 @@ class RaggedCollectiveTest(jtu.JaxTestCase): ) def test_ragged_all_to_all_axis_index_groups(self, axis_name, mesh_axes): device_type = jax.devices()[0].platform - if device_type == 'tpu' and jtu.get_tpu_version() == 3: + if device_type == 'tpu' and jtu.get_tpu_version() < 4: raise unittest.SkipTest( - 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by' - ' TPU v3' + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' ) mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) operand = jax.device_put( @@ -219,6 +219,94 @@ class RaggedCollectiveTest(jtu.JaxTestCase): [10, 30, 0, 0], [20, 20, 40, 0]], dtype=jnp.int32) ) + @parameterized.named_parameters( + dict( + testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2) + ), + ) + def test_ragged_all_to_all_degenerate_groups(self, axis_name, mesh_axes): + device_type = jax.devices()[0].platform + if device_type == 'tpu': + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` with singleton group is' + ' not supported by TPU' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + operand = jax.device_put( + jnp.array([[1, 0, 0, 0], [2, 3, 4, 0]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + output = jax.device_put( + jnp.zeros((2, 4), dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + input_offsets = jax.device_put( + jnp.array([[0], [0]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + send_sizes = jax.device_put( + jnp.array([[1], [3]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + output_offsets = jax.device_put( + jnp.array([[2], [1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + recv_sizes = jax.device_put( + jnp.array([[1], [3]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + axis_index_groups = ((0,), (1,)) + + @jax.jit + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + operand = operand.reshape(operand.shape[1:]) + output = output.reshape(output.shape[1:]) + input_offsets = input_offsets.reshape(input_offsets.shape[1:]) + send_sizes = send_sizes.reshape(send_sizes.shape[1:]) + output_offsets = output_offsets.reshape(output_offsets.shape[1:]) + recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:]) + return lax.ragged_all_to_all( + operand, + output, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=axis_name, + axis_index_groups=axis_index_groups, + ) + + mlir_module = fwd.lower( + operand, output, input_offsets, send_sizes, output_offsets, + recv_sizes).as_text() + self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module) + self.assertIn("replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>", + mlir_module) + + c = fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ).reshape((2, 4)) + self.assertAllClose( + c, jnp.array([[0, 0, 1, 0], [0, 2, 3, 4]], dtype=jnp.int32) + ) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())