fix to ragged_all_to_all transpose

PiperOrigin-RevId: 738110447
This commit is contained in:
Matthew Johnson 2025-03-18 12:50:36 -07:00 committed by jax authors
parent 875099b25d
commit 942ff38e36

View File

@ -1301,6 +1301,7 @@ def _ragged_all_to_all_transpose(
mask = jax.numpy.cumsum(
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
.at[output_offsets_ + recv_sizes].add(-1))
mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),))
output_t = jax.numpy.where(mask, 0, t)
return [operand_t, output_t] + [None] * 4