[jax:pallas] Fix Pallas kernel batching rule where an input is aliased with an output and the input is batched on a non-zero axis.

PiperOrigin-RevId: 644348136
This commit is contained in:
Chris Jones 2024-06-18 05:28:12 -07:00 committed by jax authors
parent f4158ace93
commit 8ce0b55c86
2 changed files with 21 additions and 2 deletions

View File

@ -409,16 +409,20 @@ def _broadcast_input_output_aliases(
When we have input/output aliasing, since the output will be mapped, we need
to make sure to broadcast the input across that dimension if it is not
mapped.
mapped. If the input is mapped, but on a different axis, we tranpose the input
to match the output.
"""
args_ = list(args)
dims_ = list(dims)
for input_index, _ in input_output_aliases:
dim = dims_[input_index]
dims_[input_index] = 0
if dim is batching.not_mapped:
dims_[input_index] = 0
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
elif dim != 0:
# TODO(cjfj): Change output batching axis instead?
args_[input_index] = jnp.moveaxis(args[input_index], dim, 0)
return tuple(args_), tuple(dims_)

View File

@ -982,6 +982,21 @@ class PallasCallVmapTest(PallasTest):
out_ref = jnp.arange(2, 10)
np.testing.assert_allclose(out, out_ref)
def test_vmap_of_kernel_with_input_output_aliases_different_axes(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.int32),
debug=False,
input_output_aliases={0: 0},
grid=(),
)
def add(x_ref, o_ref):
o_ref[()] = x_ref[()] + 1
out = jax.vmap(add, in_axes=1)(jnp.arange(8).reshape((4, 2)))
out_ref = jnp.arange(1, 9).reshape((4, 2)).swapaxes(0, 1)
np.testing.assert_allclose(out, out_ref)
def test_vmap_of_slicing_kernel_different_axes(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),