mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Define lax.ragged_dot_general and express lax.ragged_dot in terms of it.
PiperOrigin-RevId: 735471245
This commit is contained in:
parent
64beebbfb0
commit
b6d4fe5387
@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from functools import partial
|
||||
@ -2227,10 +2228,122 @@ def ragged_dot(
|
||||
Results:
|
||||
(m, n) shaped array with preferred_element_type element type.
|
||||
"""
|
||||
return ragged_dot_p.bind(lhs, rhs, group_sizes,
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type,
|
||||
group_offset=group_offset)
|
||||
return ragged_dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS,
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type,
|
||||
group_offset=group_offset,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RaggedDotDimensionNumbers():
|
||||
"""Describes ragged, group, and dot dimensions for ragged dot general.
|
||||
|
||||
Args:
|
||||
dot_dimension_numbers: a tuple of tuples of sequences of ints of the form
|
||||
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
|
||||
rhs_batch_dims))`.
|
||||
lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged
|
||||
dimensions.
|
||||
rhs_group_dimensions: a sequence of ints indicating the 'rhs' group
|
||||
dimensions.
|
||||
"""
|
||||
dot_dimension_numbers: DotDimensionNumbers
|
||||
lhs_ragged_dimensions: Sequence[int]
|
||||
rhs_group_dimensions: Sequence[int]
|
||||
|
||||
def __init__(
|
||||
self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions
|
||||
):
|
||||
super().__setattr__(
|
||||
'dot_dimension_numbers',
|
||||
tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers),
|
||||
)
|
||||
super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions))
|
||||
super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions))
|
||||
|
||||
|
||||
def _from_maybe_ragged(
|
||||
dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers,
|
||||
) -> DotDimensionNumbers:
|
||||
return (
|
||||
dot_dimension_numbers.dot_dimension_numbers
|
||||
if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers)
|
||||
else dot_dimension_numbers
|
||||
)
|
||||
|
||||
|
||||
# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.)
|
||||
_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[0],
|
||||
)
|
||||
|
||||
|
||||
def ragged_dot_general(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
group_sizes: Array,
|
||||
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
group_offset: Array | None = None,
|
||||
) -> Array:
|
||||
"""Ragged matrix multiplication.
|
||||
|
||||
Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and
|
||||
a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and
|
||||
``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally,
|
||||
``lhs`` is required to have one ragged dimension, and ``rhs`` may have at
|
||||
most one group dimension.
|
||||
|
||||
Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has
|
||||
three modes, depending on the kind of the lhs ragged dimension:
|
||||
1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`.
|
||||
Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``,
|
||||
and `x...` are the lhs non-contracting dims outer to the ragged dim.
|
||||
2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`.
|
||||
Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and
|
||||
``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim.
|
||||
3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`.
|
||||
Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and
|
||||
``rhs``, and `x...` are the lhs batch dims outer to the ragged dim.
|
||||
If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according
|
||||
to the rules above.
|
||||
|
||||
Args:
|
||||
lhs: an array
|
||||
rhs: an array
|
||||
group_sizes: an array with integer element type
|
||||
ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to
|
||||
specify the dot dimension numbers, lhs ragged dimension, and rhs group
|
||||
dimension.
|
||||
precision: Optional. Consistent with precision argument for
|
||||
:func:`jax.lax.dot`.
|
||||
preferred_element_type: Optional. Consistent with precision argument for
|
||||
:func:`jax.lax.dot`.
|
||||
group_offset: Optional. (1,) shaped array that indicates the group in
|
||||
group_sizes to start computing from. If not specified, defaults to [0].
|
||||
|
||||
Results:
|
||||
An array whose shape is the same as that produced by `dot_general`, with an
|
||||
extra leading dimension of size `g` in the case where the lhs ragged
|
||||
dimension is a contracting dimension.
|
||||
"""
|
||||
return ragged_dot_general_p.bind(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type,
|
||||
group_offset=group_offset,
|
||||
)
|
||||
|
||||
|
||||
def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None
|
||||
@ -4593,7 +4706,7 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
out_sharding):
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
||||
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
||||
for d in (lhs_contracting, lhs_batch)):
|
||||
msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
|
||||
@ -4654,12 +4767,17 @@ def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
|
||||
|
||||
def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
||||
batch_shape = tuple(lhs_shape[i] for i in lhs_batch)
|
||||
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
|
||||
lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch)
|
||||
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
|
||||
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch)
|
||||
rhs_group = ()
|
||||
if isinstance(dimension_numbers, RaggedDotDimensionNumbers):
|
||||
rhs_group = tuple(dimension_numbers.rhs_group_dimensions)
|
||||
rhs_contract_or_batch_or_group = tuple(
|
||||
sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group)
|
||||
)
|
||||
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group)
|
||||
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
|
||||
|
||||
|
||||
@ -4723,7 +4841,7 @@ def tuple_delete(tup, idx):
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_sharding):
|
||||
out_sharding, name: str = 'lax.dot_general'):
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError
|
||||
del dimension_numbers # unused
|
||||
@ -4744,8 +4862,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
result_dtype = rhs.dtype
|
||||
else:
|
||||
if lhs.dtype != rhs.dtype:
|
||||
raise TypeError(
|
||||
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
|
||||
raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}')
|
||||
result_dtype = lhs.dtype
|
||||
has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset))
|
||||
return _maybe_upcast(result_dtype, preferred_element_type,
|
||||
@ -4884,8 +5001,9 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
# explicitly present dimensions that this dot_general is zipping together.
|
||||
lbd, rbd = batch_dims
|
||||
assert lbd is not None or rbd is not None
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
||||
|
||||
is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers)
|
||||
def bump_dims(dims, b):
|
||||
return tuple(np.add(dims, np.greater_equal(dims, b)))
|
||||
|
||||
@ -4908,8 +5026,14 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
elif (type(rbd) is int and lbd is None):
|
||||
# The right vmapped dimension becomes an additional tensor dimension in the
|
||||
# batched dot_general.
|
||||
rhs_tensor = [d for d in range(rhs_ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract]
|
||||
rhs_tensor = list(
|
||||
remaining(
|
||||
range(rhs_ndim),
|
||||
rhs_batch,
|
||||
rhs_contract,
|
||||
dimension_numbers.rhs_group_dimensions if is_ragged_dot else [],
|
||||
)
|
||||
)
|
||||
result_batch_dim = (lhs_ndim - len(lhs_contract) +
|
||||
int(sum(np.less(rhs_tensor, rbd))))
|
||||
rhs_batch = bump_dims(rhs_batch, rbd)
|
||||
@ -4919,6 +5043,16 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
assert False
|
||||
|
||||
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
||||
if is_ragged_dot:
|
||||
new_dimension_numbers = RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=new_dimension_numbers,
|
||||
lhs_ragged_dimensions=bump_dims(
|
||||
dimension_numbers.lhs_ragged_dimensions, lbd
|
||||
),
|
||||
rhs_group_dimensions=bump_dims(
|
||||
dimension_numbers.rhs_group_dimensions, rbd
|
||||
),
|
||||
)
|
||||
return new_dimension_numbers, result_batch_dim
|
||||
|
||||
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
|
||||
@ -5010,15 +5144,6 @@ def _dot_general_batch_unpack_dims(batch_dims):
|
||||
lbd, rbd = batch_dims
|
||||
return (lbd, rbd)
|
||||
|
||||
# DotDimensionNumbers used in the dot_general call for ragged_dot().
|
||||
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
|
||||
([2, 0], [1, 0]),
|
||||
([], []),
|
||||
)
|
||||
_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
|
||||
([3, 1], [2, 1]),
|
||||
([0], [0]),
|
||||
)
|
||||
|
||||
ad.defbilinear(dot_general_p,
|
||||
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
||||
@ -5186,58 +5311,181 @@ for platform in ["cpu", "tpu"]:
|
||||
platform=platform)
|
||||
|
||||
|
||||
def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
|
||||
if len(lhs.shape) == 3:
|
||||
# Batched case
|
||||
b, m, k = lhs.shape
|
||||
b2, group_count, rk, n = rhs.shape
|
||||
b3 = group_sizes.shape[0]
|
||||
if b != b2:
|
||||
raise TypeError(
|
||||
f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and'
|
||||
f' {b2}.'
|
||||
)
|
||||
if b3 != b:
|
||||
raise TypeError(
|
||||
'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got'
|
||||
f' {b3} and {b}.'
|
||||
)
|
||||
if k != rk:
|
||||
raise TypeError(
|
||||
f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and'
|
||||
f' {rk}.'
|
||||
)
|
||||
num_groups = group_sizes.shape[1]
|
||||
if group_count != num_groups:
|
||||
raise TypeError(
|
||||
'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got'
|
||||
f' {group_count} and {num_groups}.'
|
||||
)
|
||||
return (b, m, n)
|
||||
class RaggedDotMode(enum.Enum):
|
||||
RAGGED_NONCONTRACTING = 1 # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]
|
||||
RAGGED_CONTRACTING = 2 # [b,m,k], [b,k,n], [b,g] -> [g,b,m,n]
|
||||
RAGGED_BATCH = 3 # [b,m,k], [b,k,n], [g] -> [b,m,n]
|
||||
|
||||
m, k = lhs.shape
|
||||
group_count, rk, n = rhs.shape
|
||||
if k != rk:
|
||||
raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
|
||||
num_groups = group_sizes.shape[0]
|
||||
if group_count != num_groups:
|
||||
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
|
||||
return (m, n)
|
||||
|
||||
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
|
||||
precision, preferred_element_type: DTypeLike | None,
|
||||
**_) -> np.dtype:
|
||||
def _ragged_dot_mode_and_dim(
|
||||
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
||||
) -> tuple[RaggedDotMode, int]:
|
||||
assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1
|
||||
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
|
||||
(lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers
|
||||
lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch)
|
||||
if lhs_ragged_dim in lhs_noncontracting:
|
||||
mode = RaggedDotMode.RAGGED_NONCONTRACTING
|
||||
elif lhs_ragged_dim in lhs_contracting:
|
||||
mode = RaggedDotMode.RAGGED_CONTRACTING
|
||||
elif lhs_ragged_dim in lhs_batch:
|
||||
mode = RaggedDotMode.RAGGED_BATCH
|
||||
else:
|
||||
raise TypeError(
|
||||
f'lhs_ragged_dim {lhs_ragged_dim} not found in '
|
||||
f'lhs_noncontracting {lhs_noncontracting}, '
|
||||
f'lhs_contracting {lhs_contracting}, or '
|
||||
f'lhs_batch {lhs_batch}.'
|
||||
)
|
||||
return mode, lhs_ragged_dim
|
||||
|
||||
|
||||
def _ragged_dot_mode(
|
||||
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
||||
) -> RaggedDotMode:
|
||||
return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0]
|
||||
|
||||
|
||||
def _is_ragged_contracting(
|
||||
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
||||
) -> bool:
|
||||
return (
|
||||
_ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers)
|
||||
== RaggedDotMode.RAGGED_CONTRACTING
|
||||
)
|
||||
|
||||
|
||||
def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract):
|
||||
batch, contract = map(list, (batch, contract))
|
||||
noncontract = remaining(range(rank), contract, batch)
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
return batch + noncontract[: noncontract.index(ragged_dim)]
|
||||
case RaggedDotMode.RAGGED_CONTRACTING:
|
||||
return batch + contract[: contract.index(ragged_dim)]
|
||||
case RaggedDotMode.RAGGED_BATCH:
|
||||
return batch[: batch.index(ragged_dim)]
|
||||
|
||||
|
||||
def _ragged_dot_general_shape_rule(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
*,
|
||||
ragged_dot_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
):
|
||||
def _check_in_range(dim, rank, dim_name, arg_name):
|
||||
if dim < 0 or dim >= rank:
|
||||
raise TypeError(
|
||||
f'ragged_dot_general requires {dim_name} numbers to be nonnegative '
|
||||
f'and less than the number of axes of the {arg_name} value, '
|
||||
f'got {dim} for {arg_name} of rank {rank}.'
|
||||
)
|
||||
|
||||
# Validate the lhs ragged dimension, and find out which mode we're in.
|
||||
if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1:
|
||||
raise TypeError(
|
||||
'ragged_dot_general expects exactly one lhs ragged dimension.'
|
||||
)
|
||||
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
|
||||
_check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs')
|
||||
mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers)
|
||||
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = (
|
||||
ragged_dot_dimension_numbers.dot_dimension_numbers
|
||||
)
|
||||
|
||||
# Validate the shape of group_sizes, if it is something other than [g].
|
||||
if group_sizes.ndim == 0:
|
||||
raise TypeError('expected rank of group_sizes to be >=1.')
|
||||
if group_sizes.ndim != 1:
|
||||
# Construct the expected shape [b...,x...,g] of group_sizes.
|
||||
prefix_dims = _ragged_dot_prefix_dims(
|
||||
mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting
|
||||
)
|
||||
expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims)
|
||||
expected_gs_shape += (group_sizes.shape[-1],)
|
||||
# TODO(pravnar): Permit other broadcastable shapes.
|
||||
if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape):
|
||||
raise TypeError(
|
||||
'expected group_sizes to have shape '
|
||||
f'{expected_gs_shape}, got {group_sizes.shape}.'
|
||||
)
|
||||
num_groups = group_sizes.shape[-1]
|
||||
|
||||
# Validate properties of the rhs group dimension(s).
|
||||
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
||||
if len(rhs_group_dims) != 0:
|
||||
raise TypeError(
|
||||
'ragged_dot_general requires zero group dimensions in the rhs '
|
||||
'when lhs ragged dimension is contracting or batch.'
|
||||
)
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
if len(rhs_group_dims) != 1:
|
||||
raise TypeError(
|
||||
'ragged_dot_general requires exactly one rhs group dimension '
|
||||
'when lhs ragged dimension is noncontracting.'
|
||||
)
|
||||
rhs_group_dim = rhs_group_dims[0]
|
||||
_check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs')
|
||||
if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting:
|
||||
raise TypeError(
|
||||
'ragged_dot_general requires rhs group dimension numbers to be '
|
||||
'distinct from contracting and batch dimensions.'
|
||||
)
|
||||
if rhs.shape[rhs_group_dim] != num_groups:
|
||||
raise TypeError(
|
||||
'expected rhs group dimension size to be '
|
||||
f'{num_groups}, got {rhs.shape[rhs_group_dim]}.'
|
||||
)
|
||||
|
||||
out_shape = _dot_general_shape_rule(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_sharding=None,
|
||||
)
|
||||
if mode == RaggedDotMode.RAGGED_CONTRACTING:
|
||||
out_shape = (num_groups,) + out_shape
|
||||
return out_shape
|
||||
|
||||
|
||||
def _ragged_dot_general_dtype_rule(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
group_sizes: Array,
|
||||
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
) -> np.dtype:
|
||||
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
|
||||
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
|
||||
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
|
||||
raise TypeError(
|
||||
'ragged_dot_general requires that '
|
||||
'group_sizes.dtype is subtype of np.integer.'
|
||||
)
|
||||
# defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl.
|
||||
return _dot_general_dtype_rule(
|
||||
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision, preferred_element_type=preferred_element_type,
|
||||
out_sharding=None)
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_sharding=None,
|
||||
name='lax.ragged_dot_general',
|
||||
)
|
||||
|
||||
|
||||
def _ragged_dot_jvp_rule(
|
||||
primals, tangents, precision, preferred_element_type, group_offset
|
||||
def _ragged_dot_general_jvp_rule(
|
||||
primals, tangents, ragged_dot_dimension_numbers,
|
||||
precision, preferred_element_type, group_offset
|
||||
):
|
||||
# note - we could ostensibly just get this by passing on the
|
||||
# value to ragged_dot below, but, this feels cleaner.
|
||||
@ -5247,20 +5495,22 @@ def _ragged_dot_jvp_rule(
|
||||
dx, dy, _ = tangents # no tan on the gs
|
||||
|
||||
# primal
|
||||
primal_out = ragged_dot(
|
||||
primal_out = ragged_dot_general(
|
||||
x,
|
||||
y,
|
||||
gs,
|
||||
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
|
||||
# tangent
|
||||
dx_out = (
|
||||
ragged_dot(
|
||||
ragged_dot_general(
|
||||
dx,
|
||||
y,
|
||||
gs,
|
||||
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
@ -5268,10 +5518,11 @@ def _ragged_dot_jvp_rule(
|
||||
else _zeros(primal_out)
|
||||
)
|
||||
dy_out = (
|
||||
ragged_dot(
|
||||
ragged_dot_general(
|
||||
x,
|
||||
dy,
|
||||
gs,
|
||||
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
@ -5283,58 +5534,111 @@ def _ragged_dot_jvp_rule(
|
||||
return primal_out, tangent_out
|
||||
|
||||
|
||||
def _ragged_to_dense(x, y, group_sizes):
|
||||
from jax._src.lax import control_flow # avoid circular imports
|
||||
shape = (y.shape[0], x.shape[0], x.shape[1])
|
||||
x = broadcast_in_dim(x, shape, [1, 2])
|
||||
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
|
||||
group_ends = control_flow.cumsum(group_sizes)
|
||||
group_starts = concatenate(
|
||||
[_zeros(group_sizes)[:1], group_ends[:-1]],
|
||||
dimension=0,
|
||||
)
|
||||
group_ends = broadcast_in_dim(group_ends, shape, (0,))
|
||||
group_starts = broadcast_in_dim(group_starts, shape, (0,))
|
||||
mask = bitwise_and(group_starts <= iota, iota < group_ends)
|
||||
x = select(mask, x, _zeros(x))
|
||||
return x
|
||||
|
||||
|
||||
def _ragged_dot_transpose_rule(
|
||||
ct, *operands, precision, preferred_element_type, group_offset
|
||||
def _ragged_dot_general_transpose_rule(
|
||||
ct,
|
||||
x,
|
||||
y,
|
||||
group_sizes,
|
||||
*,
|
||||
ragged_dot_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
group_offset: Array | None,
|
||||
):
|
||||
x, y, gs = operands
|
||||
if group_offset is not None:
|
||||
raise NotImplementedError('Unimplemented group_offset support.')
|
||||
|
||||
if ad.is_undefined_primal(y):
|
||||
grad_x = None
|
||||
else:
|
||||
y_t = _matrix_transpose(y)
|
||||
grad_x = ragged_dot(
|
||||
ct,
|
||||
y_t,
|
||||
gs,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
(x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers
|
||||
x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x)
|
||||
y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y)
|
||||
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
||||
y_group = ragged_dot_dimension_numbers.rhs_group_dimensions
|
||||
y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group)
|
||||
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
|
||||
x_ndim, ragged_dot_dimension_numbers
|
||||
)
|
||||
|
||||
if ad.is_undefined_primal(x):
|
||||
grad_y = None
|
||||
else:
|
||||
y = y.aval if ad.is_undefined_primal(y) else y
|
||||
x_dense = _ragged_to_dense(x, y, group_sizes=gs)
|
||||
ct_dense = _ragged_to_dense(ct, y, group_sizes=gs)
|
||||
dimension_numbers = (([1], [1]), ([0], [0]))
|
||||
grad_y = dot_general(
|
||||
x_dense,
|
||||
ct_dense,
|
||||
dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError(
|
||||
f'Unimplemented {fn_name} for ragged dot general in mode '
|
||||
f'{ragged_dot_mode.name}.'
|
||||
)
|
||||
|
||||
return grad_x, grad_y, None
|
||||
# This is a hack to ensure we continue to emit the `_matrix_transpose` for the
|
||||
# grad_x case. This isn't strictly necessary since we have dot_dim_nums.
|
||||
# TODO(pravnar): Remove this once we no longer care to emit the transpose.
|
||||
_is_basic_ragged_dot = (
|
||||
x_ndim == 2
|
||||
and y_ndim == 3
|
||||
and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS
|
||||
)
|
||||
|
||||
def grad_x_dims():
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
||||
dims = (
|
||||
ragged_dot_dimension_numbers
|
||||
if _is_basic_ragged_dot
|
||||
else RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)),
|
||||
lhs_ragged_dimensions=[
|
||||
len(x_batch) + x_kept.index(lhs_ragged_dim)
|
||||
],
|
||||
rhs_group_dimensions=y_group,
|
||||
)
|
||||
)
|
||||
x_contract_sorted_by_y = list(
|
||||
np.take(x_contract, np.argsort(y_contract))
|
||||
)
|
||||
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
|
||||
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
||||
raise unimplemented('grad_x_dims', mode)
|
||||
return dims, unsorted_axes
|
||||
|
||||
def grad_y_dims():
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept)
|
||||
dims = RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)),
|
||||
lhs_ragged_dimensions=[lhs_ragged_dim],
|
||||
rhs_group_dimensions=[],
|
||||
)
|
||||
y_contract_sorted_by_x = list(
|
||||
np.take(y_contract, np.argsort(x_contract))
|
||||
)
|
||||
unsorted_axes = (
|
||||
list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept
|
||||
)
|
||||
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
||||
raise unimplemented('grad_y_dims', mode)
|
||||
return dims, unsorted_axes
|
||||
|
||||
def _ragged_dot_grad(lhs, rhs, dims_fn, aval):
|
||||
dims, unsorted_axes = dims_fn()
|
||||
ragged_dot_general_out = ragged_dot_general(
|
||||
lhs, rhs, group_sizes, dims, precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
group_offset=group_offset)
|
||||
result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes)))
|
||||
if result.dtype != aval.dtype:
|
||||
result = _convert_element_type(result, aval.dtype, aval.weak_type)
|
||||
return result
|
||||
|
||||
x_bar = (
|
||||
None
|
||||
if ad.is_undefined_primal(y)
|
||||
else _ragged_dot_grad(ct,
|
||||
_matrix_transpose(y) if _is_basic_ragged_dot else y,
|
||||
grad_x_dims,
|
||||
x.aval)
|
||||
)
|
||||
y_bar = (
|
||||
None
|
||||
if ad.is_undefined_primal(x)
|
||||
else _ragged_dot_grad(x, ct, grad_y_dims, y.aval)
|
||||
)
|
||||
return x_bar, y_bar, None
|
||||
|
||||
|
||||
def _ragged_dot_batch_unpack_args(batched_args):
|
||||
@ -5349,62 +5653,71 @@ def _ragged_dot_batch_unpack_dims(batch_dims):
|
||||
return (lbd, rbd)
|
||||
|
||||
|
||||
def _ragged_dot_invoke_prim(
|
||||
def _ragged_dot_general_invoke_prim(
|
||||
group_sizes,
|
||||
lhs,
|
||||
rhs,
|
||||
new_dimension_numbers,
|
||||
new_ragged_dot_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
out_sharding,
|
||||
):
|
||||
del out_sharding
|
||||
return ragged_dot(
|
||||
return ragged_dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
|
||||
|
||||
def _ragged_dot_batch_rule(
|
||||
def _ragged_dot_general_batch_rule(
|
||||
axis_data,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
*,
|
||||
ragged_dot_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
):
|
||||
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
|
||||
|
||||
return _dot_batch_rule(
|
||||
invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2])
|
||||
batched_out, result_batch_dim = _dot_batch_rule(
|
||||
_ragged_dot_batch_unpack_args,
|
||||
_ragged_dot_batch_unpack_dims,
|
||||
invoke,
|
||||
axis_data,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
dimension_numbers=ragged_dot_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_sharding=None,
|
||||
)
|
||||
if _is_ragged_contracting(batched_args[0].ndim - 1,
|
||||
ragged_dot_dimension_numbers):
|
||||
result_batch_dim += 1
|
||||
return batched_out, result_batch_dim
|
||||
|
||||
|
||||
ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
|
||||
_ragged_dot_dtype_rule, 'ragged_dot')
|
||||
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
|
||||
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
|
||||
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
|
||||
batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
|
||||
batching.skippable_batchers[ragged_dot_p] = lambda _: ()
|
||||
ragged_dot_general_p = standard_primitive(
|
||||
_ragged_dot_general_shape_rule,
|
||||
_ragged_dot_general_dtype_rule,
|
||||
'ragged_dot_general',
|
||||
)
|
||||
ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule
|
||||
ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule
|
||||
batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule
|
||||
batching.skippable_batchers[ragged_dot_general_p] = lambda _: ()
|
||||
|
||||
def _ragged_dot_impl(
|
||||
|
||||
def _ragged_dot_general_impl(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
group_sizes: Array,
|
||||
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
group_offset: Array | None = None,
|
||||
@ -5412,24 +5725,100 @@ def _ragged_dot_impl(
|
||||
if group_offset is not None:
|
||||
raise NotImplementedError("Unimplemented group_offset support.")
|
||||
|
||||
if len(lhs.shape) == 3:
|
||||
ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS
|
||||
ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0))
|
||||
else:
|
||||
ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS
|
||||
ragged_to_dense = _ragged_to_dense
|
||||
def ragged_to_dense(x: Array, gs: Array, *, dim: int):
|
||||
from jax._src.lax import control_flow # avoid circular imports
|
||||
assert gs.ndim == 1
|
||||
shape = gs.shape + x.shape
|
||||
x = broadcast_in_dim(x, shape, list(range(1, len(shape))))
|
||||
iota = broadcasted_iota(gs.dtype, shape, dim+1)
|
||||
group_ends = control_flow.cumsum(gs)
|
||||
group_starts = concatenate(
|
||||
[_zeros(gs)[:1], group_ends[:-1]],
|
||||
dimension=0,
|
||||
)
|
||||
group_ends = broadcast_in_dim(group_ends, shape, (0,))
|
||||
group_starts = broadcast_in_dim(group_starts, shape, (0,))
|
||||
mask = bitwise_and(group_starts <= iota, iota < group_ends)
|
||||
x = select(mask, x, _zeros(x))
|
||||
return x
|
||||
|
||||
lhs = ragged_to_dense(lhs, rhs, group_sizes)
|
||||
def batched_ragged_to_dense(dim, *x_in_axes: int):
|
||||
if not x_in_axes:
|
||||
return partial(ragged_to_dense, dim=dim)
|
||||
x_axis, *rest = x_in_axes
|
||||
decr = lambda d: d - 1 if d >= x_axis else d
|
||||
return api.vmap(
|
||||
batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]),
|
||||
in_axes=(x_axis, 0),
|
||||
)
|
||||
|
||||
return dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=ragged_dot_dims,
|
||||
incr = lambda dims: [d + 1 for d in dims]
|
||||
|
||||
# Expand the ragged `dim` of `x`, given its batching `axes`.
|
||||
# The group axis from `gs` becomes the outermost axis of the result.
|
||||
# Some examples:
|
||||
# x: [m,k] , gs: [g] ==> expand(x, 0, gs): [g,m,k]
|
||||
# x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k]
|
||||
def expand(x, dim, gs, *axes):
|
||||
expanded = batched_ragged_to_dense(dim, *axes)(x, gs)
|
||||
unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes))
|
||||
return transpose(expanded, np.argsort(unsorted_dims))
|
||||
|
||||
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
|
||||
lhs.ndim, ragged_dot_dimension_numbers
|
||||
)
|
||||
(l_contract, r_contract), (l_batch, r_batch) = (
|
||||
ragged_dot_dimension_numbers.dot_dimension_numbers
|
||||
)
|
||||
l_prefix = _ragged_dot_prefix_dims(
|
||||
mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract
|
||||
)
|
||||
|
||||
_dot_general = partial(
|
||||
dot_general,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
# TODO(pravnar): Permit other broadcastable shapes.
|
||||
if group_sizes.ndim == 1:
|
||||
group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix])
|
||||
|
||||
mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False))
|
||||
match mode:
|
||||
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
||||
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
|
||||
assert len(rhs_group_dims) == 1
|
||||
return _dot_general(
|
||||
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
|
||||
rhs,
|
||||
dimension_numbers=(
|
||||
(incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]),
|
||||
(incr(l_batch), r_batch),
|
||||
),
|
||||
)
|
||||
case RaggedDotMode.RAGGED_CONTRACTING:
|
||||
rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)]
|
||||
r_prefix = _ragged_dot_prefix_dims(
|
||||
mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract
|
||||
)
|
||||
return _dot_general(
|
||||
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
|
||||
expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix),
|
||||
dimension_numbers=(
|
||||
(incr(l_contract), incr(r_contract)),
|
||||
([0] + incr(l_batch), [0] + incr(r_batch)),
|
||||
),
|
||||
)
|
||||
case RaggedDotMode.RAGGED_BATCH:
|
||||
return _dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
|
||||
)
|
||||
|
||||
|
||||
mlir.register_lowering(ragged_dot_general_p,
|
||||
mlir.lower_fun(_ragged_dot_general_impl,
|
||||
multiple_results=False))
|
||||
|
||||
|
||||
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions,
|
||||
|
@ -1541,6 +1541,7 @@ tf_not_yet_impl = [
|
||||
"assert_consumed_value",
|
||||
"consume",
|
||||
"ragged_dot",
|
||||
"ragged_dot_general",
|
||||
"cholesky_update",
|
||||
"symmetric_product",
|
||||
"from_edtype",
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
from jax._src.lax.lax import (
|
||||
DotDimensionNumbers as DotDimensionNumbers,
|
||||
RaggedDotDimensionNumbers as RaggedDotDimensionNumbers,
|
||||
Precision as Precision,
|
||||
PrecisionLike as PrecisionLike,
|
||||
DotAlgorithm as DotAlgorithm,
|
||||
@ -158,6 +159,7 @@ from jax._src.lax.lax import (
|
||||
pow as pow,
|
||||
pow_p as pow_p,
|
||||
ragged_dot as ragged_dot,
|
||||
ragged_dot_general as ragged_dot_general,
|
||||
real as real,
|
||||
real_p as real_p,
|
||||
reciprocal as reciprocal,
|
||||
|
@ -4820,5 +4820,228 @@ class RaggedTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(
|
||||
lax_reference.ragged_dot, lax.ragged_dot, args_maker)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
"lhs_shape": lhs_shape,
|
||||
"rhs_shape": rhs_shape,
|
||||
"group_sizes_shape": group_sizes_shape,
|
||||
"ragged_dot_dimension_numbers": ragged_dot_dimension_numbers,
|
||||
"err_msg": err_msg,
|
||||
}
|
||||
for lhs_shape, rhs_shape, group_sizes_shape, ragged_dot_dimension_numbers, err_msg in [
|
||||
(
|
||||
[11, 5],
|
||||
[3, 5, 7],
|
||||
[3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[0, 1],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
"ragged_dot_general expects exactly one lhs ragged dimension",
|
||||
),
|
||||
(
|
||||
[11, 5],
|
||||
[3, 5, 7],
|
||||
[3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[2],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
(
|
||||
"ragged_dot_general requires lhs ragged dimension numbers to "
|
||||
"be nonnegative and less than the number of axes of the lhs"
|
||||
),
|
||||
),
|
||||
(
|
||||
[11, 5],
|
||||
[3, 5, 7],
|
||||
[2, 3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
r"expected group_sizes to have shape \(3,\), got \(2, 3\)",
|
||||
),
|
||||
(
|
||||
[19, 17, 11, 5],
|
||||
[3, 19, 5, 7],
|
||||
[19, 11, 3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([3], [2]), ([0], [1])),
|
||||
lhs_ragged_dimensions=[2],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
(
|
||||
r"expected group_sizes to have shape \(19, 17, 3\), "
|
||||
r"got \(19, 11, 3\)"
|
||||
),
|
||||
),
|
||||
(
|
||||
[19, 11, 17, 5],
|
||||
[19, 17, 5, 7],
|
||||
[19, 11, 3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([2, 3], [1, 2]), ([0], [0])),
|
||||
lhs_ragged_dimensions=[3],
|
||||
rhs_group_dimensions=[],
|
||||
),
|
||||
(
|
||||
r"expected group_sizes to have shape \(19, 17, 3\), "
|
||||
r"got \(19, 11, 3\)"
|
||||
),
|
||||
),
|
||||
(
|
||||
[17, 19, 11, 5],
|
||||
[17, 19, 5, 7],
|
||||
[19, 3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([3], [2]), ([0, 1], [0, 1])),
|
||||
lhs_ragged_dimensions=[1],
|
||||
rhs_group_dimensions=[],
|
||||
),
|
||||
(
|
||||
r"expected group_sizes to have shape \(17, 3\), "
|
||||
r"got \(19, 3\)"
|
||||
),
|
||||
),
|
||||
(
|
||||
[19, 11, 5],
|
||||
[19, 5, 7],
|
||||
[19, 3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([2], [1]), ([0], [0])),
|
||||
lhs_ragged_dimensions=[1],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
(
|
||||
"ragged_dot_general requires rhs group dimension numbers to "
|
||||
"be distinct from contracting and batch dimensions"
|
||||
),
|
||||
),
|
||||
(
|
||||
[11, 3],
|
||||
[3, 3, 7],
|
||||
[3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[1],
|
||||
),
|
||||
(
|
||||
"ragged_dot_general requires rhs group dimension numbers to "
|
||||
"be distinct from contracting and batch dimensions"
|
||||
),
|
||||
),
|
||||
(
|
||||
[11, 5],
|
||||
[3, 5, 7],
|
||||
[2],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
"expected rhs group dimension size to be 2, got 3",
|
||||
),
|
||||
(
|
||||
[2, 11, 5],
|
||||
[3, 2, 5, 7],
|
||||
[3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([2], [2]), ([0], [1])),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
(
|
||||
"ragged_dot_general requires zero group dimensions in "
|
||||
"the rhs when lhs ragged dimension is contracting or batch"
|
||||
),
|
||||
),
|
||||
(
|
||||
[11, 5],
|
||||
[3, 5, 7],
|
||||
[3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[1],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
(
|
||||
"ragged_dot_general requires zero group dimensions in "
|
||||
"the rhs when lhs ragged dimension is contracting or batch"
|
||||
),
|
||||
),
|
||||
(
|
||||
[11, 5],
|
||||
[5, 7],
|
||||
[3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [0]), ([], [])),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[],
|
||||
),
|
||||
(
|
||||
"ragged_dot_general requires exactly one rhs group dimension "
|
||||
"when lhs ragged dimension is noncontracting"
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_ragged_dot_general_shape_inference_failure(
|
||||
self, lhs_shape, rhs_shape, group_sizes_shape,
|
||||
ragged_dot_dimension_numbers, err_msg):
|
||||
lhs = jnp.ones(lhs_shape, dtype=jnp.float32)
|
||||
rhs = jnp.ones(rhs_shape, dtype=jnp.float32)
|
||||
group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
lax.ragged_dot_general(lhs, rhs, group_sizes,
|
||||
ragged_dot_dimension_numbers)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
"lhs_shape": lhs_shape,
|
||||
"rhs_shape": rhs_shape,
|
||||
"group_sizes_shape": group_sizes_shape,
|
||||
"ragged_dnums": ragged_dnums,
|
||||
"out_shape": out_shape,
|
||||
}
|
||||
for lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape in [
|
||||
(
|
||||
[11, 5],
|
||||
[3, 5, 7],
|
||||
[3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [1]), ([], [])),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[0],
|
||||
),
|
||||
(11, 7),
|
||||
),
|
||||
(
|
||||
[11, 5],
|
||||
[5, 7],
|
||||
[3],
|
||||
lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(([1], [0]), ([], [])),
|
||||
lhs_ragged_dimensions=[1],
|
||||
rhs_group_dimensions=[],
|
||||
),
|
||||
(3, 11, 7),
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_ragged_dot_general_shape_inference_success(
|
||||
self, lhs_shape, rhs_shape, group_sizes_shape, ragged_dnums, out_shape):
|
||||
lhs = jnp.ones(lhs_shape, dtype=jnp.float32)
|
||||
rhs = jnp.ones(rhs_shape, dtype=jnp.float32)
|
||||
group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32)
|
||||
self.assertEqual(
|
||||
lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user