diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9878ddc3c..99760099d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -16,6 +16,7 @@ from __future__ import annotations import builtins from collections.abc import Callable, Sequence +import dataclasses import enum import functools from functools import partial @@ -2227,10 +2228,122 @@ def ragged_dot( Results: (m, n) shaped array with preferred_element_type element type. """ - return ragged_dot_p.bind(lhs, rhs, group_sizes, - precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type, - group_offset=group_offset) + return ragged_dot_general( + lhs, + rhs, + group_sizes, + ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS, + precision=canonicalize_precision(precision), + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) + + +@dataclasses.dataclass(frozen=True) +class RaggedDotDimensionNumbers(): + """Describes ragged, group, and dot dimensions for ragged dot general. + + Args: + dot_dimension_numbers: a tuple of tuples of sequences of ints of the form + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + rhs_batch_dims))`. + lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged + dimensions. + rhs_group_dimensions: a sequence of ints indicating the 'rhs' group + dimensions. + """ + dot_dimension_numbers: DotDimensionNumbers + lhs_ragged_dimensions: Sequence[int] + rhs_group_dimensions: Sequence[int] + + def __init__( + self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions + ): + super().__setattr__( + 'dot_dimension_numbers', + tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers), + ) + super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions)) + super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions)) + + +def _from_maybe_ragged( + dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers, +) -> DotDimensionNumbers: + return ( + dot_dimension_numbers.dot_dimension_numbers + if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers) + else dot_dimension_numbers + ) + + +# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.) +_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], +) + + +def ragged_dot_general( + lhs: Array, + rhs: Array, + group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, + group_offset: Array | None = None, +) -> Array: + """Ragged matrix multiplication. + + Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and + a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and + ``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally, + ``lhs`` is required to have one ragged dimension, and ``rhs`` may have at + most one group dimension. + + Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has + three modes, depending on the kind of the lhs ragged dimension: + 1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`. + Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``, + and `x...` are the lhs non-contracting dims outer to the ragged dim. + 2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`. + Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and + ``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim. + 3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`. + Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and + ``rhs``, and `x...` are the lhs batch dims outer to the ragged dim. + If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according + to the rules above. + + Args: + lhs: an array + rhs: an array + group_sizes: an array with integer element type + ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to + specify the dot dimension numbers, lhs ragged dimension, and rhs group + dimension. + precision: Optional. Consistent with precision argument for + :func:`jax.lax.dot`. + preferred_element_type: Optional. Consistent with precision argument for + :func:`jax.lax.dot`. + group_offset: Optional. (1,) shaped array that indicates the group in + group_sizes to start computing from. If not specified, defaults to [0]. + + Results: + An array whose shape is the same as that produced by `dot_general`, with an + extra leading dimension of size `g` in the case where the lhs ragged + dimension is a contracting dimension. + """ + return ragged_dot_general_p.bind( + lhs, + rhs, + group_sizes, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, + precision=canonicalize_precision(precision), + preferred_element_type=preferred_element_type, + group_offset=group_offset, + ) def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None @@ -4593,7 +4706,7 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, out_sharding): if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): msg = ("dot_general requires lhs dimension numbers to be nonnegative and " @@ -4654,12 +4767,17 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers) def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) batch_shape = tuple(lhs_shape[i] for i in lhs_batch) lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch))) lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch) - rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch))) - rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch) + rhs_group = () + if isinstance(dimension_numbers, RaggedDotDimensionNumbers): + rhs_group = tuple(dimension_numbers.rhs_group_dimensions) + rhs_contract_or_batch_or_group = tuple( + sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group) + ) + rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group) return batch_shape + lhs_tensored_shape + rhs_tensored_shape @@ -4723,7 +4841,7 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - out_sharding): + out_sharding, name: str = 'lax.dot_general'): if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError del dimension_numbers # unused @@ -4744,8 +4862,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, result_dtype = rhs.dtype else: if lhs.dtype != rhs.dtype: - raise TypeError( - f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}") + raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}') result_dtype = lhs.dtype has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)) return _maybe_upcast(result_dtype, preferred_element_type, @@ -4884,8 +5001,9 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): # explicitly present dimensions that this dot_general is zipping together. lbd, rbd = batch_dims assert lbd is not None or rbd is not None - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers) + is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers) def bump_dims(dims, b): return tuple(np.add(dims, np.greater_equal(dims, b))) @@ -4908,8 +5026,14 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): elif (type(rbd) is int and lbd is None): # The right vmapped dimension becomes an additional tensor dimension in the # batched dot_general. - rhs_tensor = [d for d in range(rhs_ndim) - if d not in rhs_batch and d not in rhs_contract] + rhs_tensor = list( + remaining( + range(rhs_ndim), + rhs_batch, + rhs_contract, + dimension_numbers.rhs_group_dimensions if is_ragged_dot else [], + ) + ) result_batch_dim = (lhs_ndim - len(lhs_contract) + int(sum(np.less(rhs_tensor, rbd)))) rhs_batch = bump_dims(rhs_batch, rbd) @@ -4919,6 +5043,16 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers): assert False new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) + if is_ragged_dot: + new_dimension_numbers = RaggedDotDimensionNumbers( + dot_dimension_numbers=new_dimension_numbers, + lhs_ragged_dimensions=bump_dims( + dimension_numbers.lhs_ragged_dimensions, lbd + ), + rhs_group_dimensions=bump_dims( + dimension_numbers.rhs_group_dimensions, rbd + ), + ) return new_dimension_numbers, result_batch_dim def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *, @@ -5010,15 +5144,6 @@ def _dot_general_batch_unpack_dims(batch_dims): lbd, rbd = batch_dims return (lbd, rbd) -# DotDimensionNumbers used in the dot_general call for ragged_dot(). -_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( - ([2, 0], [1, 0]), - ([], []), -) -_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = ( - ([3, 1], [2, 1]), - ([0], [0]), -) ad.defbilinear(dot_general_p, _dot_general_transpose_lhs, _dot_general_transpose_rhs) @@ -5186,58 +5311,181 @@ for platform in ["cpu", "tpu"]: platform=platform) -def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape: - if len(lhs.shape) == 3: - # Batched case - b, m, k = lhs.shape - b2, group_count, rk, n = rhs.shape - b3 = group_sizes.shape[0] - if b != b2: - raise TypeError( - f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and' - f' {b2}.' - ) - if b3 != b: - raise TypeError( - 'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got' - f' {b3} and {b}.' - ) - if k != rk: - raise TypeError( - f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and' - f' {rk}.' - ) - num_groups = group_sizes.shape[1] - if group_count != num_groups: - raise TypeError( - 'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got' - f' {group_count} and {num_groups}.' - ) - return (b, m, n) +class RaggedDotMode(enum.Enum): + RAGGED_NONCONTRACTING = 1 # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n] + RAGGED_CONTRACTING = 2 # [b,m,k], [b,k,n], [b,g] -> [g,b,m,n] + RAGGED_BATCH = 3 # [b,m,k], [b,k,n], [g] -> [b,m,n] - m, k = lhs.shape - group_count, rk, n = rhs.shape - if k != rk: - raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.") - num_groups = group_sizes.shape[0] - if group_count != num_groups: - raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.") - return (m, n) -def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, - precision, preferred_element_type: DTypeLike | None, - **_) -> np.dtype: +def _ragged_dot_mode_and_dim( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> tuple[RaggedDotMode, int]: + assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1 + lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0] + (lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers + lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch) + if lhs_ragged_dim in lhs_noncontracting: + mode = RaggedDotMode.RAGGED_NONCONTRACTING + elif lhs_ragged_dim in lhs_contracting: + mode = RaggedDotMode.RAGGED_CONTRACTING + elif lhs_ragged_dim in lhs_batch: + mode = RaggedDotMode.RAGGED_BATCH + else: + raise TypeError( + f'lhs_ragged_dim {lhs_ragged_dim} not found in ' + f'lhs_noncontracting {lhs_noncontracting}, ' + f'lhs_contracting {lhs_contracting}, or ' + f'lhs_batch {lhs_batch}.' + ) + return mode, lhs_ragged_dim + + +def _ragged_dot_mode( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> RaggedDotMode: + return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0] + + +def _is_ragged_contracting( + lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers +) -> bool: + return ( + _ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers) + == RaggedDotMode.RAGGED_CONTRACTING + ) + + +def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract): + batch, contract = map(list, (batch, contract)) + noncontract = remaining(range(rank), contract, batch) + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + return batch + noncontract[: noncontract.index(ragged_dim)] + case RaggedDotMode.RAGGED_CONTRACTING: + return batch + contract[: contract.index(ragged_dim)] + case RaggedDotMode.RAGGED_BATCH: + return batch[: batch.index(ragged_dim)] + + +def _ragged_dot_general_shape_rule( + lhs, + rhs, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: DTypeLike | None, + **_, +): + def _check_in_range(dim, rank, dim_name, arg_name): + if dim < 0 or dim >= rank: + raise TypeError( + f'ragged_dot_general requires {dim_name} numbers to be nonnegative ' + f'and less than the number of axes of the {arg_name} value, ' + f'got {dim} for {arg_name} of rank {rank}.' + ) + + # Validate the lhs ragged dimension, and find out which mode we're in. + if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1: + raise TypeError( + 'ragged_dot_general expects exactly one lhs ragged dimension.' + ) + lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0] + _check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs') + mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers) + + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + + # Validate the shape of group_sizes, if it is something other than [g]. + if group_sizes.ndim == 0: + raise TypeError('expected rank of group_sizes to be >=1.') + if group_sizes.ndim != 1: + # Construct the expected shape [b...,x...,g] of group_sizes. + prefix_dims = _ragged_dot_prefix_dims( + mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting + ) + expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims) + expected_gs_shape += (group_sizes.shape[-1],) + # TODO(pravnar): Permit other broadcastable shapes. + if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape): + raise TypeError( + 'expected group_sizes to have shape ' + f'{expected_gs_shape}, got {group_sizes.shape}.' + ) + num_groups = group_sizes.shape[-1] + + # Validate properties of the rhs group dimension(s). + rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions + match mode: + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + if len(rhs_group_dims) != 0: + raise TypeError( + 'ragged_dot_general requires zero group dimensions in the rhs ' + 'when lhs ragged dimension is contracting or batch.' + ) + case RaggedDotMode.RAGGED_NONCONTRACTING: + if len(rhs_group_dims) != 1: + raise TypeError( + 'ragged_dot_general requires exactly one rhs group dimension ' + 'when lhs ragged dimension is noncontracting.' + ) + rhs_group_dim = rhs_group_dims[0] + _check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs') + if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting: + raise TypeError( + 'ragged_dot_general requires rhs group dimension numbers to be ' + 'distinct from contracting and batch dimensions.' + ) + if rhs.shape[rhs_group_dim] != num_groups: + raise TypeError( + 'expected rhs group dimension size to be ' + f'{num_groups}, got {rhs.shape[rhs_group_dim]}.' + ) + + out_shape = _dot_general_shape_rule( + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + out_sharding=None, + ) + if mode == RaggedDotMode.RAGGED_CONTRACTING: + out_shape = (num_groups,) + out_shape + return out_shape + + +def _ragged_dot_general_dtype_rule( + lhs: Array, + rhs: Array, + group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, + precision, + preferred_element_type: DTypeLike | None, + **_, +) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): - raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") - # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. + raise TypeError( + 'ragged_dot_general requires that ' + 'group_sizes.dtype is subtype of np.integer.' + ) + # defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl. return _dot_general_dtype_rule( - lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, - precision=precision, preferred_element_type=preferred_element_type, - out_sharding=None) + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + out_sharding=None, + name='lax.ragged_dot_general', + ) -def _ragged_dot_jvp_rule( - primals, tangents, precision, preferred_element_type, group_offset +def _ragged_dot_general_jvp_rule( + primals, tangents, ragged_dot_dimension_numbers, + precision, preferred_element_type, group_offset ): # note - we could ostensibly just get this by passing on the # value to ragged_dot below, but, this feels cleaner. @@ -5247,20 +5495,22 @@ def _ragged_dot_jvp_rule( dx, dy, _ = tangents # no tan on the gs # primal - primal_out = ragged_dot( + primal_out = ragged_dot_general( x, y, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) # tangent dx_out = ( - ragged_dot( + ragged_dot_general( dx, y, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) @@ -5268,10 +5518,11 @@ def _ragged_dot_jvp_rule( else _zeros(primal_out) ) dy_out = ( - ragged_dot( + ragged_dot_general( x, dy, gs, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) @@ -5283,58 +5534,111 @@ def _ragged_dot_jvp_rule( return primal_out, tangent_out -def _ragged_to_dense(x, y, group_sizes): - from jax._src.lax import control_flow # avoid circular imports - shape = (y.shape[0], x.shape[0], x.shape[1]) - x = broadcast_in_dim(x, shape, [1, 2]) - iota = broadcasted_iota(group_sizes.dtype, shape, 1) - group_ends = control_flow.cumsum(group_sizes) - group_starts = concatenate( - [_zeros(group_sizes)[:1], group_ends[:-1]], - dimension=0, - ) - group_ends = broadcast_in_dim(group_ends, shape, (0,)) - group_starts = broadcast_in_dim(group_starts, shape, (0,)) - mask = bitwise_and(group_starts <= iota, iota < group_ends) - x = select(mask, x, _zeros(x)) - return x - - -def _ragged_dot_transpose_rule( - ct, *operands, precision, preferred_element_type, group_offset +def _ragged_dot_general_transpose_rule( + ct, + x, + y, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: DTypeLike | None, + group_offset: Array | None, ): - x, y, gs = operands if group_offset is not None: raise NotImplementedError('Unimplemented group_offset support.') - if ad.is_undefined_primal(y): - grad_x = None - else: - y_t = _matrix_transpose(y) - grad_x = ragged_dot( - ct, - y_t, - gs, - precision=precision, - preferred_element_type=preferred_element_type, - ) + (x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers + x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x) + y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y) + x_kept = remaining(range(x_ndim), x_contract, x_batch) + y_group = ragged_dot_dimension_numbers.rhs_group_dimensions + y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group) + mode, lhs_ragged_dim = _ragged_dot_mode_and_dim( + x_ndim, ragged_dot_dimension_numbers + ) - if ad.is_undefined_primal(x): - grad_y = None - else: - y = y.aval if ad.is_undefined_primal(y) else y - x_dense = _ragged_to_dense(x, y, group_sizes=gs) - ct_dense = _ragged_to_dense(ct, y, group_sizes=gs) - dimension_numbers = (([1], [1]), ([0], [0])) - grad_y = dot_general( - x_dense, - ct_dense, - dimension_numbers, - precision=precision, - preferred_element_type=preferred_element_type, - ) + unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError( + f'Unimplemented {fn_name} for ragged dot general in mode ' + f'{ragged_dot_mode.name}.' + ) - return grad_x, grad_y, None + # This is a hack to ensure we continue to emit the `_matrix_transpose` for the + # grad_x case. This isn't strictly necessary since we have dot_dim_nums. + # TODO(pravnar): Remove this once we no longer care to emit the transpose. + _is_basic_ragged_dot = ( + x_ndim == 2 + and y_ndim == 3 + and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS + ) + + def grad_x_dims(): + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) + dims = ( + ragged_dot_dimension_numbers + if _is_basic_ragged_dot + else RaggedDotDimensionNumbers( + dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)), + lhs_ragged_dimensions=[ + len(x_batch) + x_kept.index(lhs_ragged_dim) + ], + rhs_group_dimensions=y_group, + ) + ) + x_contract_sorted_by_y = list( + np.take(x_contract, np.argsort(y_contract)) + ) + unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + raise unimplemented('grad_x_dims', mode) + return dims, unsorted_axes + + def grad_y_dims(): + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept) + dims = RaggedDotDimensionNumbers( + dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)), + lhs_ragged_dimensions=[lhs_ragged_dim], + rhs_group_dimensions=[], + ) + y_contract_sorted_by_x = list( + np.take(y_contract, np.argsort(x_contract)) + ) + unsorted_axes = ( + list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept + ) + case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: + raise unimplemented('grad_y_dims', mode) + return dims, unsorted_axes + + def _ragged_dot_grad(lhs, rhs, dims_fn, aval): + dims, unsorted_axes = dims_fn() + ragged_dot_general_out = ragged_dot_general( + lhs, rhs, group_sizes, dims, precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset) + result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes))) + if result.dtype != aval.dtype: + result = _convert_element_type(result, aval.dtype, aval.weak_type) + return result + + x_bar = ( + None + if ad.is_undefined_primal(y) + else _ragged_dot_grad(ct, + _matrix_transpose(y) if _is_basic_ragged_dot else y, + grad_x_dims, + x.aval) + ) + y_bar = ( + None + if ad.is_undefined_primal(x) + else _ragged_dot_grad(x, ct, grad_y_dims, y.aval) + ) + return x_bar, y_bar, None def _ragged_dot_batch_unpack_args(batched_args): @@ -5349,62 +5653,71 @@ def _ragged_dot_batch_unpack_dims(batch_dims): return (lbd, rbd) -def _ragged_dot_invoke_prim( +def _ragged_dot_general_invoke_prim( group_sizes, lhs, rhs, - new_dimension_numbers, + new_ragged_dot_dimension_numbers, precision, preferred_element_type, out_sharding, ): del out_sharding - return ragged_dot( + return ragged_dot_general( lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, ) -def _ragged_dot_batch_rule( +def _ragged_dot_general_batch_rule( axis_data, batched_args, batch_dims, *, + ragged_dot_dimension_numbers, precision, preferred_element_type: DTypeLike | None, **_, ): - invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2]) - - return _dot_batch_rule( + invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2]) + batched_out, result_batch_dim = _dot_batch_rule( _ragged_dot_batch_unpack_args, _ragged_dot_batch_unpack_dims, invoke, axis_data, batched_args, batch_dims, - dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + dimension_numbers=ragged_dot_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, out_sharding=None, ) + if _is_ragged_contracting(batched_args[0].ndim - 1, + ragged_dot_dimension_numbers): + result_batch_dim += 1 + return batched_out, result_batch_dim -ragged_dot_p = standard_primitive(_ragged_dot_shape_rule, - _ragged_dot_dtype_rule, 'ragged_dot') -ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p)) -ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule -ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule -batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule -batching.skippable_batchers[ragged_dot_p] = lambda _: () +ragged_dot_general_p = standard_primitive( + _ragged_dot_general_shape_rule, + _ragged_dot_general_dtype_rule, + 'ragged_dot_general', +) +ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule +ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule +batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule +batching.skippable_batchers[ragged_dot_general_p] = lambda _: () -def _ragged_dot_impl( + +def _ragged_dot_general_impl( lhs: Array, rhs: Array, group_sizes: Array, + ragged_dot_dimension_numbers: RaggedDotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, group_offset: Array | None = None, @@ -5412,24 +5725,100 @@ def _ragged_dot_impl( if group_offset is not None: raise NotImplementedError("Unimplemented group_offset support.") - if len(lhs.shape) == 3: - ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS - ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0)) - else: - ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS - ragged_to_dense = _ragged_to_dense + def ragged_to_dense(x: Array, gs: Array, *, dim: int): + from jax._src.lax import control_flow # avoid circular imports + assert gs.ndim == 1 + shape = gs.shape + x.shape + x = broadcast_in_dim(x, shape, list(range(1, len(shape)))) + iota = broadcasted_iota(gs.dtype, shape, dim+1) + group_ends = control_flow.cumsum(gs) + group_starts = concatenate( + [_zeros(gs)[:1], group_ends[:-1]], + dimension=0, + ) + group_ends = broadcast_in_dim(group_ends, shape, (0,)) + group_starts = broadcast_in_dim(group_starts, shape, (0,)) + mask = bitwise_and(group_starts <= iota, iota < group_ends) + x = select(mask, x, _zeros(x)) + return x - lhs = ragged_to_dense(lhs, rhs, group_sizes) + def batched_ragged_to_dense(dim, *x_in_axes: int): + if not x_in_axes: + return partial(ragged_to_dense, dim=dim) + x_axis, *rest = x_in_axes + decr = lambda d: d - 1 if d >= x_axis else d + return api.vmap( + batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]), + in_axes=(x_axis, 0), + ) - return dot_general( - lhs, - rhs, - dimension_numbers=ragged_dot_dims, + incr = lambda dims: [d + 1 for d in dims] + + # Expand the ragged `dim` of `x`, given its batching `axes`. + # The group axis from `gs` becomes the outermost axis of the result. + # Some examples: + # x: [m,k] , gs: [g] ==> expand(x, 0, gs): [g,m,k] + # x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k] + def expand(x, dim, gs, *axes): + expanded = batched_ragged_to_dense(dim, *axes)(x, gs) + unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes)) + return transpose(expanded, np.argsort(unsorted_dims)) + + mode, lhs_ragged_dim = _ragged_dot_mode_and_dim( + lhs.ndim, ragged_dot_dimension_numbers + ) + (l_contract, r_contract), (l_batch, r_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + l_prefix = _ragged_dot_prefix_dims( + mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract + ) + + _dot_general = partial( + dot_general, precision=precision, preferred_element_type=preferred_element_type, ) + # TODO(pravnar): Permit other broadcastable shapes. + if group_sizes.ndim == 1: + group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix]) -mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False)) + match mode: + case RaggedDotMode.RAGGED_NONCONTRACTING: + rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions + assert len(rhs_group_dims) == 1 + return _dot_general( + expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix), + rhs, + dimension_numbers=( + (incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]), + (incr(l_batch), r_batch), + ), + ) + case RaggedDotMode.RAGGED_CONTRACTING: + rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)] + r_prefix = _ragged_dot_prefix_dims( + mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract + ) + return _dot_general( + expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix), + expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix), + dimension_numbers=( + (incr(l_contract), incr(r_contract)), + ([0] + incr(l_batch), [0] + incr(r_batch)), + ), + ) + case RaggedDotMode.RAGGED_BATCH: + return _dot_general( + lhs, + rhs, + dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, + ) + + +mlir.register_lowering(ragged_dot_general_p, + mlir.lower_fun(_ragged_dot_general_impl, + multiple_results=False)) def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index af5ec987e..1809f211f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1541,6 +1541,7 @@ tf_not_yet_impl = [ "assert_consumed_value", "consume", "ragged_dot", + "ragged_dot_general", "cholesky_update", "symmetric_product", "from_edtype", diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index a26d15c14..4e376fb66 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -17,6 +17,7 @@ from jax._src.lax.lax import ( DotDimensionNumbers as DotDimensionNumbers, + RaggedDotDimensionNumbers as RaggedDotDimensionNumbers, Precision as Precision, PrecisionLike as PrecisionLike, DotAlgorithm as DotAlgorithm, @@ -158,6 +159,7 @@ from jax._src.lax.lax import ( pow as pow, pow_p as pow_p, ragged_dot as ragged_dot, + ragged_dot_general as ragged_dot_general, real as real, real_p as real_p, reciprocal as reciprocal, diff --git a/tests/lax_test.py b/tests/lax_test.py index 8497bf389..ad6b2a0bc 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4820,5 +4820,228 @@ class RaggedTest(jtu.JaxTestCase): self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @parameterized.parameters( + { + "lhs_shape": lhs_shape, + "rhs_shape": rhs_shape, + "group_sizes_shape": group_sizes_shape, + "ragged_dot_dimension_numbers": ragged_dot_dimension_numbers, + "err_msg": err_msg, + } + for lhs_shape, rhs_shape, group_sizes_shape, ragged_dot_dimension_numbers, err_msg in [ + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0, 1], + rhs_group_dimensions=[0], + ), + "ragged_dot_general expects exactly one lhs ragged dimension", + ), + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[2], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires lhs ragged dimension numbers to " + "be nonnegative and less than the number of axes of the lhs" + ), + ), + ( + [11, 5], + [3, 5, 7], + [2, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + r"expected group_sizes to have shape \(3,\), got \(2, 3\)", + ), + ( + [19, 17, 11, 5], + [3, 19, 5, 7], + [19, 11, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([3], [2]), ([0], [1])), + lhs_ragged_dimensions=[2], + rhs_group_dimensions=[0], + ), + ( + r"expected group_sizes to have shape \(19, 17, 3\), " + r"got \(19, 11, 3\)" + ), + ), + ( + [19, 11, 17, 5], + [19, 17, 5, 7], + [19, 11, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2, 3], [1, 2]), ([0], [0])), + lhs_ragged_dimensions=[3], + rhs_group_dimensions=[], + ), + ( + r"expected group_sizes to have shape \(19, 17, 3\), " + r"got \(19, 11, 3\)" + ), + ), + ( + [17, 19, 11, 5], + [17, 19, 5, 7], + [19, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([3], [2]), ([0, 1], [0, 1])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[], + ), + ( + r"expected group_sizes to have shape \(17, 3\), " + r"got \(19, 3\)" + ), + ), + ( + [19, 11, 5], + [19, 5, 7], + [19, 3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2], [1]), ([0], [0])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires rhs group dimension numbers to " + "be distinct from contracting and batch dimensions" + ), + ), + ( + [11, 3], + [3, 3, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[1], + ), + ( + "ragged_dot_general requires rhs group dimension numbers to " + "be distinct from contracting and batch dimensions" + ), + ), + ( + [11, 5], + [3, 5, 7], + [2], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + "expected rhs group dimension size to be 2, got 3", + ), + ( + [2, 11, 5], + [3, 2, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([2], [2]), ([0], [1])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires zero group dimensions in " + "the rhs when lhs ragged dimension is contracting or batch" + ), + ), + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[0], + ), + ( + "ragged_dot_general requires zero group dimensions in " + "the rhs when lhs ragged dimension is contracting or batch" + ), + ), + ( + [11, 5], + [5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [0]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[], + ), + ( + "ragged_dot_general requires exactly one rhs group dimension " + "when lhs ragged dimension is noncontracting" + ), + ), + ] + ) + def test_ragged_dot_general_shape_inference_failure( + self, lhs_shape, rhs_shape, group_sizes_shape, + ragged_dot_dimension_numbers, err_msg): + lhs = jnp.ones(lhs_shape, dtype=jnp.float32) + rhs = jnp.ones(rhs_shape, dtype=jnp.float32) + group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) + with self.assertRaisesRegex(TypeError, err_msg): + lax.ragged_dot_general(lhs, rhs, group_sizes, + ragged_dot_dimension_numbers) + + @parameterized.parameters( + { + "lhs_shape": lhs_shape, + "rhs_shape": rhs_shape, + "group_sizes_shape": group_sizes_shape, + "ragged_dnums": ragged_dnums, + "out_shape": out_shape, + } + for lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape in [ + ( + [11, 5], + [3, 5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [1]), ([], [])), + lhs_ragged_dimensions=[0], + rhs_group_dimensions=[0], + ), + (11, 7), + ), + ( + [11, 5], + [5, 7], + [3], + lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(([1], [0]), ([], [])), + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[], + ), + (3, 11, 7), + ), + ] + ) + def test_ragged_dot_general_shape_inference_success( + self, lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape): + lhs = jnp.ones(lhs_shape, dtype=jnp.float32) + rhs = jnp.ones(rhs_shape, dtype=jnp.float32) + group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) + self.assertEqual( + lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, + out_shape, + ) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())