mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
fix to ragged_all_to_all transpose
PiperOrigin-RevId: 738110447
This commit is contained in:
parent
875099b25d
commit
942ff38e36
@ -1301,6 +1301,7 @@ def _ragged_all_to_all_transpose(
|
||||
mask = jax.numpy.cumsum(
|
||||
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
|
||||
.at[output_offsets_ + recv_sizes].add(-1))
|
||||
mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),))
|
||||
output_t = jax.numpy.where(mask, 0, t)
|
||||
return [operand_t, output_t] + [None] * 4
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user