Fixed mgpu.FragmentedArray.reduce_sum for integer types

The implementation previously assumed the type is floating and used addf.

PiperOrigin-RevId: 678718871
This commit is contained in:
Sergei Lebedev 2024-09-25 08:49:36 -07:00 committed by jax authors
parent a43c7f2ace
commit a373e37be2
2 changed files with 43 additions and 13 deletions

View File

@ -660,19 +660,6 @@ class FragmentedArray:
)
def reduce_sum(self, scratch) -> ir.Value:
index = ir.IndexType.get()
if not isinstance(self.layout, WGStridedFragLayout):
raise NotImplementedError(f"Unsupported layout {self.layout}")
result = c(0, self.mlir_dtype)
for reg in self.registers:
result = arith.addf(
result,
vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
)
scratch_ty = ir.MemRefType(scratch.type)
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
if ir.FloatType.isinstance(self.mlir_dtype):
op = arith.addf
elif ir.IntegerType.isinstance(self.mlir_dtype):
@ -680,6 +667,19 @@ class FragmentedArray:
else:
raise NotImplementedError(self.mlir_dtype)
index = ir.IndexType.get()
if not isinstance(self.layout, WGStridedFragLayout):
raise NotImplementedError(f"Unsupported layout {self.layout}")
result = c(0, self.mlir_dtype)
for reg in self.registers:
result = op(
result,
vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
)
scratch_ty = ir.MemRefType(scratch.type)
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
warp_result = utils.warp_tree_reduce(result, op, 32)
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
memref.store(warp_result, scratch, [warp_id])

View File

@ -1311,6 +1311,36 @@ class FragmentedArrayTest(TestCase):
rtol = 4e-6 if approx else 2e-7
np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol)
@parameterized.product(
dtype=[jnp.float32, jnp.int32],
m=[128],
n=[32, 64],
)
def test_reduce_sum(self, dtype, m, n):
def kernel(ctx, src, dst, scratch):
src = mgpu.FragmentedArray.load_strided(
src, is_signed=utils.is_signed(dtype)
)
acc = mgpu.FragmentedArray.splat(
src.reduce_sum(scratch),
(m,),
is_signed=src.is_signed
)
acc.store_untiled(dst)
in_shape = jax.ShapeDtypeStruct((m, n), dtype)
out_shape = jax.ShapeDtypeStruct((m,), dtype)
kernel_fn = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
in_shape,
out_shape,
smem_scratch_shape=jax.ShapeDtypeStruct((4,), dtype),
)
x = np.arange(m * n, dtype=dtype).reshape(m, n)
np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum()))
@parameterized.product(
op=(arith.addf, arith.maximumf),
m=(64, 128),