[Mosaic GPU] Remove TMA inputs

This test configuration dates back to the time when we were very unsure
about how to use TMA. At this point we have plenty of experience and it
makes more sense to focus the test in question on verifying WGMMA. This
also simplifies adding support for smaller RHS tiling.

PiperOrigin-RevId: 732040900
This commit is contained in:
Adam Paszke 2025-02-28 01:18:47 -08:00 committed by jax authors
parent 092ea35301
commit 832f5a3aff

View File

@ -650,7 +650,6 @@ class WGMMATest(TestCase):
m=(64, 128, 192),
n=(64, 128, 192),
k_steps=(1, 2),
tma_inputs=(False, True),
swizzle=(32, 64, 128),
jax_out_dtype=(jnp.float16, jnp.float32),
)
@ -662,7 +661,6 @@ class WGMMATest(TestCase):
in_mlir_dtype_cls,
lhs_transpose,
rhs_transpose,
tma_inputs,
swizzle,
jax_out_dtype,
):
@ -670,8 +668,6 @@ class WGMMATest(TestCase):
raise self.skipTest("Only f16 input is supported for f16 output.")
if swizzle != 128 and lhs_transpose:
raise self.skipTest("Transpose only supported in 128B swizzled WGMMA")
if swizzle != 128 and not tma_inputs:
raise self.skipTest("Copy with non-128B swizzles not implemented")
in_mlir_dtype = in_mlir_dtype_cls.get()
out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype)
@ -696,59 +692,32 @@ class WGMMATest(TestCase):
nk_tile = swizzle // bytewidth(in_mlir_dtype)
k = nk_tile * k_steps
assert m % 64 == 0 and n % nk_tile == 0
index = ir.IndexType.get()
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers = scratch
if tma_inputs:
lhs_transform = (mgpu.TileTransform((64, nk_tile)),)
if lhs_transpose:
assert nk_tile == 64 # Make sure we didn't have to transpose tiling.
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
if rhs_transpose:
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
ctx.async_copy(
src_ref=lhs,
dst_ref=lhs_smem,
swizzle=swizzle,
gmem_transform=lhs_transform,
barrier=barriers[0],
)
ctx.async_copy(
src_ref=rhs,
dst_ref=rhs_smem,
swizzle=swizzle,
gmem_transform=rhs_transform,
barrier=barriers[1],
)
for i in range(2):
barriers[i].wait()
else:
for mi in range(m // 64):
for ki in range(k // nk_tile):
lhs_slice = (
ds(c(mi * 64, index), 64),
ds(c(ki * nk_tile, index), nk_tile),
)
if lhs_transpose:
lhs_slice = lhs_slice[::-1]
copy(
src=memref_slice(lhs, lhs_slice),
dst=memref_slice(lhs_smem, (mi, ki)),
swizzle=swizzle,
)
for ki in range(k // nk_tile):
k_slice = ds(c(ki * nk_tile, index), nk_tile)
for ni in range(n // nk_tile):
rhs_slice = (k_slice, ds(c(ni * nk_tile, index), nk_tile))
if rhs_transpose:
rhs_slice = rhs_slice[::-1]
copy(
src=memref_slice(rhs, rhs_slice),
dst=memref_slice(rhs_smem, (ki, ni)),
swizzle=swizzle,
)
lhs_transform = (mgpu.TileTransform((64, nk_tile)),)
if lhs_transpose:
assert nk_tile == 64 # Make sure we didn't have to transpose tiling.
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
if rhs_transpose:
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
ctx.async_copy(
src_ref=lhs,
dst_ref=lhs_smem,
swizzle=swizzle,
gmem_transform=lhs_transform,
barrier=barriers[0],
)
ctx.async_copy(
src_ref=rhs,
dst_ref=rhs_smem,
swizzle=swizzle,
gmem_transform=rhs_transform,
barrier=barriers[1],
)
for i in range(2):
barriers[i].wait()
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype)
if lhs_transpose:
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))