Add support for axis_name and axis_index_groups to lax.ragged_all_to_all

PiperOrigin-RevId: 720738861
This commit is contained in:
Gunhyun Park 2025-01-28 16:01:30 -08:00 committed by jax authors
parent 9cbff64251
commit 809e1133c8
4 changed files with 324 additions and 77 deletions

View File

@ -457,7 +457,9 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
return tree_util.tree_map(bind, x)
def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
def ragged_all_to_all(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *,
axis_name, axis_index_groups = None):
"""Ragged version of :func:`all_to_all`.
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
@ -528,14 +530,26 @@ def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets
send_sizes: array of ragged output data.
output_offsets: array of ragged offsets in the target replica output.
recv_sizes: array of ragged output receive sizes.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the
first two and last two replicas). Groups must cover all axis indices
exactly once, and all groups must be the same size. Otherwise, the
behavior is undefined.
Returns:
array with shape equal to ``output``.
"""
return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes,
output_offsets, recv_sizes)
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes,
output_offsets, recv_sizes,
axis_name=axis_name,
axis_index_groups=axis_index_groups)
def axis_index(axis_name):
@ -1134,29 +1148,44 @@ batching.fancy_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective
batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name')
def _ragged_all_to_all_lowering(ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
N = input_offsets.type.shape[0]
backend_config = ir.DictAttr.get({
'replica_groups': ir.DenseIntElementsAttr.get(
np.arange(0, N, 1, dtype=np.int64), shape=[1, N]
)
})
def _ragged_all_to_all_lowering(
ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
*, axis_name, axis_index_groups
):
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
# Assumes all groups are the same size
split_count = len(replica_groups[0])
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
if len(replica_groups[0]) == 1:
return [operand]
ragged_all_to_all_attrs = {
"replica_groups": _replica_groups_hlo(replica_groups)
}
is_spmd = isinstance(
ctx.module_context.axis_context, (SPMDAxisContext, ShardingContext))
if is_spmd:
ragged_all_to_all_attrs['channel_id'] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), ctx.module_context.new_channel()
)
return hlo.CustomCallOp(
result=[output.type],
inputs=[operand, output, input_offsets, send_sizes, output_offsets,
recv_sizes],
call_target_name=ir.StringAttr.get('ragged_all_to_all'),
backend_config=backend_config,
backend_config=ir.DictAttr.get(ragged_all_to_all_attrs),
api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4),
).results
@ragged_all_to_all_p.def_abstract_eval
def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
if operand.shape[1:] != output.shape[1:]:
raise ValueError(
"ragged_all_to_all input and output shapes must be equal, except for"
" the outermost dimension."
)
def _ragged_all_to_all_effectful_abstract_eval(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
axis_name, axis_index_groups
):
del operand, axis_index_groups
if not dtypes.issubdtype(input_offsets.dtype, np.integer):
raise ValueError("ragged_all_to_all input_offsets must be integer type.")
if not dtypes.issubdtype(send_sizes.dtype, np.integer):
@ -1185,15 +1214,16 @@ def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes,
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
" size, but got shape {}".format(recv_sizes.shape)
)
return output.update(
shape=list(output.shape),
dtype=output.dtype,
weak_type=output.weak_type,
)
ragged_all_to_all_p.def_impl(partial(dispatch.apply_primitive, ragged_all_to_all_p))
_check_axis_names(axis_name)
out_aval = output.update(shape=output.shape, weak_type=False)
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval)
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
"""Gather values of x across all replicas.

View File

@ -1377,6 +1377,27 @@ jax_multiplatform_test(
},
)
jax_multiplatform_test(
name = "ragged_collective_test",
srcs = ["ragged_collective_test.py"],
disable_configs = [
"tpu_pjrt_c_api",
],
enable_backends = [
"gpu",
"tpu",
],
enable_configs = [
"gpu_p100x2_shardy",
],
tags = [
"multiaccelerator",
],
deps = [
"//jax:experimental",
],
)
jax_multiplatform_test(
name = "shard_map_test",
srcs = ["shard_map_test.py"],

View File

@ -4720,23 +4720,6 @@ class CompositeTest(jtu.JaxTestCase):
class RaggedTest(jtu.JaxTestCase):
def testRaggedAllToAll(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
mlir_module = jax.jit(lax.ragged_all_to_all).lower(
operand, output, input_offsets, send_sizes, output_offsets,
recv_sizes).as_text()
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
self.assertIn(
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
" tensor<1x3xi64>}}",
mlir_module,
)
def testRaggedAllToAllErrors(self):
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
@ -4744,117 +4727,111 @@ class RaggedTest(jtu.JaxTestCase):
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input and output shapes must be equal, except for"
" the outermost dimension.",
):
jax.jit(lax.ragged_all_to_all).lower(
operand,
jnp.array([[0.0], [0.0], [0.0], [0.0], [0.0]], dtype=jnp.float32),
input_offsets, send_sizes, output_offsets, recv_sizes)
axis_name = "x"
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all input_offsets must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32),
send_sizes, output_offsets, recv_sizes)
send_sizes, output_offsets, recv_sizes, axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all send_sizes must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets,
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets,
recv_sizes)
recv_sizes, axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all output_offsets must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets, send_sizes,
jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes,
axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError, "ragged_all_to_all recv_sizes must be integer type."
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes,
output_offsets, recv_sizes)
output_offsets, recv_sizes, axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, jnp.array([], dtype=jnp.int32), send_sizes,
output_offsets, recv_sizes)
output_offsets, recv_sizes, axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets,
jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes,
axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets, jnp.array([], dtype=jnp.int32),
output_offsets, recv_sizes)
output_offsets, recv_sizes, axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all output_offsets must be rank 1 with positive"
" dimension size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets, send_sizes,
jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes,
axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all output_offsets must be rank 1 with positive"
" dimension size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets, send_sizes,
jnp.array([], dtype=jnp.int32), recv_sizes)
jnp.array([], dtype=jnp.int32), recv_sizes, axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
" size, but got shape (1, 3)",
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([[1, 2, 3]], dtype=jnp.int32))
jnp.array([[1, 2, 3]], dtype=jnp.int32), axis_name=axis_name)
with self.assertRaisesWithLiteralMatch(
ValueError,
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
" size, but got shape (0,)",
):
jax.jit(lax.ragged_all_to_all).lower(
jax.jit(lax.ragged_all_to_all, static_argnames='axis_name').lower(
operand, output, input_offsets, send_sizes, output_offsets,
jnp.array([], dtype=jnp.int32))
jnp.array([], dtype=jnp.int32), axis_name=axis_name)
@jtu.sample_product(
[

View File

@ -0,0 +1,219 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.ad_checkpoint
from jax import lax
from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
config.parse_flags_with_absl()
class RaggedCollectiveTest(jtu.JaxTestCase):
@parameterized.named_parameters(
dict(
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2)
),
)
def test_ragged_all_to_all(self, axis_name, mesh_axes):
device_type = jax.devices()[0].platform
if device_type == 'tpu' and jtu.get_tpu_version() == 3:
raise unittest.SkipTest(
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by'
' TPU v3'
)
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
operand = jax.device_put(
jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
output = jax.device_put(
jnp.zeros((2, 4), dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
input_offsets = jax.device_put(
jnp.array([[0, 1], [0, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
send_sizes = jax.device_put(
jnp.array([[1, 2], [1, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
output_offsets = jax.device_put(
jnp.array([[0, 0], [1, 2]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
recv_sizes = jax.device_put(
jnp.array([[1, 1], [2, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
@jax.jit
@partial(
shard_map,
mesh=mesh,
in_specs=(
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
),
out_specs=P(axis_name),
check_rep=False,
)
def fwd(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
):
operand = operand.reshape(operand.shape[1:])
output = output.reshape(output.shape[1:])
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
return lax.ragged_all_to_all(
operand,
output,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=axis_name,
)
mlir_module = fwd.lower(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
).as_text()
self.assertIn('stablehlo.custom_call @ragged_all_to_all', mlir_module)
self.assertIn(
'replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>', mlir_module
)
c = fwd(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
).reshape((2, 4))
self.assertAllClose(
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32)
)
@parameterized.named_parameters(
dict(
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4)
),
)
def test_ragged_all_to_all_axis_index_groups(self, axis_name, mesh_axes):
device_type = jax.devices()[0].platform
if device_type == 'tpu' and jtu.get_tpu_version() == 3:
raise unittest.SkipTest(
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by'
' TPU v3'
)
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
operand = jax.device_put(
jnp.array([[1, 2, 2], [3, 4, 0],
[10, 20, 20], [30, 40, 0]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
output = jax.device_put(
jnp.zeros((4, 4), dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
input_offsets = jax.device_put(
jnp.array([[0, 1], [0, 1],
[0, 1], [0, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
send_sizes = jax.device_put(
jnp.array([[1, 2], [1, 1],
[1, 2], [1, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
output_offsets = jax.device_put(
jnp.array([[0, 0], [1, 2],
[0, 0], [1, 2]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
recv_sizes = jax.device_put(
jnp.array([[1, 1], [2, 1],
[1, 1], [2, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
axis_index_groups = ((0, 1), (2, 3))
@jax.jit
@partial(
shard_map,
mesh=mesh,
in_specs=(
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
),
out_specs=P(axis_name),
check_rep=False,
)
def fwd(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
):
operand = operand.reshape(operand.shape[1:])
output = output.reshape(output.shape[1:])
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
return lax.ragged_all_to_all(
operand,
output,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
)
mlir_module = fwd.lower(
operand, output, input_offsets, send_sizes, output_offsets,
recv_sizes).as_text()
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
self.assertIn("replica_groups = dense<[[0, 1], [2, 3]]> :"
" tensor<2x2xi64>", mlir_module)
c = fwd(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
).reshape((4, 4))
self.assertAllClose(
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0],
[10, 30, 0, 0], [20, 20, 40, 0]], dtype=jnp.int32)
)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())