mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
add autodiff rules for jax.lax.ragged_all_to_all collective
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future. PiperOrigin-RevId: 735957604
This commit is contained in:
parent
3a26804c68
commit
66a6eb299e
@ -8197,7 +8197,7 @@ _zeros: Callable = partial(full_like, fill_value=0)
|
||||
def _zero(x):
|
||||
x_aval = core.get_aval(x)
|
||||
return full_like(x, shape=(), fill_value=0,
|
||||
sharding=x_aval.sharding.with_spec(P()))
|
||||
sharding=x_aval.sharding.with_spec(P()))
|
||||
|
||||
_ones: Callable = partial(full_like, fill_value=1)
|
||||
|
||||
|
@ -22,6 +22,7 @@ from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -459,78 +460,135 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
|
||||
def ragged_all_to_all(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *,
|
||||
axis_name, axis_index_groups = None):
|
||||
"""Ragged version of :func:`all_to_all`.
|
||||
"""Ragged version of :func:`all_to_all` collective.
|
||||
|
||||
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
|
||||
and the outermost (ragged) dimension. ``axis_index_groups`` is default to all
|
||||
replicas (e.g. there is only one group and covers all axis indices).
|
||||
We say data are "ragged" when they can be represented as a list of arrays
|
||||
whose shapes differ only in the size of the leading axis. For example, these
|
||||
data are ragged, comprising four component arrays::
|
||||
|
||||
Ragged arrays are defined by a set of three arrays:
|
||||
* ``data``: the ``data`` array is "ragged" along its outermost dimension,
|
||||
along which each indexed element has variable size.
|
||||
* ``offsets``: the ``offsets`` array indexes the outermost dimension of the
|
||||
``data`` array, and represents the starting offset of each ragged element of
|
||||
the ``data`` array.
|
||||
* ``sizes``: the ``sizes`` array represents the size of each ragged element of
|
||||
the ``data`` array, where the size is specified in units of sub-elements. A
|
||||
sub-element is defined as the suffix of the ``data`` array shape obtained by
|
||||
removing the outermost "ragged" dimension.
|
||||
The ``offsets`` and ``sizes`` arrays must have the same size.
|
||||
ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)]
|
||||
|
||||
# Example ragged tensor
|
||||
data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
|
||||
offsets: [3] = {0, 1, 4}
|
||||
sizes: [3] = {1, 3, 4}
|
||||
We often instead want a contiguous representation, e.g. for batching. But
|
||||
because the shapes of the components differ, we can't apply ``jnp.stack`` to
|
||||
represent these data by a single rectangular array with the leading axis
|
||||
indexing the component arrays. So instead of stacking, we concatenate along
|
||||
the leading axis and keep track of offsets and sizes.
|
||||
|
||||
# Index 'data' at 'offsets'[0], 'sizes'[0]'
|
||||
{a,b,c}
|
||||
That is, we can represent ragged data contiguously using a triple of dense
|
||||
arrays ``(data, offsets, sizes)``:
|
||||
* ``data``: the concatenated component arrays,
|
||||
* ``offsets``: 1D array of indices into the leading axis of ``data``
|
||||
indicating where the data for each component array begins,
|
||||
* ``sizes``: 1D array of sizes of the leading axis of each component array.
|
||||
We refer to this triple as a ragged array. (Offsets can't be computed from
|
||||
sizes in general to allow for internal padding.)
|
||||
|
||||
# Index 'data' at 'offsets'[1], 'sizes'[1]'
|
||||
{d,e,f},{g,h,i},{j,k,l}
|
||||
For example::
|
||||
data: f32[8,3] = jnp.array([
|
||||
[a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x],
|
||||
])
|
||||
offsets: i32[3] = jnp.array([0, 1, 4])
|
||||
sizes: i32[3] = jnp.array([1, 3, 4])
|
||||
|
||||
# Index 'data' at 'offsets'[2], 'sizes'[2]'
|
||||
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
|
||||
# To extract the first component array, of type f32[1,3]
|
||||
data[offsets[0]:offsets[0]+sizes[0]]
|
||||
|
||||
# To extract the second component array, of type f32[3,3]
|
||||
data[offsets[1]:offsets[1]+sizes[1]]
|
||||
|
||||
``output_offsets`` must be sharded in a way that each replica has offsets in
|
||||
the target replica output perspective.
|
||||
# To extract the third component array, of type f32[4,3]
|
||||
data[offsets[2]:offsets[2]+sizes[2]]
|
||||
|
||||
For i-th output offset, the current replica will send
|
||||
`operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th
|
||||
replica that will be written to
|
||||
`output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th
|
||||
replica ``output``.
|
||||
The ``ragged_all_to_all`` collective operation communicates slices of ragged
|
||||
arrays between devices. Each caller is both a sender and a receiver. The
|
||||
``input_offsets`` and ``send_sizes`` arguments indicate the slices of the
|
||||
caller's ``operand`` to be sent. Received results are returned in an array
|
||||
that has the same value of the argument ``output`` except with received values
|
||||
written at some slices. The ``output_offsets`` argument does *not* indicate
|
||||
the offsets at which all the received results are written; instead,
|
||||
``output_offsets`` indicates the offsets at which the *sent* slices are
|
||||
written on their corresponding receivers. The sizes of received slices are
|
||||
indicated by ``recv_sizes``. See below for details.
|
||||
|
||||
For example, if we have 2 replicas:
|
||||
The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and
|
||||
``recv_sizes`` must all be the same length, and that length must be divisible
|
||||
by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and
|
||||
``recv_sizes`` must satisfy::
|
||||
|
||||
replica 0:
|
||||
operand: [1, 2, 2]
|
||||
output: [0, 0, 0, 0]
|
||||
input_offsets: [0, 1]
|
||||
send_sizes: [1, 2]
|
||||
output_offsets: [0, 0]
|
||||
recv_sizes: [1, 1]
|
||||
jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True))
|
||||
|
||||
replica 1:
|
||||
operand: [3, 4, 0]
|
||||
output: [0, 0, 0, 0]
|
||||
input_offsets: [0, 1]
|
||||
send_sizes: [1, 1]
|
||||
output_offsets: [1, 2]
|
||||
recv_sizes: [2, 1]
|
||||
Specifically, given a call::
|
||||
|
||||
replica 0's result will be: [1, 3, 0, 0]
|
||||
replica 1's result will be: [2, 2, 4, 0]
|
||||
result = ragged_all_to_all(operand, output, input_offsets, send_sizes,
|
||||
output_offsets, recv_sizes, axis_name)
|
||||
|
||||
the caller sends data like::
|
||||
|
||||
assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes)
|
||||
N = len(input_offsets)
|
||||
slices_per_device, leftover = divmod(N, lax.axis_size(axis_name))
|
||||
assert not leftover
|
||||
|
||||
for i in range(N):
|
||||
dst_idx = i // slices_per_device
|
||||
SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]],
|
||||
axis_name=axis_name, to_axis_index=dst_idx)
|
||||
|
||||
and receives data in ``result`` like::
|
||||
|
||||
result = output
|
||||
output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
|
||||
for i in range(N):
|
||||
src_idx = i // slices_per_device
|
||||
result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i]
|
||||
].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx))
|
||||
|
||||
where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local
|
||||
``output_offsets`` does not indicate the offsets at which its local ``result``
|
||||
is updated; instead, it indicates where the corresponding sent slices are
|
||||
written on their destination instances. To compute the local offsets at which
|
||||
received data are written, we apply an ``all_to_all`` on ``output_offsets``.
|
||||
|
||||
For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with
|
||||
these arguments in each mapped function instance::
|
||||
|
||||
axis index 0:
|
||||
operand = [1, 2, 2]
|
||||
output = [0, 0, 0, 0]
|
||||
input_offsets = [0, 1]
|
||||
send_sizes = [1, 2]
|
||||
output_offsets = [0, 0]
|
||||
recv_sizes = [1, 1]
|
||||
|
||||
axis index 1:
|
||||
operand = [3, 4, 0]
|
||||
output = [0, 0, 0, 0]
|
||||
input_offsets = [0, 1]
|
||||
send_sizes = [1, 1]
|
||||
output_offsets = [1, 2]
|
||||
recv_sizes = [2, 1]
|
||||
|
||||
then::
|
||||
|
||||
axis index 0:
|
||||
result = [1, 3, 0, 0]
|
||||
|
||||
axis index 1:
|
||||
result = [2, 2, 4, 0]
|
||||
|
||||
Args:
|
||||
operand: array with ragged dimension along its outermost dimension.
|
||||
output: array of ragged input offsets.
|
||||
input_offsets: array of ragged input send sizes.
|
||||
send_sizes: array of ragged output data.
|
||||
output_offsets: array of ragged offsets in the target replica output.
|
||||
recv_sizes: array of ragged output receive sizes.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
operand: data array of shape (N, A, B, ...) representing concatenated
|
||||
(possibly padded) ragged data to be sent.
|
||||
output: data array of shape (M, A, B, ...) to update with received data.
|
||||
input_offsets: 1D integer array of shape (K,) representing the offsets of
|
||||
leading-axis slices into ``operand`` to be sent.
|
||||
send_sizes: 1D integer array array of shape (K,) representing the sizes of
|
||||
leading-axis slices into ``operand`` to be sent.
|
||||
output_offsets: 1D integer array of shape (K,) representing where the
|
||||
corresponding sent data is written on each corresponding receiver.
|
||||
recv_sizes: 1D integer array of shape (K,) representing sizes of
|
||||
leading-axis slices into ``output`` to update with received data.
|
||||
axis_name: name of the mapped axis over which to perform the communication.
|
||||
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
||||
an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the
|
||||
first two and last two replicas). Groups must cover all axis indices
|
||||
@ -538,7 +596,10 @@ def ragged_all_to_all(
|
||||
behavior is undefined.
|
||||
|
||||
Returns:
|
||||
array with shape equal to ``output``.
|
||||
Array of shape (M, A, B, ...) with the same value as the ``output`` except
|
||||
with received data written into slices starting at
|
||||
``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size
|
||||
``recv_sizes``.
|
||||
"""
|
||||
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval(
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
def _ragged_all_to_all_jvp(primals, tangents, **params):
|
||||
operand, output, *sizes_and_offsets = primals
|
||||
operand_dot, output_dot, *_ = tangents
|
||||
result = ragged_all_to_all_p.bind(
|
||||
operand, output, *sizes_and_offsets, **params)
|
||||
if type(operand_dot) is type(output_dot) is ad.Zero:
|
||||
result_dot = ad.Zero.from_primal_value(result)
|
||||
else:
|
||||
operand_dot = ad.instantiate_zeros(operand_dot)
|
||||
output_dot = ad.instantiate_zeros(output_dot)
|
||||
result_dot = ragged_all_to_all_p.bind(
|
||||
operand_dot, output_dot, *sizes_and_offsets, **params)
|
||||
return result, result_dot
|
||||
|
||||
def _ragged_all_to_all_transpose(
|
||||
t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
|
||||
*, axis_name, axis_index_groups):
|
||||
if type(t) is ad.Zero:
|
||||
operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
|
||||
output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None
|
||||
else:
|
||||
zero = ad.zeros_like_aval(operand.aval)
|
||||
output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
|
||||
input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True)
|
||||
operand_t = ragged_all_to_all_p.bind(
|
||||
t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes,
|
||||
axis_name=axis_name, axis_index_groups=axis_index_groups)
|
||||
mask = jax.numpy.cumsum(
|
||||
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
|
||||
.at[output_offsets_ + recv_sizes].add(-1))
|
||||
output_t = jax.numpy.where(mask, 0, t)
|
||||
return [operand_t, output_t] + [None] * 4
|
||||
|
||||
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
|
||||
ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval)
|
||||
ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp
|
||||
ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose
|
||||
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
|
||||
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')
|
||||
|
||||
|
@ -125,6 +125,80 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
|
||||
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32)
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2)
|
||||
),
|
||||
)
|
||||
def test_ragged_all_to_all_grad(self, axis_name, mesh_axes):
|
||||
device_type = jax.devices()[0].platform
|
||||
if device_type == 'tpu' and jtu.get_tpu_version() < 4:
|
||||
raise unittest.SkipTest(
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU'
|
||||
f' v{jtu.get_tpu_version()}'
|
||||
)
|
||||
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
||||
operand = jax.device_put(
|
||||
jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.float32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output = jax.device_put(
|
||||
jnp.zeros((2, 4), dtype=jnp.float32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
input_offsets = jax.device_put(
|
||||
jnp.array([[0, 1], [0, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
send_sizes = jax.device_put(
|
||||
jnp.array([[1, 2], [1, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output_offsets = jax.device_put(
|
||||
jnp.array([[0, 0], [1, 2]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
recv_sizes = jax.device_put(
|
||||
jnp.array([[1, 1], [2, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
|
||||
@partial(
|
||||
shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=(
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
),
|
||||
out_specs=P(axis_name),
|
||||
check_rep=False,
|
||||
)
|
||||
def fwd(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
):
|
||||
operand = operand.reshape(operand.shape[1:])
|
||||
output = output.reshape(output.shape[1:])
|
||||
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
|
||||
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
|
||||
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
|
||||
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
|
||||
return lax.ragged_all_to_all(
|
||||
operand,
|
||||
output,
|
||||
input_offsets,
|
||||
send_sizes,
|
||||
output_offsets,
|
||||
recv_sizes,
|
||||
axis_name=axis_name,
|
||||
)
|
||||
|
||||
args = input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
jtu.check_grads(lambda op, out: fwd(op, out, *args), (operand, output), order=1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user