Define lax.ragged_dot_general and express lax.ragged_dot in terms of it.

PiperOrigin-RevId: 735471245
This commit is contained in:
Praveen Narayanan 2025-03-10 12:24:38 -07:00 committed by jax authors
parent 64beebbfb0
commit b6d4fe5387
4 changed files with 761 additions and 146 deletions

View File

@ -16,6 +16,7 @@ from __future__ import annotations
import builtins
from collections.abc import Callable, Sequence
import dataclasses
import enum
import functools
from functools import partial
@ -2227,10 +2228,122 @@ def ragged_dot(
Results:
(m, n) shaped array with preferred_element_type element type.
"""
return ragged_dot_p.bind(lhs, rhs, group_sizes,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type,
group_offset=group_offset)
return ragged_dot_general(
lhs,
rhs,
group_sizes,
ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
@dataclasses.dataclass(frozen=True)
class RaggedDotDimensionNumbers():
"""Describes ragged, group, and dot dimensions for ragged dot general.
Args:
dot_dimension_numbers: a tuple of tuples of sequences of ints of the form
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
rhs_batch_dims))`.
lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged
dimensions.
rhs_group_dimensions: a sequence of ints indicating the 'rhs' group
dimensions.
"""
dot_dimension_numbers: DotDimensionNumbers
lhs_ragged_dimensions: Sequence[int]
rhs_group_dimensions: Sequence[int]
def __init__(
self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions
):
super().__setattr__(
'dot_dimension_numbers',
tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers),
)
super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions))
super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions))
def _from_maybe_ragged(
dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers,
) -> DotDimensionNumbers:
return (
dot_dimension_numbers.dot_dimension_numbers
if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers)
else dot_dimension_numbers
)
# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.)
_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[0],
)
def ragged_dot_general(
lhs: Array,
rhs: Array,
group_sizes: Array,
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
group_offset: Array | None = None,
) -> Array:
"""Ragged matrix multiplication.
Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and
a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and
``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally,
``lhs`` is required to have one ragged dimension, and ``rhs`` may have at
most one group dimension.
Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has
three modes, depending on the kind of the lhs ragged dimension:
1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`.
Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``,
and `x...` are the lhs non-contracting dims outer to the ragged dim.
2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`.
Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and
``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim.
3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`.
Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and
``rhs``, and `x...` are the lhs batch dims outer to the ragged dim.
If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according
to the rules above.
Args:
lhs: an array
rhs: an array
group_sizes: an array with integer element type
ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to
specify the dot dimension numbers, lhs ragged dimension, and rhs group
dimension.
precision: Optional. Consistent with precision argument for
:func:`jax.lax.dot`.
preferred_element_type: Optional. Consistent with precision argument for
:func:`jax.lax.dot`.
group_offset: Optional. (1,) shaped array that indicates the group in
group_sizes to start computing from. If not specified, defaults to [0].
Results:
An array whose shape is the same as that produced by `dot_general`, with an
extra leading dimension of size `g` in the case where the lhs ragged
dimension is a contracting dimension.
"""
return ragged_dot_general_p.bind(
lhs,
rhs,
group_sizes,
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None
@ -4593,7 +4706,7 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
out_sharding):
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
for d in (lhs_contracting, lhs_batch)):
msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
@ -4654,12 +4767,17 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
batch_shape = tuple(lhs_shape[i] for i in lhs_batch)
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch)
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch)
rhs_group = ()
if isinstance(dimension_numbers, RaggedDotDimensionNumbers):
rhs_group = tuple(dimension_numbers.rhs_group_dimensions)
rhs_contract_or_batch_or_group = tuple(
sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group)
)
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group)
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
@ -4723,7 +4841,7 @@ def tuple_delete(tup, idx):
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
out_sharding):
out_sharding, name: str = 'lax.dot_general'):
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError
del dimension_numbers # unused
@ -4744,8 +4862,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
result_dtype = rhs.dtype
else:
if lhs.dtype != rhs.dtype:
raise TypeError(
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}')
result_dtype = lhs.dtype
has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset))
return _maybe_upcast(result_dtype, preferred_element_type,
@ -4884,8 +5001,9 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
# explicitly present dimensions that this dot_general is zipping together.
lbd, rbd = batch_dims
assert lbd is not None or rbd is not None
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers)
def bump_dims(dims, b):
return tuple(np.add(dims, np.greater_equal(dims, b)))
@ -4908,8 +5026,14 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
elif (type(rbd) is int and lbd is None):
# The right vmapped dimension becomes an additional tensor dimension in the
# batched dot_general.
rhs_tensor = [d for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract]
rhs_tensor = list(
remaining(
range(rhs_ndim),
rhs_batch,
rhs_contract,
dimension_numbers.rhs_group_dimensions if is_ragged_dot else [],
)
)
result_batch_dim = (lhs_ndim - len(lhs_contract) +
int(sum(np.less(rhs_tensor, rbd))))
rhs_batch = bump_dims(rhs_batch, rbd)
@ -4919,6 +5043,16 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
assert False
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
if is_ragged_dot:
new_dimension_numbers = RaggedDotDimensionNumbers(
dot_dimension_numbers=new_dimension_numbers,
lhs_ragged_dimensions=bump_dims(
dimension_numbers.lhs_ragged_dimensions, lbd
),
rhs_group_dimensions=bump_dims(
dimension_numbers.rhs_group_dimensions, rbd
),
)
return new_dimension_numbers, result_batch_dim
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
@ -5010,15 +5144,6 @@ def _dot_general_batch_unpack_dims(batch_dims):
lbd, rbd = batch_dims
return (lbd, rbd)
# DotDimensionNumbers used in the dot_general call for ragged_dot().
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
([2, 0], [1, 0]),
([], []),
)
_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
([3, 1], [2, 1]),
([0], [0]),
)
ad.defbilinear(dot_general_p,
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
@ -5186,58 +5311,181 @@ for platform in ["cpu", "tpu"]:
platform=platform)
def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
if len(lhs.shape) == 3:
# Batched case
b, m, k = lhs.shape
b2, group_count, rk, n = rhs.shape
b3 = group_sizes.shape[0]
if b != b2:
raise TypeError(
f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and'
f' {b2}.'
)
if b3 != b:
raise TypeError(
'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got'
f' {b3} and {b}.'
)
if k != rk:
raise TypeError(
f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and'
f' {rk}.'
)
num_groups = group_sizes.shape[1]
if group_count != num_groups:
raise TypeError(
'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got'
f' {group_count} and {num_groups}.'
)
return (b, m, n)
class RaggedDotMode(enum.Enum):
RAGGED_NONCONTRACTING = 1 # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]
RAGGED_CONTRACTING = 2 # [b,m,k], [b,k,n], [b,g] -> [g,b,m,n]
RAGGED_BATCH = 3 # [b,m,k], [b,k,n], [g] -> [b,m,n]
m, k = lhs.shape
group_count, rk, n = rhs.shape
if k != rk:
raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
num_groups = group_sizes.shape[0]
if group_count != num_groups:
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
return (m, n)
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
precision, preferred_element_type: DTypeLike | None,
**_) -> np.dtype:
def _ragged_dot_mode_and_dim(
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
) -> tuple[RaggedDotMode, int]:
assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
(lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers
lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch)
if lhs_ragged_dim in lhs_noncontracting:
mode = RaggedDotMode.RAGGED_NONCONTRACTING
elif lhs_ragged_dim in lhs_contracting:
mode = RaggedDotMode.RAGGED_CONTRACTING
elif lhs_ragged_dim in lhs_batch:
mode = RaggedDotMode.RAGGED_BATCH
else:
raise TypeError(
f'lhs_ragged_dim {lhs_ragged_dim} not found in '
f'lhs_noncontracting {lhs_noncontracting}, '
f'lhs_contracting {lhs_contracting}, or '
f'lhs_batch {lhs_batch}.'
)
return mode, lhs_ragged_dim
def _ragged_dot_mode(
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
) -> RaggedDotMode:
return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0]
def _is_ragged_contracting(
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
) -> bool:
return (
_ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers)
== RaggedDotMode.RAGGED_CONTRACTING
)
def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract):
batch, contract = map(list, (batch, contract))
noncontract = remaining(range(rank), contract, batch)
match mode:
case RaggedDotMode.RAGGED_NONCONTRACTING:
return batch + noncontract[: noncontract.index(ragged_dim)]
case RaggedDotMode.RAGGED_CONTRACTING:
return batch + contract[: contract.index(ragged_dim)]
case RaggedDotMode.RAGGED_BATCH:
return batch[: batch.index(ragged_dim)]
def _ragged_dot_general_shape_rule(
lhs,
rhs,
group_sizes,
*,
ragged_dot_dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
**_,
):
def _check_in_range(dim, rank, dim_name, arg_name):
if dim < 0 or dim >= rank:
raise TypeError(
f'ragged_dot_general requires {dim_name} numbers to be nonnegative '
f'and less than the number of axes of the {arg_name} value, '
f'got {dim} for {arg_name} of rank {rank}.'
)
# Validate the lhs ragged dimension, and find out which mode we're in.
if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1:
raise TypeError(
'ragged_dot_general expects exactly one lhs ragged dimension.'
)
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
_check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs')
mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers)
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = (
ragged_dot_dimension_numbers.dot_dimension_numbers
)
# Validate the shape of group_sizes, if it is something other than [g].
if group_sizes.ndim == 0:
raise TypeError('expected rank of group_sizes to be >=1.')
if group_sizes.ndim != 1:
# Construct the expected shape [b...,x...,g] of group_sizes.
prefix_dims = _ragged_dot_prefix_dims(
mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting
)
expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims)
expected_gs_shape += (group_sizes.shape[-1],)
# TODO(pravnar): Permit other broadcastable shapes.
if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape):
raise TypeError(
'expected group_sizes to have shape '
f'{expected_gs_shape}, got {group_sizes.shape}.'
)
num_groups = group_sizes.shape[-1]
# Validate properties of the rhs group dimension(s).
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
match mode:
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
if len(rhs_group_dims) != 0:
raise TypeError(
'ragged_dot_general requires zero group dimensions in the rhs '
'when lhs ragged dimension is contracting or batch.'
)
case RaggedDotMode.RAGGED_NONCONTRACTING:
if len(rhs_group_dims) != 1:
raise TypeError(
'ragged_dot_general requires exactly one rhs group dimension '
'when lhs ragged dimension is noncontracting.'
)
rhs_group_dim = rhs_group_dims[0]
_check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs')
if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting:
raise TypeError(
'ragged_dot_general requires rhs group dimension numbers to be '
'distinct from contracting and batch dimensions.'
)
if rhs.shape[rhs_group_dim] != num_groups:
raise TypeError(
'expected rhs group dimension size to be '
f'{num_groups}, got {rhs.shape[rhs_group_dim]}.'
)
out_shape = _dot_general_shape_rule(
lhs,
rhs,
dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
out_sharding=None,
)
if mode == RaggedDotMode.RAGGED_CONTRACTING:
out_shape = (num_groups,) + out_shape
return out_shape
def _ragged_dot_general_dtype_rule(
lhs: Array,
rhs: Array,
group_sizes: Array,
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
precision,
preferred_element_type: DTypeLike | None,
**_,
) -> np.dtype:
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
raise TypeError(
'ragged_dot_general requires that '
'group_sizes.dtype is subtype of np.integer.'
)
# defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl.
return _dot_general_dtype_rule(
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision, preferred_element_type=preferred_element_type,
out_sharding=None)
lhs,
rhs,
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
out_sharding=None,
name='lax.ragged_dot_general',
)
def _ragged_dot_jvp_rule(
primals, tangents, precision, preferred_element_type, group_offset
def _ragged_dot_general_jvp_rule(
primals, tangents, ragged_dot_dimension_numbers,
precision, preferred_element_type, group_offset
):
# note - we could ostensibly just get this by passing on the
# value to ragged_dot below, but, this feels cleaner.
@ -5247,20 +5495,22 @@ def _ragged_dot_jvp_rule(
dx, dy, _ = tangents # no tan on the gs
# primal
primal_out = ragged_dot(
primal_out = ragged_dot_general(
x,
y,
gs,
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
# tangent
dx_out = (
ragged_dot(
ragged_dot_general(
dx,
y,
gs,
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
@ -5268,10 +5518,11 @@ def _ragged_dot_jvp_rule(
else _zeros(primal_out)
)
dy_out = (
ragged_dot(
ragged_dot_general(
x,
dy,
gs,
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
@ -5283,58 +5534,111 @@ def _ragged_dot_jvp_rule(
return primal_out, tangent_out
def _ragged_to_dense(x, y, group_sizes):
from jax._src.lax import control_flow # avoid circular imports
shape = (y.shape[0], x.shape[0], x.shape[1])
x = broadcast_in_dim(x, shape, [1, 2])
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
group_ends = control_flow.cumsum(group_sizes)
group_starts = concatenate(
[_zeros(group_sizes)[:1], group_ends[:-1]],
dimension=0,
)
group_ends = broadcast_in_dim(group_ends, shape, (0,))
group_starts = broadcast_in_dim(group_starts, shape, (0,))
mask = bitwise_and(group_starts <= iota, iota < group_ends)
x = select(mask, x, _zeros(x))
return x
def _ragged_dot_transpose_rule(
ct, *operands, precision, preferred_element_type, group_offset
def _ragged_dot_general_transpose_rule(
ct,
x,
y,
group_sizes,
*,
ragged_dot_dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
group_offset: Array | None,
):
x, y, gs = operands
if group_offset is not None:
raise NotImplementedError('Unimplemented group_offset support.')
if ad.is_undefined_primal(y):
grad_x = None
else:
y_t = _matrix_transpose(y)
grad_x = ragged_dot(
ct,
y_t,
gs,
precision=precision,
preferred_element_type=preferred_element_type,
)
(x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers
x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x)
y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y)
x_kept = remaining(range(x_ndim), x_contract, x_batch)
y_group = ragged_dot_dimension_numbers.rhs_group_dimensions
y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group)
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
x_ndim, ragged_dot_dimension_numbers
)
if ad.is_undefined_primal(x):
grad_y = None
else:
y = y.aval if ad.is_undefined_primal(y) else y
x_dense = _ragged_to_dense(x, y, group_sizes=gs)
ct_dense = _ragged_to_dense(ct, y, group_sizes=gs)
dimension_numbers = (([1], [1]), ([0], [0]))
grad_y = dot_general(
x_dense,
ct_dense,
dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError(
f'Unimplemented {fn_name} for ragged dot general in mode '
f'{ragged_dot_mode.name}.'
)
return grad_x, grad_y, None
# This is a hack to ensure we continue to emit the `_matrix_transpose` for the
# grad_x case. This isn't strictly necessary since we have dot_dim_nums.
# TODO(pravnar): Remove this once we no longer care to emit the transpose.
_is_basic_ragged_dot = (
x_ndim == 2
and y_ndim == 3
and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS
)
def grad_x_dims():
match mode:
case RaggedDotMode.RAGGED_NONCONTRACTING:
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
dims = (
ragged_dot_dimension_numbers
if _is_basic_ragged_dot
else RaggedDotDimensionNumbers(
dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)),
lhs_ragged_dimensions=[
len(x_batch) + x_kept.index(lhs_ragged_dim)
],
rhs_group_dimensions=y_group,
)
)
x_contract_sorted_by_y = list(
np.take(x_contract, np.argsort(y_contract))
)
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
raise unimplemented('grad_x_dims', mode)
return dims, unsorted_axes
def grad_y_dims():
match mode:
case RaggedDotMode.RAGGED_NONCONTRACTING:
ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept)
dims = RaggedDotDimensionNumbers(
dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)),
lhs_ragged_dimensions=[lhs_ragged_dim],
rhs_group_dimensions=[],
)
y_contract_sorted_by_x = list(
np.take(y_contract, np.argsort(x_contract))
)
unsorted_axes = (
list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept
)
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
raise unimplemented('grad_y_dims', mode)
return dims, unsorted_axes
def _ragged_dot_grad(lhs, rhs, dims_fn, aval):
dims, unsorted_axes = dims_fn()
ragged_dot_general_out = ragged_dot_general(
lhs, rhs, group_sizes, dims, precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset)
result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes)))
if result.dtype != aval.dtype:
result = _convert_element_type(result, aval.dtype, aval.weak_type)
return result
x_bar = (
None
if ad.is_undefined_primal(y)
else _ragged_dot_grad(ct,
_matrix_transpose(y) if _is_basic_ragged_dot else y,
grad_x_dims,
x.aval)
)
y_bar = (
None
if ad.is_undefined_primal(x)
else _ragged_dot_grad(x, ct, grad_y_dims, y.aval)
)
return x_bar, y_bar, None
def _ragged_dot_batch_unpack_args(batched_args):
@ -5349,62 +5653,71 @@ def _ragged_dot_batch_unpack_dims(batch_dims):
return (lbd, rbd)
def _ragged_dot_invoke_prim(
def _ragged_dot_general_invoke_prim(
group_sizes,
lhs,
rhs,
new_dimension_numbers,
new_ragged_dot_dimension_numbers,
precision,
preferred_element_type,
out_sharding,
):
del out_sharding
return ragged_dot(
return ragged_dot_general(
lhs,
rhs,
group_sizes,
ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
def _ragged_dot_batch_rule(
def _ragged_dot_general_batch_rule(
axis_data,
batched_args,
batch_dims,
*,
ragged_dot_dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
**_,
):
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
return _dot_batch_rule(
invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2])
batched_out, result_batch_dim = _dot_batch_rule(
_ragged_dot_batch_unpack_args,
_ragged_dot_batch_unpack_dims,
invoke,
axis_data,
batched_args,
batch_dims,
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
dimension_numbers=ragged_dot_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
out_sharding=None,
)
if _is_ragged_contracting(batched_args[0].ndim - 1,
ragged_dot_dimension_numbers):
result_batch_dim += 1
return batched_out, result_batch_dim
ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
_ragged_dot_dtype_rule, 'ragged_dot')
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
batching.skippable_batchers[ragged_dot_p] = lambda _: ()
ragged_dot_general_p = standard_primitive(
_ragged_dot_general_shape_rule,
_ragged_dot_general_dtype_rule,
'ragged_dot_general',
)
ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule
ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule
batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule
batching.skippable_batchers[ragged_dot_general_p] = lambda _: ()
def _ragged_dot_impl(
def _ragged_dot_general_impl(
lhs: Array,
rhs: Array,
group_sizes: Array,
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
group_offset: Array | None = None,
@ -5412,24 +5725,100 @@ def _ragged_dot_impl(
if group_offset is not None:
raise NotImplementedError("Unimplemented group_offset support.")
if len(lhs.shape) == 3:
ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS
ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0))
else:
ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS
ragged_to_dense = _ragged_to_dense
def ragged_to_dense(x: Array, gs: Array, *, dim: int):
from jax._src.lax import control_flow # avoid circular imports
assert gs.ndim == 1
shape = gs.shape + x.shape
x = broadcast_in_dim(x, shape, list(range(1, len(shape))))
iota = broadcasted_iota(gs.dtype, shape, dim+1)
group_ends = control_flow.cumsum(gs)
group_starts = concatenate(
[_zeros(gs)[:1], group_ends[:-1]],
dimension=0,
)
group_ends = broadcast_in_dim(group_ends, shape, (0,))
group_starts = broadcast_in_dim(group_starts, shape, (0,))
mask = bitwise_and(group_starts <= iota, iota < group_ends)
x = select(mask, x, _zeros(x))
return x
lhs = ragged_to_dense(lhs, rhs, group_sizes)
def batched_ragged_to_dense(dim, *x_in_axes: int):
if not x_in_axes:
return partial(ragged_to_dense, dim=dim)
x_axis, *rest = x_in_axes
decr = lambda d: d - 1 if d >= x_axis else d
return api.vmap(
batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]),
in_axes=(x_axis, 0),
)
return dot_general(
lhs,
rhs,
dimension_numbers=ragged_dot_dims,
incr = lambda dims: [d + 1 for d in dims]
# Expand the ragged `dim` of `x`, given its batching `axes`.
# The group axis from `gs` becomes the outermost axis of the result.
# Some examples:
# x: [m,k] , gs: [g] ==> expand(x, 0, gs): [g,m,k]
# x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k]
def expand(x, dim, gs, *axes):
expanded = batched_ragged_to_dense(dim, *axes)(x, gs)
unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes))
return transpose(expanded, np.argsort(unsorted_dims))
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
lhs.ndim, ragged_dot_dimension_numbers
)
(l_contract, r_contract), (l_batch, r_batch) = (
ragged_dot_dimension_numbers.dot_dimension_numbers
)
l_prefix = _ragged_dot_prefix_dims(
mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract
)
_dot_general = partial(
dot_general,
precision=precision,
preferred_element_type=preferred_element_type,
)
# TODO(pravnar): Permit other broadcastable shapes.
if group_sizes.ndim == 1:
group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix])
mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False))
match mode:
case RaggedDotMode.RAGGED_NONCONTRACTING:
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
assert len(rhs_group_dims) == 1
return _dot_general(
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
rhs,
dimension_numbers=(
(incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]),
(incr(l_batch), r_batch),
),
)
case RaggedDotMode.RAGGED_CONTRACTING:
rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)]
r_prefix = _ragged_dot_prefix_dims(
mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract
)
return _dot_general(
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix),
dimension_numbers=(
(incr(l_contract), incr(r_contract)),
([0] + incr(l_batch), [0] + incr(r_batch)),
),
)
case RaggedDotMode.RAGGED_BATCH:
return _dot_general(
lhs,
rhs,
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
)
mlir.register_lowering(ragged_dot_general_p,
mlir.lower_fun(_ragged_dot_general_impl,
multiple_results=False))
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions,

View File

@ -1541,6 +1541,7 @@ tf_not_yet_impl = [
"assert_consumed_value",
"consume",
"ragged_dot",
"ragged_dot_general",
"cholesky_update",
"symmetric_product",
"from_edtype",

View File

@ -17,6 +17,7 @@
from jax._src.lax.lax import (
DotDimensionNumbers as DotDimensionNumbers,
RaggedDotDimensionNumbers as RaggedDotDimensionNumbers,
Precision as Precision,
PrecisionLike as PrecisionLike,
DotAlgorithm as DotAlgorithm,
@ -158,6 +159,7 @@ from jax._src.lax.lax import (
pow as pow,
pow_p as pow_p,
ragged_dot as ragged_dot,
ragged_dot_general as ragged_dot_general,
real as real,
real_p as real_p,
reciprocal as reciprocal,

View File

@ -4820,5 +4820,228 @@ class RaggedTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(
lax_reference.ragged_dot, lax.ragged_dot, args_maker)
@parameterized.parameters(
{
"lhs_shape": lhs_shape,
"rhs_shape": rhs_shape,
"group_sizes_shape": group_sizes_shape,
"ragged_dot_dimension_numbers": ragged_dot_dimension_numbers,
"err_msg": err_msg,
}
for lhs_shape, rhs_shape, group_sizes_shape, ragged_dot_dimension_numbers, err_msg in [
(
[11, 5],
[3, 5, 7],
[3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[0, 1],
rhs_group_dimensions=[0],
),
"ragged_dot_general expects exactly one lhs ragged dimension",
),
(
[11, 5],
[3, 5, 7],
[3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[2],
rhs_group_dimensions=[0],
),
(
"ragged_dot_general requires lhs ragged dimension numbers to "
"be nonnegative and less than the number of axes of the lhs"
),
),
(
[11, 5],
[3, 5, 7],
[2, 3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[0],
),
r"expected group_sizes to have shape \(3,\), got \(2, 3\)",
),
(
[19, 17, 11, 5],
[3, 19, 5, 7],
[19, 11, 3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([3], [2]), ([0], [1])),
lhs_ragged_dimensions=[2],
rhs_group_dimensions=[0],
),
(
r"expected group_sizes to have shape \(19, 17, 3\), "
r"got \(19, 11, 3\)"
),
),
(
[19, 11, 17, 5],
[19, 17, 5, 7],
[19, 11, 3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([2, 3], [1, 2]), ([0], [0])),
lhs_ragged_dimensions=[3],
rhs_group_dimensions=[],
),
(
r"expected group_sizes to have shape \(19, 17, 3\), "
r"got \(19, 11, 3\)"
),
),
(
[17, 19, 11, 5],
[17, 19, 5, 7],
[19, 3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([3], [2]), ([0, 1], [0, 1])),
lhs_ragged_dimensions=[1],
rhs_group_dimensions=[],
),
(
r"expected group_sizes to have shape \(17, 3\), "
r"got \(19, 3\)"
),
),
(
[19, 11, 5],
[19, 5, 7],
[19, 3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([2], [1]), ([0], [0])),
lhs_ragged_dimensions=[1],
rhs_group_dimensions=[0],
),
(
"ragged_dot_general requires rhs group dimension numbers to "
"be distinct from contracting and batch dimensions"
),
),
(
[11, 3],
[3, 3, 7],
[3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[1],
),
(
"ragged_dot_general requires rhs group dimension numbers to "
"be distinct from contracting and batch dimensions"
),
),
(
[11, 5],
[3, 5, 7],
[2],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[0],
),
"expected rhs group dimension size to be 2, got 3",
),
(
[2, 11, 5],
[3, 2, 5, 7],
[3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([2], [2]), ([0], [1])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[0],
),
(
"ragged_dot_general requires zero group dimensions in "
"the rhs when lhs ragged dimension is contracting or batch"
),
),
(
[11, 5],
[3, 5, 7],
[3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[1],
rhs_group_dimensions=[0],
),
(
"ragged_dot_general requires zero group dimensions in "
"the rhs when lhs ragged dimension is contracting or batch"
),
),
(
[11, 5],
[5, 7],
[3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [0]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[],
),
(
"ragged_dot_general requires exactly one rhs group dimension "
"when lhs ragged dimension is noncontracting"
),
),
]
)
def test_ragged_dot_general_shape_inference_failure(
self, lhs_shape, rhs_shape, group_sizes_shape,
ragged_dot_dimension_numbers, err_msg):
lhs = jnp.ones(lhs_shape, dtype=jnp.float32)
rhs = jnp.ones(rhs_shape, dtype=jnp.float32)
group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32)
with self.assertRaisesRegex(TypeError, err_msg):
lax.ragged_dot_general(lhs, rhs, group_sizes,
ragged_dot_dimension_numbers)
@parameterized.parameters(
{
"lhs_shape": lhs_shape,
"rhs_shape": rhs_shape,
"group_sizes_shape": group_sizes_shape,
"ragged_dnums": ragged_dnums,
"out_shape": out_shape,
}
for lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape in [
(
[11, 5],
[3, 5, 7],
[3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [1]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[0],
),
(11, 7),
),
(
[11, 5],
[5, 7],
[3],
lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([1], [0]), ([], [])),
lhs_ragged_dimensions=[1],
rhs_group_dimensions=[],
),
(3, 11, 7),
),
]
)
def test_ragged_dot_general_shape_inference_success(
self, lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape):
lhs = jnp.ones(lhs_shape, dtype=jnp.float32)
rhs = jnp.ones(rhs_shape, dtype=jnp.float32)
group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32)
self.assertEqual(
lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape,
out_shape,
)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())