[Easy] Refactor ragged_dot transpose, combine ragged_to_dense

PiperOrigin-RevId: 663630185
This commit is contained in:
jax authors 2024-08-16 00:32:00 -07:00 committed by jax authors
parent 417fcd574b
commit 9785368c7f

View File

@ -3039,6 +3039,22 @@ def _ragged_dot_jvp_rule(
return primal_out, tangent_out
def _ragged_to_dense(x, y, group_sizes):
shape = (y.shape[0], x.shape[0], x.shape[1])
x = broadcast_in_dim(x, shape, [1, 2])
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
group_ends = jax.lax.cumsum(group_sizes)
group_starts = concatenate(
[_zeros(group_sizes)[:1], group_ends[:-1]],
dimension=0,
)
group_ends = broadcast_in_dim(group_ends, shape, (0,))
group_starts = broadcast_in_dim(group_starts, shape, (0,))
mask = bitwise_and(group_starts <= iota, iota < group_ends)
x = select(mask, x, _zeros(x))
return x
def _ragged_dot_transpose_rule(
ct, *operands, precision, preferred_element_type, group_offset
):
@ -3046,24 +3062,6 @@ def _ragged_dot_transpose_rule(
if group_offset is not None:
raise NotImplementedError('Unimplemented group_offset support.')
def ragged_to_dense(x, group_sizes):
group_count = group_sizes.shape[0]
shape = (group_count, x.shape[0], x.shape[1])
x_broadcasted = jax.lax.broadcast_in_dim(x, shape, (1, 2))
iota = jax.lax.broadcasted_iota(group_sizes.dtype, shape, 1)
group_ends = jax.lax.cumsum(group_sizes)
group_starts = concatenate(
[
np.zeros_like([group_ends[0]], dtype=group_sizes.dtype),
group_ends[:-1],
],
0,
)
group_ends = jax.lax.broadcast_in_dim(group_ends, shape, (0,))
group_starts = jax.lax.broadcast_in_dim(group_starts, shape, (0,))
mask = (group_starts <= iota) & (iota < group_ends)
return jax.numpy.where(mask, x_broadcasted, 0)
if ad.is_undefined_primal(y):
grad_x = None
else:
@ -3079,8 +3077,9 @@ def _ragged_dot_transpose_rule(
if ad.is_undefined_primal(x):
grad_y = None
else:
x_dense = ragged_to_dense(x, gs)
ct_dense = ragged_to_dense(ct, gs)
y = y.aval if ad.is_undefined_primal(y) else y
x_dense = _ragged_to_dense(x, y, group_sizes=gs)
ct_dense = _ragged_to_dense(ct, y, group_sizes=gs)
dimension_numbers = (([1], [1]), ([0], [0]))
grad_y = jax.lax.dot_general(
x_dense,
@ -3109,17 +3108,7 @@ def _ragged_dot_impl(
) -> Array:
if group_offset is not None:
raise NotImplementedError("Unimplemented group_offset support.")
shape = (rhs.shape[0], lhs.shape[0], lhs.shape[1])
lhs = broadcast_in_dim(lhs, shape, [1, 2])
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
group_ends = jax.lax.cumsum(group_sizes)
group_starts = concatenate(
[_zeros(group_sizes)[:1], group_ends[:-1]], dimension=0,
)
group_ends = broadcast_in_dim(group_ends, shape, (0,))
group_starts = broadcast_in_dim(group_starts, shape, (0,))
mask = bitwise_and(group_starts <= iota, iota < group_ends)
lhs = select(mask, lhs, _zeros(lhs))
lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes)
return dot_general(
lhs,
rhs,