mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fixed mgpu.FragmentedArray.reduce_sum
for integer types
The implementation previously assumed the type is floating and used addf. PiperOrigin-RevId: 678718871
This commit is contained in:
parent
a43c7f2ace
commit
a373e37be2
@ -660,19 +660,6 @@ class FragmentedArray:
|
||||
)
|
||||
|
||||
def reduce_sum(self, scratch) -> ir.Value:
|
||||
index = ir.IndexType.get()
|
||||
if not isinstance(self.layout, WGStridedFragLayout):
|
||||
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
||||
result = c(0, self.mlir_dtype)
|
||||
for reg in self.registers:
|
||||
result = arith.addf(
|
||||
result,
|
||||
vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
|
||||
)
|
||||
scratch_ty = ir.MemRefType(scratch.type)
|
||||
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
|
||||
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
|
||||
|
||||
if ir.FloatType.isinstance(self.mlir_dtype):
|
||||
op = arith.addf
|
||||
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
||||
@ -680,6 +667,19 @@ class FragmentedArray:
|
||||
else:
|
||||
raise NotImplementedError(self.mlir_dtype)
|
||||
|
||||
index = ir.IndexType.get()
|
||||
if not isinstance(self.layout, WGStridedFragLayout):
|
||||
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
||||
result = c(0, self.mlir_dtype)
|
||||
for reg in self.registers:
|
||||
result = op(
|
||||
result,
|
||||
vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
|
||||
)
|
||||
scratch_ty = ir.MemRefType(scratch.type)
|
||||
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
|
||||
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
|
||||
|
||||
warp_result = utils.warp_tree_reduce(result, op, 32)
|
||||
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
|
||||
memref.store(warp_result, scratch, [warp_id])
|
||||
|
@ -1311,6 +1311,36 @@ class FragmentedArrayTest(TestCase):
|
||||
rtol = 4e-6 if approx else 2e-7
|
||||
np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.product(
|
||||
dtype=[jnp.float32, jnp.int32],
|
||||
m=[128],
|
||||
n=[32, 64],
|
||||
)
|
||||
def test_reduce_sum(self, dtype, m, n):
|
||||
def kernel(ctx, src, dst, scratch):
|
||||
src = mgpu.FragmentedArray.load_strided(
|
||||
src, is_signed=utils.is_signed(dtype)
|
||||
)
|
||||
acc = mgpu.FragmentedArray.splat(
|
||||
src.reduce_sum(scratch),
|
||||
(m,),
|
||||
is_signed=src.is_signed
|
||||
)
|
||||
acc.store_untiled(dst)
|
||||
|
||||
in_shape = jax.ShapeDtypeStruct((m, n), dtype)
|
||||
out_shape = jax.ShapeDtypeStruct((m,), dtype)
|
||||
kernel_fn = mgpu.as_gpu_kernel(
|
||||
kernel,
|
||||
(1, 1, 1),
|
||||
(128, 1, 1),
|
||||
in_shape,
|
||||
out_shape,
|
||||
smem_scratch_shape=jax.ShapeDtypeStruct((4,), dtype),
|
||||
)
|
||||
x = np.arange(m * n, dtype=dtype).reshape(m, n)
|
||||
np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum()))
|
||||
|
||||
@parameterized.product(
|
||||
op=(arith.addf, arith.maximumf),
|
||||
m=(64, 128),
|
||||
|
Loading…
x
Reference in New Issue
Block a user