diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 764e4dcbe..221fe2a9e 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1301,6 +1301,7 @@ def _ragged_all_to_all_transpose( mask = jax.numpy.cumsum( jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ .at[output_offsets_ + recv_sizes].add(-1)) + mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),)) output_t = jax.numpy.where(mask, 0, t) return [operand_t, output_t] + [None] * 4