mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mgpu] Fixed off-by-one issue in pointwise argument shuffling when leading argument is splat.
Also adapted the test to catch a possible regression. The issue appeared in >2 operands. PiperOrigin-RevId: 701306731
This commit is contained in:
parent
f10d3eb312
commit
ea69401eff
@ -632,7 +632,7 @@ class FragmentedArray:
|
||||
continue
|
||||
elif not isinstance(o.layout, WGSplatFragLayout):
|
||||
return o._pointwise(
|
||||
lambda o, *args: op(*args[:i], o, *args[i:]),
|
||||
lambda o, this, *args: op(this, *args[:i], o, *args[i:]),
|
||||
self,
|
||||
*other[:i],
|
||||
*other[i + 1 :],
|
||||
|
@ -1565,14 +1565,14 @@ class FragmentedArrayTest(TestCase):
|
||||
assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout)
|
||||
pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq
|
||||
assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout)
|
||||
(pi_arr_sq + pi_arr_cube).store_untiled(dst)
|
||||
(pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst)
|
||||
|
||||
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
|
||||
inp = jnp.ones_like(out_shape) * 3.14
|
||||
result = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, ()
|
||||
)(inp)
|
||||
np.testing.assert_allclose(result, np.full((128, 32), 3.14 ** 2 + 3.14 ** 3, np.float32))
|
||||
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
|
||||
|
||||
|
||||
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user