[mgpu] Fixed off-by-one issue in pointwise argument shuffling when leading argument is splat.

Also adapted the test to catch a possible regression. The issue appeared in >2 operands.

PiperOrigin-RevId: 701306731
This commit is contained in:
Christos Perivolaropoulos 2024-11-29 09:34:04 -08:00 committed by jax authors
parent f10d3eb312
commit ea69401eff
2 changed files with 3 additions and 3 deletions

View File

@ -632,7 +632,7 @@ class FragmentedArray:
continue
elif not isinstance(o.layout, WGSplatFragLayout):
return o._pointwise(
lambda o, *args: op(*args[:i], o, *args[i:]),
lambda o, this, *args: op(this, *args[:i], o, *args[i:]),
self,
*other[:i],
*other[i + 1 :],

View File

@ -1565,14 +1565,14 @@ class FragmentedArrayTest(TestCase):
assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout)
pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq
assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout)
(pi_arr_sq + pi_arr_cube).store_untiled(dst)
(pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
inp = jnp.ones_like(out_shape) * 3.14
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, ()
)(inp)
np.testing.assert_allclose(result, np.full((128, 32), 3.14 ** 2 + 3.14 ** 3, np.float32))
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))