diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e6cbbf245..12706426b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8197,7 +8197,7 @@ _zeros: Callable = partial(full_like, fill_value=0) def _zero(x): x_aval = core.get_aval(x) return full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) + sharding=x_aval.sharding.with_spec(P())) _ones: Callable = partial(full_like, fill_value=1) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index b556042fe..764e4dcbe 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -22,6 +22,7 @@ from functools import partial import itertools import math +import jax from jax import tree_util from jax._src import core from jax._src import dispatch @@ -459,78 +460,135 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, def ragged_all_to_all( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *, axis_name, axis_index_groups = None): - """Ragged version of :func:`all_to_all`. + """Ragged version of :func:`all_to_all` collective. - For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent - and the outermost (ragged) dimension. ``axis_index_groups`` is default to all - replicas (e.g. there is only one group and covers all axis indices). + We say data are "ragged" when they can be represented as a list of arrays + whose shapes differ only in the size of the leading axis. For example, these + data are ragged, comprising four component arrays:: - Ragged arrays are defined by a set of three arrays: - * ``data``: the ``data`` array is "ragged" along its outermost dimension, - along which each indexed element has variable size. - * ``offsets``: the ``offsets`` array indexes the outermost dimension of the - ``data`` array, and represents the starting offset of each ragged element of - the ``data`` array. - * ``sizes``: the ``sizes`` array represents the size of each ragged element of - the ``data`` array, where the size is specified in units of sub-elements. A - sub-element is defined as the suffix of the ``data`` array shape obtained by - removing the outermost "ragged" dimension. - The ``offsets`` and ``sizes`` arrays must have the same size. + ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)] - # Example ragged tensor - data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}} - offsets: [3] = {0, 1, 4} - sizes: [3] = {1, 3, 4} + We often instead want a contiguous representation, e.g. for batching. But + because the shapes of the components differ, we can't apply ``jnp.stack`` to + represent these data by a single rectangular array with the leading axis + indexing the component arrays. So instead of stacking, we concatenate along + the leading axis and keep track of offsets and sizes. - # Index 'data' at 'offsets'[0], 'sizes'[0]' - {a,b,c} + That is, we can represent ragged data contiguously using a triple of dense + arrays ``(data, offsets, sizes)``: + * ``data``: the concatenated component arrays, + * ``offsets``: 1D array of indices into the leading axis of ``data`` + indicating where the data for each component array begins, + * ``sizes``: 1D array of sizes of the leading axis of each component array. + We refer to this triple as a ragged array. (Offsets can't be computed from + sizes in general to allow for internal padding.) - # Index 'data' at 'offsets'[1], 'sizes'[1]' - {d,e,f},{g,h,i},{j,k,l} + For example:: + data: f32[8,3] = jnp.array([ + [a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x], + ]) + offsets: i32[3] = jnp.array([0, 1, 4]) + sizes: i32[3] = jnp.array([1, 3, 4]) - # Index 'data' at 'offsets'[2], 'sizes'[2]' - {m,n,o},{p,q,r},{s,t,u},{v,w,x} + # To extract the first component array, of type f32[1,3] + data[offsets[0]:offsets[0]+sizes[0]] + # To extract the second component array, of type f32[3,3] + data[offsets[1]:offsets[1]+sizes[1]] - ``output_offsets`` must be sharded in a way that each replica has offsets in - the target replica output perspective. + # To extract the third component array, of type f32[4,3] + data[offsets[2]:offsets[2]+sizes[2]] - For i-th output offset, the current replica will send - `operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th - replica that will be written to - `output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th - replica ``output``. + The ``ragged_all_to_all`` collective operation communicates slices of ragged + arrays between devices. Each caller is both a sender and a receiver. The + ``input_offsets`` and ``send_sizes`` arguments indicate the slices of the + caller's ``operand`` to be sent. Received results are returned in an array + that has the same value of the argument ``output`` except with received values + written at some slices. The ``output_offsets`` argument does *not* indicate + the offsets at which all the received results are written; instead, + ``output_offsets`` indicates the offsets at which the *sent* slices are + written on their corresponding receivers. The sizes of received slices are + indicated by ``recv_sizes``. See below for details. - For example, if we have 2 replicas: + The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and + ``recv_sizes`` must all be the same length, and that length must be divisible + by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and + ``recv_sizes`` must satisfy:: - replica 0: - operand: [1, 2, 2] - output: [0, 0, 0, 0] - input_offsets: [0, 1] - send_sizes: [1, 2] - output_offsets: [0, 0] - recv_sizes: [1, 1] + jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True)) - replica 1: - operand: [3, 4, 0] - output: [0, 0, 0, 0] - input_offsets: [0, 1] - send_sizes: [1, 1] - output_offsets: [1, 2] - recv_sizes: [2, 1] + Specifically, given a call:: - replica 0's result will be: [1, 3, 0, 0] - replica 1's result will be: [2, 2, 4, 0] + result = ragged_all_to_all(operand, output, input_offsets, send_sizes, + output_offsets, recv_sizes, axis_name) + + the caller sends data like:: + + assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes) + N = len(input_offsets) + slices_per_device, leftover = divmod(N, lax.axis_size(axis_name)) + assert not leftover + + for i in range(N): + dst_idx = i // slices_per_device + SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]], + axis_name=axis_name, to_axis_index=dst_idx) + + and receives data in ``result`` like:: + + result = output + output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True) + for i in range(N): + src_idx = i // slices_per_device + result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i] + ].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx)) + + where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local + ``output_offsets`` does not indicate the offsets at which its local ``result`` + is updated; instead, it indicates where the corresponding sent slices are + written on their destination instances. To compute the local offsets at which + received data are written, we apply an ``all_to_all`` on ``output_offsets``. + + For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with + these arguments in each mapped function instance:: + + axis index 0: + operand = [1, 2, 2] + output = [0, 0, 0, 0] + input_offsets = [0, 1] + send_sizes = [1, 2] + output_offsets = [0, 0] + recv_sizes = [1, 1] + + axis index 1: + operand = [3, 4, 0] + output = [0, 0, 0, 0] + input_offsets = [0, 1] + send_sizes = [1, 1] + output_offsets = [1, 2] + recv_sizes = [2, 1] + + then:: + + axis index 0: + result = [1, 3, 0, 0] + + axis index 1: + result = [2, 2, 4, 0] Args: - operand: array with ragged dimension along its outermost dimension. - output: array of ragged input offsets. - input_offsets: array of ragged input send sizes. - send_sizes: array of ragged output data. - output_offsets: array of ragged offsets in the target replica output. - recv_sizes: array of ragged output receive sizes. - axis_name: hashable Python object used to name a pmapped axis (see the - :func:`jax.pmap` documentation for more details). + operand: data array of shape (N, A, B, ...) representing concatenated + (possibly padded) ragged data to be sent. + output: data array of shape (M, A, B, ...) to update with received data. + input_offsets: 1D integer array of shape (K,) representing the offsets of + leading-axis slices into ``operand`` to be sent. + send_sizes: 1D integer array array of shape (K,) representing the sizes of + leading-axis slices into ``operand`` to be sent. + output_offsets: 1D integer array of shape (K,) representing where the + corresponding sent data is written on each corresponding receiver. + recv_sizes: 1D integer array of shape (K,) representing sizes of + leading-axis slices into ``output`` to update with received data. + axis_name: name of the mapped axis over which to perform the communication. axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the first two and last two replicas). Groups must cover all axis indices @@ -538,7 +596,10 @@ def ragged_all_to_all( behavior is undefined. Returns: - array with shape equal to ``output``. + Array of shape (M, A, B, ...) with the same value as the ``output`` except + with received data written into slices starting at + ``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size + ``recv_sizes``. """ if not isinstance(axis_name, (tuple, list)): @@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval( effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects +def _ragged_all_to_all_jvp(primals, tangents, **params): + operand, output, *sizes_and_offsets = primals + operand_dot, output_dot, *_ = tangents + result = ragged_all_to_all_p.bind( + operand, output, *sizes_and_offsets, **params) + if type(operand_dot) is type(output_dot) is ad.Zero: + result_dot = ad.Zero.from_primal_value(result) + else: + operand_dot = ad.instantiate_zeros(operand_dot) + output_dot = ad.instantiate_zeros(output_dot) + result_dot = ragged_all_to_all_p.bind( + operand_dot, output_dot, *sizes_and_offsets, **params) + return result, result_dot + +def _ragged_all_to_all_transpose( + t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, + *, axis_name, axis_index_groups): + if type(t) is ad.Zero: + operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None + output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None + else: + zero = ad.zeros_like_aval(operand.aval) + output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True) + input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True) + operand_t = ragged_all_to_all_p.bind( + t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes, + axis_name=axis_name, axis_index_groups=axis_index_groups) + mask = jax.numpy.cumsum( + jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ + .at[output_offsets_ + recv_sizes].add(-1)) + output_t = jax.numpy.where(mask, 0, t) + return [operand_t, output_t] + [None] * 4 + ragged_all_to_all_p = core.Primitive('ragged_all_to_all') ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) +ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp +ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 48f3d062b..844892adc 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -125,6 +125,80 @@ class RaggedCollectiveTest(jtu.JaxTestCase): c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32) ) + @parameterized.named_parameters( + dict( + testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2) + ), + ) + def test_ragged_all_to_all_grad(self, axis_name, mesh_axes): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + operand = jax.device_put( + jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.float32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + output = jax.device_put( + jnp.zeros((2, 4), dtype=jnp.float32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + input_offsets = jax.device_put( + jnp.array([[0, 1], [0, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + send_sizes = jax.device_put( + jnp.array([[1, 2], [1, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + output_offsets = jax.device_put( + jnp.array([[0, 0], [1, 2]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + recv_sizes = jax.device_put( + jnp.array([[1, 1], [2, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + operand = operand.reshape(operand.shape[1:]) + output = output.reshape(output.shape[1:]) + input_offsets = input_offsets.reshape(input_offsets.shape[1:]) + send_sizes = send_sizes.reshape(send_sizes.shape[1:]) + output_offsets = output_offsets.reshape(output_offsets.shape[1:]) + recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:]) + return lax.ragged_all_to_all( + operand, + output, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=axis_name, + ) + + args = input_offsets, send_sizes, output_offsets, recv_sizes + jtu.check_grads(lambda op, out: fwd(op, out, *args), (operand, output), order=1) + @parameterized.named_parameters( dict( testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4)