From 8ce0b55c86f3f16ea2c3285b82c8e760a6c69d75 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 18 Jun 2024 05:28:12 -0700 Subject: [PATCH] [jax:pallas] Fix Pallas kernel batching rule where an input is aliased with an output and the input is batched on a non-zero axis. PiperOrigin-RevId: 644348136 --- jax/_src/pallas/pallas_call.py | 8 ++++++-- tests/pallas/pallas_test.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 973ca4054..100a1b65d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -409,16 +409,20 @@ def _broadcast_input_output_aliases( When we have input/output aliasing, since the output will be mapped, we need to make sure to broadcast the input across that dimension if it is not - mapped. + mapped. If the input is mapped, but on a different axis, we tranpose the input + to match the output. """ args_ = list(args) dims_ = list(dims) for input_index, _ in input_output_aliases: dim = dims_[input_index] + dims_[input_index] = 0 if dim is batching.not_mapped: - dims_[input_index] = 0 args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) + elif dim != 0: + # TODO(cjfj): Change output batching axis instead? + args_[input_index] = jnp.moveaxis(args[input_index], dim, 0) return tuple(args_), tuple(dims_) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index b672879a0..39d638caf 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -982,6 +982,21 @@ class PallasCallVmapTest(PallasTest): out_ref = jnp.arange(2, 10) np.testing.assert_allclose(out, out_ref) + def test_vmap_of_kernel_with_input_output_aliases_different_axes(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + debug=False, + input_output_aliases={0: 0}, + grid=(), + ) + def add(x_ref, o_ref): + o_ref[()] = x_ref[()] + 1 + + out = jax.vmap(add, in_axes=1)(jnp.arange(8).reshape((4, 2))) + out_ref = jnp.arange(1, 9).reshape((4, 2)).swapaxes(0, 1) + np.testing.assert_allclose(out, out_ref) + def test_vmap_of_slicing_kernel_different_axes(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),