[mosaic_gpu] Allow calling reduce_sum on a fragmented array in splat layout

PiperOrigin-RevId: 706668018
This commit is contained in:
Sergei Lebedev 2024-12-16 05:03:35 -08:00 committed by jax authors
parent c7d1c3d2d8
commit ee7226d564
2 changed files with 48 additions and 7 deletions

View File

@ -1130,7 +1130,27 @@ class FragmentedArray:
)
# NOTE: scratch can be reused immediately once this function returns.
def reduce_sum(self, scratch):
def reduce_sum(self, scratch: ir.Value | None = None):
if isinstance(self.layout, WGSplatFragLayout):
[reg] = self.registers.flat
if ir.FloatType.isinstance(self.mlir_dtype):
op = arith.mulf
elif ir.IntegerType.isinstance(self.mlir_dtype):
op = arith.muli
else:
raise NotImplementedError(self.mlir_dtype)
return FragmentedArray.splat(
op(reg, utils.c(math.prod(self.shape), self.mlir_dtype)),
(),
is_signed=self.is_signed,
)
if not isinstance(self.layout, WGStridedFragLayout):
raise NotImplementedError(f"Unsupported layout {self.layout}")
if scratch is None:
raise ValueError("scratch must be provided")
if ir.FloatType.isinstance(self.mlir_dtype):
op = addf
elif ir.IntegerType.isinstance(self.mlir_dtype):
@ -1138,9 +1158,6 @@ class FragmentedArray:
else:
raise NotImplementedError(self.mlir_dtype)
index = ir.IndexType.get()
if not isinstance(self.layout, WGStridedFragLayout):
raise NotImplementedError(f"Unsupported layout {self.layout}")
result = c(0, self.mlir_dtype)
for reg in self.registers:
result = op(
@ -1151,6 +1168,7 @@ class FragmentedArray:
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
index = ir.IndexType.get()
warp_result = utils.warp_tree_reduce(result, op, 32)
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
memref.store(warp_result, scratch, [warp_id])

View File

@ -1476,7 +1476,7 @@ class FragmentedArrayTest(TestCase):
m=[128],
n=[32, 64],
)
def test_reduce_sum(self, dtype, m, n):
def test_strided_reduce_sum(self, dtype, m, n):
def kernel(ctx, src, dst, scratch):
src = mgpu.FragmentedArray.load_strided(
src, is_signed=utils.is_signed(dtype)
@ -1497,6 +1497,31 @@ class FragmentedArrayTest(TestCase):
x = np.arange(m * n, dtype=dtype).reshape(m, n)
np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum()))
@parameterized.product(
dtype=[jnp.float32, jnp.int32],
m=[128],
n=[32, 64],
)
def test_splat_reduce_sum(self, dtype, m, n):
def kernel(ctx, dst, _):
src = mgpu.FragmentedArray.splat(
utils.c(1, utils.dtype_to_ir_type(dtype)),
(m, n),
is_signed=utils.is_signed(dtype),
)
acc = src.reduce_sum().broadcast((m,))
acc.store_untiled(dst)
kernel_fn = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
in_shape=(),
out_shape=jax.ShapeDtypeStruct((m,), dtype),
smem_scratch_shape=(),
)
np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), m * n * 1.0))
@parameterized.product(
op=(arith.addf, arith.maximumf),
m=(64, 128),
@ -1548,7 +1573,6 @@ class FragmentedArrayTest(TestCase):
)()
np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32))
def test_splat_binary_ops(self):
def kernel(ctx, src, dst, _):
f32 = ir.F32Type.get()
@ -1570,7 +1594,6 @@ class FragmentedArrayTest(TestCase):
)(inp)
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
def test_strided_load_store(self, in_shape):
def kernel(ctx, *args):