add autodiff rules for jax.lax.ragged_all_to_all collective

also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.

PiperOrigin-RevId: 735957604
This commit is contained in:
Matthew Johnson 2025-03-11 18:21:19 -07:00 committed by jax authors
parent 3a26804c68
commit 66a6eb299e
3 changed files with 229 additions and 59 deletions

View File

@ -8197,7 +8197,7 @@ _zeros: Callable = partial(full_like, fill_value=0)
def _zero(x):
x_aval = core.get_aval(x)
return full_like(x, shape=(), fill_value=0,
sharding=x_aval.sharding.with_spec(P()))
sharding=x_aval.sharding.with_spec(P()))
_ones: Callable = partial(full_like, fill_value=1)

View File

@ -22,6 +22,7 @@ from functools import partial
import itertools
import math
import jax
from jax import tree_util
from jax._src import core
from jax._src import dispatch
@ -459,78 +460,135 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
def ragged_all_to_all(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *,
axis_name, axis_index_groups = None):
"""Ragged version of :func:`all_to_all`.
"""Ragged version of :func:`all_to_all` collective.
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
and the outermost (ragged) dimension. ``axis_index_groups`` is default to all
replicas (e.g. there is only one group and covers all axis indices).
We say data are "ragged" when they can be represented as a list of arrays
whose shapes differ only in the size of the leading axis. For example, these
data are ragged, comprising four component arrays::
Ragged arrays are defined by a set of three arrays:
* ``data``: the ``data`` array is "ragged" along its outermost dimension,
along which each indexed element has variable size.
* ``offsets``: the ``offsets`` array indexes the outermost dimension of the
``data`` array, and represents the starting offset of each ragged element of
the ``data`` array.
* ``sizes``: the ``sizes`` array represents the size of each ragged element of
the ``data`` array, where the size is specified in units of sub-elements. A
sub-element is defined as the suffix of the ``data`` array shape obtained by
removing the outermost "ragged" dimension.
The ``offsets`` and ``sizes`` arrays must have the same size.
ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)]
# Example ragged tensor
data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
offsets: [3] = {0, 1, 4}
sizes: [3] = {1, 3, 4}
We often instead want a contiguous representation, e.g. for batching. But
because the shapes of the components differ, we can't apply ``jnp.stack`` to
represent these data by a single rectangular array with the leading axis
indexing the component arrays. So instead of stacking, we concatenate along
the leading axis and keep track of offsets and sizes.
# Index 'data' at 'offsets'[0], 'sizes'[0]'
{a,b,c}
That is, we can represent ragged data contiguously using a triple of dense
arrays ``(data, offsets, sizes)``:
* ``data``: the concatenated component arrays,
* ``offsets``: 1D array of indices into the leading axis of ``data``
indicating where the data for each component array begins,
* ``sizes``: 1D array of sizes of the leading axis of each component array.
We refer to this triple as a ragged array. (Offsets can't be computed from
sizes in general to allow for internal padding.)
# Index 'data' at 'offsets'[1], 'sizes'[1]'
{d,e,f},{g,h,i},{j,k,l}
For example::
data: f32[8,3] = jnp.array([
[a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x],
])
offsets: i32[3] = jnp.array([0, 1, 4])
sizes: i32[3] = jnp.array([1, 3, 4])
# Index 'data' at 'offsets'[2], 'sizes'[2]'
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
# To extract the first component array, of type f32[1,3]
data[offsets[0]:offsets[0]+sizes[0]]
# To extract the second component array, of type f32[3,3]
data[offsets[1]:offsets[1]+sizes[1]]
``output_offsets`` must be sharded in a way that each replica has offsets in
the target replica output perspective.
# To extract the third component array, of type f32[4,3]
data[offsets[2]:offsets[2]+sizes[2]]
For i-th output offset, the current replica will send
`operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th
replica that will be written to
`output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th
replica ``output``.
The ``ragged_all_to_all`` collective operation communicates slices of ragged
arrays between devices. Each caller is both a sender and a receiver. The
``input_offsets`` and ``send_sizes`` arguments indicate the slices of the
caller's ``operand`` to be sent. Received results are returned in an array
that has the same value of the argument ``output`` except with received values
written at some slices. The ``output_offsets`` argument does *not* indicate
the offsets at which all the received results are written; instead,
``output_offsets`` indicates the offsets at which the *sent* slices are
written on their corresponding receivers. The sizes of received slices are
indicated by ``recv_sizes``. See below for details.
For example, if we have 2 replicas:
The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and
``recv_sizes`` must all be the same length, and that length must be divisible
by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and
``recv_sizes`` must satisfy::
replica 0:
operand: [1, 2, 2]
output: [0, 0, 0, 0]
input_offsets: [0, 1]
send_sizes: [1, 2]
output_offsets: [0, 0]
recv_sizes: [1, 1]
jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True))
replica 1:
operand: [3, 4, 0]
output: [0, 0, 0, 0]
input_offsets: [0, 1]
send_sizes: [1, 1]
output_offsets: [1, 2]
recv_sizes: [2, 1]
Specifically, given a call::
replica 0's result will be: [1, 3, 0, 0]
replica 1's result will be: [2, 2, 4, 0]
result = ragged_all_to_all(operand, output, input_offsets, send_sizes,
output_offsets, recv_sizes, axis_name)
the caller sends data like::
assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes)
N = len(input_offsets)
slices_per_device, leftover = divmod(N, lax.axis_size(axis_name))
assert not leftover
for i in range(N):
dst_idx = i // slices_per_device
SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]],
axis_name=axis_name, to_axis_index=dst_idx)
and receives data in ``result`` like::
result = output
output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
for i in range(N):
src_idx = i // slices_per_device
result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i]
].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx))
where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local
``output_offsets`` does not indicate the offsets at which its local ``result``
is updated; instead, it indicates where the corresponding sent slices are
written on their destination instances. To compute the local offsets at which
received data are written, we apply an ``all_to_all`` on ``output_offsets``.
For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with
these arguments in each mapped function instance::
axis index 0:
operand = [1, 2, 2]
output = [0, 0, 0, 0]
input_offsets = [0, 1]
send_sizes = [1, 2]
output_offsets = [0, 0]
recv_sizes = [1, 1]
axis index 1:
operand = [3, 4, 0]
output = [0, 0, 0, 0]
input_offsets = [0, 1]
send_sizes = [1, 1]
output_offsets = [1, 2]
recv_sizes = [2, 1]
then::
axis index 0:
result = [1, 3, 0, 0]
axis index 1:
result = [2, 2, 4, 0]
Args:
operand: array with ragged dimension along its outermost dimension.
output: array of ragged input offsets.
input_offsets: array of ragged input send sizes.
send_sizes: array of ragged output data.
output_offsets: array of ragged offsets in the target replica output.
recv_sizes: array of ragged output receive sizes.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
operand: data array of shape (N, A, B, ...) representing concatenated
(possibly padded) ragged data to be sent.
output: data array of shape (M, A, B, ...) to update with received data.
input_offsets: 1D integer array of shape (K,) representing the offsets of
leading-axis slices into ``operand`` to be sent.
send_sizes: 1D integer array array of shape (K,) representing the sizes of
leading-axis slices into ``operand`` to be sent.
output_offsets: 1D integer array of shape (K,) representing where the
corresponding sent data is written on each corresponding receiver.
recv_sizes: 1D integer array of shape (K,) representing sizes of
leading-axis slices into ``output`` to update with received data.
axis_name: name of the mapped axis over which to perform the communication.
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the
first two and last two replicas). Groups must cover all axis indices
@ -538,7 +596,10 @@ def ragged_all_to_all(
behavior is undefined.
Returns:
array with shape equal to ``output``.
Array of shape (M, A, B, ...) with the same value as the ``output`` except
with received data written into slices starting at
``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size
``recv_sizes``.
"""
if not isinstance(axis_name, (tuple, list)):
@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval(
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
def _ragged_all_to_all_jvp(primals, tangents, **params):
operand, output, *sizes_and_offsets = primals
operand_dot, output_dot, *_ = tangents
result = ragged_all_to_all_p.bind(
operand, output, *sizes_and_offsets, **params)
if type(operand_dot) is type(output_dot) is ad.Zero:
result_dot = ad.Zero.from_primal_value(result)
else:
operand_dot = ad.instantiate_zeros(operand_dot)
output_dot = ad.instantiate_zeros(output_dot)
result_dot = ragged_all_to_all_p.bind(
operand_dot, output_dot, *sizes_and_offsets, **params)
return result, result_dot
def _ragged_all_to_all_transpose(
t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
*, axis_name, axis_index_groups):
if type(t) is ad.Zero:
operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None
else:
zero = ad.zeros_like_aval(operand.aval)
output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True)
operand_t = ragged_all_to_all_p.bind(
t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes,
axis_name=axis_name, axis_index_groups=axis_index_groups)
mask = jax.numpy.cumsum(
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
.at[output_offsets_ + recv_sizes].add(-1))
output_t = jax.numpy.where(mask, 0, t)
return [operand_t, output_t] + [None] * 4
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval)
ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp
ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')

View File

@ -125,6 +125,80 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32)
)
@parameterized.named_parameters(
dict(
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2)
),
)
def test_ragged_all_to_all_grad(self, axis_name, mesh_axes):
device_type = jax.devices()[0].platform
if device_type == 'tpu' and jtu.get_tpu_version() < 4:
raise unittest.SkipTest(
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU'
f' v{jtu.get_tpu_version()}'
)
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
operand = jax.device_put(
jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.float32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
output = jax.device_put(
jnp.zeros((2, 4), dtype=jnp.float32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
input_offsets = jax.device_put(
jnp.array([[0, 1], [0, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
send_sizes = jax.device_put(
jnp.array([[1, 2], [1, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
output_offsets = jax.device_put(
jnp.array([[0, 0], [1, 2]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
recv_sizes = jax.device_put(
jnp.array([[1, 1], [2, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
@partial(
shard_map,
mesh=mesh,
in_specs=(
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
),
out_specs=P(axis_name),
check_rep=False,
)
def fwd(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
):
operand = operand.reshape(operand.shape[1:])
output = output.reshape(output.shape[1:])
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
return lax.ragged_all_to_all(
operand,
output,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=axis_name,
)
args = input_offsets, send_sizes, output_offsets, recv_sizes
jtu.check_grads(lambda op, out: fwd(op, out, *args), (operand, output), order=1)
@parameterized.named_parameters(
dict(
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4)