mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax:pallas] Fix Pallas kernel batching rule where an input is aliased with an output and the input is batched on a non-zero axis.
PiperOrigin-RevId: 644348136
This commit is contained in:
parent
f4158ace93
commit
8ce0b55c86
@ -409,16 +409,20 @@ def _broadcast_input_output_aliases(
|
||||
|
||||
When we have input/output aliasing, since the output will be mapped, we need
|
||||
to make sure to broadcast the input across that dimension if it is not
|
||||
mapped.
|
||||
mapped. If the input is mapped, but on a different axis, we tranpose the input
|
||||
to match the output.
|
||||
"""
|
||||
|
||||
args_ = list(args)
|
||||
dims_ = list(dims)
|
||||
for input_index, _ in input_output_aliases:
|
||||
dim = dims_[input_index]
|
||||
dims_[input_index] = 0
|
||||
if dim is batching.not_mapped:
|
||||
dims_[input_index] = 0
|
||||
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
|
||||
elif dim != 0:
|
||||
# TODO(cjfj): Change output batching axis instead?
|
||||
args_[input_index] = jnp.moveaxis(args[input_index], dim, 0)
|
||||
|
||||
return tuple(args_), tuple(dims_)
|
||||
|
||||
|
@ -982,6 +982,21 @@ class PallasCallVmapTest(PallasTest):
|
||||
out_ref = jnp.arange(2, 10)
|
||||
np.testing.assert_allclose(out, out_ref)
|
||||
|
||||
def test_vmap_of_kernel_with_input_output_aliases_different_axes(self):
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((4,), jnp.int32),
|
||||
debug=False,
|
||||
input_output_aliases={0: 0},
|
||||
grid=(),
|
||||
)
|
||||
def add(x_ref, o_ref):
|
||||
o_ref[()] = x_ref[()] + 1
|
||||
|
||||
out = jax.vmap(add, in_axes=1)(jnp.arange(8).reshape((4, 2)))
|
||||
out_ref = jnp.arange(1, 9).reshape((4, 2)).swapaxes(0, 1)
|
||||
np.testing.assert_allclose(out, out_ref)
|
||||
|
||||
def test_vmap_of_slicing_kernel_different_axes(self):
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user