mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mosaic_gpu] Allow calling reduce_sum
on a fragmented array in splat layout
PiperOrigin-RevId: 706668018
This commit is contained in:
parent
c7d1c3d2d8
commit
ee7226d564
@ -1130,7 +1130,27 @@ class FragmentedArray:
|
||||
)
|
||||
|
||||
# NOTE: scratch can be reused immediately once this function returns.
|
||||
def reduce_sum(self, scratch):
|
||||
def reduce_sum(self, scratch: ir.Value | None = None):
|
||||
if isinstance(self.layout, WGSplatFragLayout):
|
||||
[reg] = self.registers.flat
|
||||
if ir.FloatType.isinstance(self.mlir_dtype):
|
||||
op = arith.mulf
|
||||
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
||||
op = arith.muli
|
||||
else:
|
||||
raise NotImplementedError(self.mlir_dtype)
|
||||
return FragmentedArray.splat(
|
||||
op(reg, utils.c(math.prod(self.shape), self.mlir_dtype)),
|
||||
(),
|
||||
is_signed=self.is_signed,
|
||||
)
|
||||
|
||||
if not isinstance(self.layout, WGStridedFragLayout):
|
||||
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
||||
|
||||
if scratch is None:
|
||||
raise ValueError("scratch must be provided")
|
||||
|
||||
if ir.FloatType.isinstance(self.mlir_dtype):
|
||||
op = addf
|
||||
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
||||
@ -1138,9 +1158,6 @@ class FragmentedArray:
|
||||
else:
|
||||
raise NotImplementedError(self.mlir_dtype)
|
||||
|
||||
index = ir.IndexType.get()
|
||||
if not isinstance(self.layout, WGStridedFragLayout):
|
||||
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
||||
result = c(0, self.mlir_dtype)
|
||||
for reg in self.registers:
|
||||
result = op(
|
||||
@ -1151,6 +1168,7 @@ class FragmentedArray:
|
||||
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
|
||||
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
|
||||
|
||||
index = ir.IndexType.get()
|
||||
warp_result = utils.warp_tree_reduce(result, op, 32)
|
||||
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
|
||||
memref.store(warp_result, scratch, [warp_id])
|
||||
|
@ -1476,7 +1476,7 @@ class FragmentedArrayTest(TestCase):
|
||||
m=[128],
|
||||
n=[32, 64],
|
||||
)
|
||||
def test_reduce_sum(self, dtype, m, n):
|
||||
def test_strided_reduce_sum(self, dtype, m, n):
|
||||
def kernel(ctx, src, dst, scratch):
|
||||
src = mgpu.FragmentedArray.load_strided(
|
||||
src, is_signed=utils.is_signed(dtype)
|
||||
@ -1497,6 +1497,31 @@ class FragmentedArrayTest(TestCase):
|
||||
x = np.arange(m * n, dtype=dtype).reshape(m, n)
|
||||
np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum()))
|
||||
|
||||
@parameterized.product(
|
||||
dtype=[jnp.float32, jnp.int32],
|
||||
m=[128],
|
||||
n=[32, 64],
|
||||
)
|
||||
def test_splat_reduce_sum(self, dtype, m, n):
|
||||
def kernel(ctx, dst, _):
|
||||
src = mgpu.FragmentedArray.splat(
|
||||
utils.c(1, utils.dtype_to_ir_type(dtype)),
|
||||
(m, n),
|
||||
is_signed=utils.is_signed(dtype),
|
||||
)
|
||||
acc = src.reduce_sum().broadcast((m,))
|
||||
acc.store_untiled(dst)
|
||||
|
||||
kernel_fn = mgpu.as_gpu_kernel(
|
||||
kernel,
|
||||
(1, 1, 1),
|
||||
(128, 1, 1),
|
||||
in_shape=(),
|
||||
out_shape=jax.ShapeDtypeStruct((m,), dtype),
|
||||
smem_scratch_shape=(),
|
||||
)
|
||||
np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), m * n * 1.0))
|
||||
|
||||
@parameterized.product(
|
||||
op=(arith.addf, arith.maximumf),
|
||||
m=(64, 128),
|
||||
@ -1548,7 +1573,6 @@ class FragmentedArrayTest(TestCase):
|
||||
)()
|
||||
np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32))
|
||||
|
||||
|
||||
def test_splat_binary_ops(self):
|
||||
def kernel(ctx, src, dst, _):
|
||||
f32 = ir.F32Type.get()
|
||||
@ -1570,7 +1594,6 @@ class FragmentedArrayTest(TestCase):
|
||||
)(inp)
|
||||
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
|
||||
|
||||
|
||||
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
|
||||
def test_strided_load_store(self, in_shape):
|
||||
def kernel(ctx, *args):
|
||||
|
Loading…
x
Reference in New Issue
Block a user