Add very simple batching support for ragged_dot.

PiperOrigin-RevId: 682079947
This commit is contained in:
jax authors 2024-10-03 16:41:31 -07:00
parent 3446337fdc
commit f203a9fc9e

View File

@ -3252,13 +3252,25 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
return y_bar
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
lhs, rhs = batched_args
lbd, rbd = batch_dims
def _dot_batch_rule(
unpack_args,
unpack_dims,
invoke_prim,
batched_args,
batch_dims,
*,
dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None,
**_,
):
lhs, rhs = unpack_args(batched_args)
lbd, rbd = unpack_dims(batch_dims)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
@ -3280,16 +3292,21 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
else:
rhs_shape = np.shape(rhs)
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm)
batched_out = invoke_prim(
lhs,
rhs,
new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm,
)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
return batched_out, result_batch_dim
def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
# There are three kinds of dimensions in a dot_general:
# - contraction dimensions appear in lhs and rhs but not the result
@ -3364,8 +3381,35 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
dot_general_p = standard_primitive(_dot_general_shape_rule,
_dot_general_dtype_rule, 'dot_general')
def _dot_general_batch_unpack_args(batch_args):
lhs, rhs = batch_args
return (lhs, rhs)
def _dot_general_batch_unpack_dims(batch_dims):
lbd, rbd = batch_dims
return (lbd, rbd)
# DotDimensionNumbers used in the dot_general call for ragged_dot().
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
([2, 0], [1, 0]),
([], []),
)
_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
([3, 1], [2, 1]),
([0], [0]),
)
ad.defbilinear(dot_general_p,
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
_dot_general_batch_rule = functools.partial(
_dot_batch_rule,
_dot_general_batch_unpack_args,
_dot_general_batch_unpack_dims,
dot_general,
)
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
@ -3469,6 +3513,34 @@ for platform in ["cpu", "tpu"]:
def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
if len(lhs.shape) == 3:
# Batched case
b, m, k = lhs.shape
b2, group_count, rk, n = rhs.shape
b3 = group_sizes.shape[0]
if b != b2:
raise TypeError(
f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and'
f' {b2}.'
)
if b3 != b:
raise TypeError(
'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got'
f' {b3} and {b}.'
)
if k != rk:
raise TypeError(
f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and'
f' {rk}.'
)
num_groups = group_sizes.shape[1]
if group_count != num_groups:
raise TypeError(
'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got'
f' {group_count} and {num_groups}.'
)
return (b, m, n)
m, k = lhs.shape
group_count, rk, n = rhs.shape
if k != rk:
@ -3478,9 +3550,6 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
return (m, n)
# DotDimensionNumbers used in the dot_general call for ragged_dot().
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], []))
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
@ -3592,11 +3661,68 @@ def _ragged_dot_transpose_rule(
return grad_x, grad_y, None
def _ragged_dot_batch_unpack_args(batched_args):
lhs, rhs, _ = batched_args
return (lhs, rhs)
def _ragged_dot_batch_unpack_dims(batch_dims):
if not all(dim == 0 for dim in batch_dims):
raise NotImplementedError('ragged_dot vmap over any dim but 0 - NYI')
lbd, rbd, _ = batch_dims
return (lbd, rbd)
def _ragged_dot_invoke_prim(
group_sizes,
lhs,
rhs,
new_dimension_numbers,
precision,
preferred_element_type,
algorithm,
transpose_algorithm,
):
assert algorithm is None
assert transpose_algorithm is None
return ragged_dot(
lhs,
rhs,
group_sizes,
precision=precision,
preferred_element_type=preferred_element_type,
)
def _ragged_dot_batch_rule(
batched_args,
batch_dims,
*,
precision,
preferred_element_type: DTypeLike | None,
**_,
):
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
return _dot_batch_rule(
_ragged_dot_batch_unpack_args,
_ragged_dot_batch_unpack_dims,
invoke,
batched_args,
batch_dims,
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision,
preferred_element_type=preferred_element_type,
)
ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
_ragged_dot_dtype_rule, 'ragged_dot')
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
batching.primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
def _ragged_dot_impl(
lhs: Array,
@ -3608,11 +3734,20 @@ def _ragged_dot_impl(
) -> Array:
if group_offset is not None:
raise NotImplementedError("Unimplemented group_offset support.")
lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes)
if len(lhs.shape) == 3:
ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS
ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0))
else:
ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS
ragged_to_dense = _ragged_to_dense
lhs = ragged_to_dense(lhs, rhs, group_sizes)
return dot_general(
lhs,
rhs,
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
dimension_numbers=ragged_dot_dims,
precision=precision,
preferred_element_type=preferred_element_type,
)