mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Easy] Refactor ragged_dot transpose, combine ragged_to_dense
PiperOrigin-RevId: 663630185
This commit is contained in:
parent
417fcd574b
commit
9785368c7f
@ -3039,6 +3039,22 @@ def _ragged_dot_jvp_rule(
|
||||
return primal_out, tangent_out
|
||||
|
||||
|
||||
def _ragged_to_dense(x, y, group_sizes):
|
||||
shape = (y.shape[0], x.shape[0], x.shape[1])
|
||||
x = broadcast_in_dim(x, shape, [1, 2])
|
||||
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
|
||||
group_ends = jax.lax.cumsum(group_sizes)
|
||||
group_starts = concatenate(
|
||||
[_zeros(group_sizes)[:1], group_ends[:-1]],
|
||||
dimension=0,
|
||||
)
|
||||
group_ends = broadcast_in_dim(group_ends, shape, (0,))
|
||||
group_starts = broadcast_in_dim(group_starts, shape, (0,))
|
||||
mask = bitwise_and(group_starts <= iota, iota < group_ends)
|
||||
x = select(mask, x, _zeros(x))
|
||||
return x
|
||||
|
||||
|
||||
def _ragged_dot_transpose_rule(
|
||||
ct, *operands, precision, preferred_element_type, group_offset
|
||||
):
|
||||
@ -3046,24 +3062,6 @@ def _ragged_dot_transpose_rule(
|
||||
if group_offset is not None:
|
||||
raise NotImplementedError('Unimplemented group_offset support.')
|
||||
|
||||
def ragged_to_dense(x, group_sizes):
|
||||
group_count = group_sizes.shape[0]
|
||||
shape = (group_count, x.shape[0], x.shape[1])
|
||||
x_broadcasted = jax.lax.broadcast_in_dim(x, shape, (1, 2))
|
||||
iota = jax.lax.broadcasted_iota(group_sizes.dtype, shape, 1)
|
||||
group_ends = jax.lax.cumsum(group_sizes)
|
||||
group_starts = concatenate(
|
||||
[
|
||||
np.zeros_like([group_ends[0]], dtype=group_sizes.dtype),
|
||||
group_ends[:-1],
|
||||
],
|
||||
0,
|
||||
)
|
||||
group_ends = jax.lax.broadcast_in_dim(group_ends, shape, (0,))
|
||||
group_starts = jax.lax.broadcast_in_dim(group_starts, shape, (0,))
|
||||
mask = (group_starts <= iota) & (iota < group_ends)
|
||||
return jax.numpy.where(mask, x_broadcasted, 0)
|
||||
|
||||
if ad.is_undefined_primal(y):
|
||||
grad_x = None
|
||||
else:
|
||||
@ -3079,8 +3077,9 @@ def _ragged_dot_transpose_rule(
|
||||
if ad.is_undefined_primal(x):
|
||||
grad_y = None
|
||||
else:
|
||||
x_dense = ragged_to_dense(x, gs)
|
||||
ct_dense = ragged_to_dense(ct, gs)
|
||||
y = y.aval if ad.is_undefined_primal(y) else y
|
||||
x_dense = _ragged_to_dense(x, y, group_sizes=gs)
|
||||
ct_dense = _ragged_to_dense(ct, y, group_sizes=gs)
|
||||
dimension_numbers = (([1], [1]), ([0], [0]))
|
||||
grad_y = jax.lax.dot_general(
|
||||
x_dense,
|
||||
@ -3109,17 +3108,7 @@ def _ragged_dot_impl(
|
||||
) -> Array:
|
||||
if group_offset is not None:
|
||||
raise NotImplementedError("Unimplemented group_offset support.")
|
||||
shape = (rhs.shape[0], lhs.shape[0], lhs.shape[1])
|
||||
lhs = broadcast_in_dim(lhs, shape, [1, 2])
|
||||
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
|
||||
group_ends = jax.lax.cumsum(group_sizes)
|
||||
group_starts = concatenate(
|
||||
[_zeros(group_sizes)[:1], group_ends[:-1]], dimension=0,
|
||||
)
|
||||
group_ends = broadcast_in_dim(group_ends, shape, (0,))
|
||||
group_starts = broadcast_in_dim(group_starts, shape, (0,))
|
||||
mask = bitwise_and(group_starts <= iota, iota < group_ends)
|
||||
lhs = select(mask, lhs, _zeros(lhs))
|
||||
lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes)
|
||||
return dot_general(
|
||||
lhs,
|
||||
rhs,
|
||||
|
Loading…
x
Reference in New Issue
Block a user