mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add very simple batching support for ragged_dot.
PiperOrigin-RevId: 682079947
This commit is contained in:
parent
3446337fdc
commit
f203a9fc9e
@ -3252,13 +3252,25 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
||||
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
|
||||
return y_bar
|
||||
|
||||
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None):
|
||||
lhs, rhs = batched_args
|
||||
lbd, rbd = batch_dims
|
||||
|
||||
def _dot_batch_rule(
|
||||
unpack_args,
|
||||
unpack_dims,
|
||||
invoke_prim,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
*,
|
||||
dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None,
|
||||
**_,
|
||||
):
|
||||
|
||||
lhs, rhs = unpack_args(batched_args)
|
||||
lbd, rbd = unpack_dims(batch_dims)
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
|
||||
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
|
||||
@ -3280,16 +3292,21 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
|
||||
else:
|
||||
rhs_shape = np.shape(rhs)
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm)
|
||||
batched_out = invoke_prim(
|
||||
lhs,
|
||||
rhs,
|
||||
new_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm,
|
||||
)
|
||||
result_batch_dim = batching.shape_as_bdim(
|
||||
result_stack_dim,
|
||||
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
|
||||
return batched_out, result_batch_dim
|
||||
|
||||
|
||||
def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
# There are three kinds of dimensions in a dot_general:
|
||||
# - contraction dimensions appear in lhs and rhs but not the result
|
||||
@ -3364,8 +3381,35 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||
|
||||
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
||||
_dot_general_dtype_rule, 'dot_general')
|
||||
|
||||
|
||||
def _dot_general_batch_unpack_args(batch_args):
|
||||
lhs, rhs = batch_args
|
||||
return (lhs, rhs)
|
||||
|
||||
|
||||
def _dot_general_batch_unpack_dims(batch_dims):
|
||||
lbd, rbd = batch_dims
|
||||
return (lbd, rbd)
|
||||
|
||||
# DotDimensionNumbers used in the dot_general call for ragged_dot().
|
||||
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
|
||||
([2, 0], [1, 0]),
|
||||
([], []),
|
||||
)
|
||||
_RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (
|
||||
([3, 1], [2, 1]),
|
||||
([0], [0]),
|
||||
)
|
||||
|
||||
ad.defbilinear(dot_general_p,
|
||||
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
||||
_dot_general_batch_rule = functools.partial(
|
||||
_dot_batch_rule,
|
||||
_dot_general_batch_unpack_args,
|
||||
_dot_general_batch_unpack_dims,
|
||||
dot_general,
|
||||
)
|
||||
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
||||
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
|
||||
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
|
||||
@ -3469,6 +3513,34 @@ for platform in ["cpu", "tpu"]:
|
||||
|
||||
|
||||
def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
|
||||
if len(lhs.shape) == 3:
|
||||
# Batched case
|
||||
b, m, k = lhs.shape
|
||||
b2, group_count, rk, n = rhs.shape
|
||||
b3 = group_sizes.shape[0]
|
||||
if b != b2:
|
||||
raise TypeError(
|
||||
f'ragged_dot requires that lhs.shape[0] == rhs.shape[0]: got {b} and'
|
||||
f' {b2}.'
|
||||
)
|
||||
if b3 != b:
|
||||
raise TypeError(
|
||||
'ragged_dot requires that group_sizes.shape[0] == lhs.shape[0]: got'
|
||||
f' {b3} and {b}.'
|
||||
)
|
||||
if k != rk:
|
||||
raise TypeError(
|
||||
f'ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and'
|
||||
f' {rk}.'
|
||||
)
|
||||
num_groups = group_sizes.shape[1]
|
||||
if group_count != num_groups:
|
||||
raise TypeError(
|
||||
'ragged_dot requires that rhs.shape[1] == group_sizes.shape[1]: got'
|
||||
f' {group_count} and {num_groups}.'
|
||||
)
|
||||
return (b, m, n)
|
||||
|
||||
m, k = lhs.shape
|
||||
group_count, rk, n = rhs.shape
|
||||
if k != rk:
|
||||
@ -3478,9 +3550,6 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
|
||||
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
|
||||
return (m, n)
|
||||
|
||||
# DotDimensionNumbers used in the dot_general call for ragged_dot().
|
||||
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], []))
|
||||
|
||||
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
|
||||
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
|
||||
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
|
||||
@ -3592,11 +3661,68 @@ def _ragged_dot_transpose_rule(
|
||||
return grad_x, grad_y, None
|
||||
|
||||
|
||||
def _ragged_dot_batch_unpack_args(batched_args):
|
||||
lhs, rhs, _ = batched_args
|
||||
return (lhs, rhs)
|
||||
|
||||
|
||||
def _ragged_dot_batch_unpack_dims(batch_dims):
|
||||
if not all(dim == 0 for dim in batch_dims):
|
||||
raise NotImplementedError('ragged_dot vmap over any dim but 0 - NYI')
|
||||
lbd, rbd, _ = batch_dims
|
||||
return (lbd, rbd)
|
||||
|
||||
|
||||
def _ragged_dot_invoke_prim(
|
||||
group_sizes,
|
||||
lhs,
|
||||
rhs,
|
||||
new_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
algorithm,
|
||||
transpose_algorithm,
|
||||
):
|
||||
assert algorithm is None
|
||||
assert transpose_algorithm is None
|
||||
|
||||
return ragged_dot(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
|
||||
|
||||
def _ragged_dot_batch_rule(
|
||||
batched_args,
|
||||
batch_dims,
|
||||
*,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
):
|
||||
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
|
||||
|
||||
return _dot_batch_rule(
|
||||
_ragged_dot_batch_unpack_args,
|
||||
_ragged_dot_batch_unpack_dims,
|
||||
invoke,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
|
||||
|
||||
ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
|
||||
_ragged_dot_dtype_rule, 'ragged_dot')
|
||||
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
|
||||
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
|
||||
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
|
||||
batching.primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
|
||||
|
||||
def _ragged_dot_impl(
|
||||
lhs: Array,
|
||||
@ -3608,11 +3734,20 @@ def _ragged_dot_impl(
|
||||
) -> Array:
|
||||
if group_offset is not None:
|
||||
raise NotImplementedError("Unimplemented group_offset support.")
|
||||
lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes)
|
||||
|
||||
if len(lhs.shape) == 3:
|
||||
ragged_dot_dims = _RAGGED_DOT_BATCH_DOT_DIMENSION_NUMBERS
|
||||
ragged_to_dense = api.vmap(_ragged_to_dense, in_axes=(0, 0, 0))
|
||||
else:
|
||||
ragged_dot_dims = _RAGGED_DOT_DOT_DIMENSION_NUMBERS
|
||||
ragged_to_dense = _ragged_to_dense
|
||||
|
||||
lhs = ragged_to_dense(lhs, rhs, group_sizes)
|
||||
|
||||
return dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
dimension_numbers=ragged_dot_dims,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user