mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU] Remove TMA inputs
This test configuration dates back to the time when we were very unsure about how to use TMA. At this point we have plenty of experience and it makes more sense to focus the test in question on verifying WGMMA. This also simplifies adding support for smaller RHS tiling. PiperOrigin-RevId: 732040900
This commit is contained in:
parent
092ea35301
commit
832f5a3aff
@ -650,7 +650,6 @@ class WGMMATest(TestCase):
|
||||
m=(64, 128, 192),
|
||||
n=(64, 128, 192),
|
||||
k_steps=(1, 2),
|
||||
tma_inputs=(False, True),
|
||||
swizzle=(32, 64, 128),
|
||||
jax_out_dtype=(jnp.float16, jnp.float32),
|
||||
)
|
||||
@ -662,7 +661,6 @@ class WGMMATest(TestCase):
|
||||
in_mlir_dtype_cls,
|
||||
lhs_transpose,
|
||||
rhs_transpose,
|
||||
tma_inputs,
|
||||
swizzle,
|
||||
jax_out_dtype,
|
||||
):
|
||||
@ -670,8 +668,6 @@ class WGMMATest(TestCase):
|
||||
raise self.skipTest("Only f16 input is supported for f16 output.")
|
||||
if swizzle != 128 and lhs_transpose:
|
||||
raise self.skipTest("Transpose only supported in 128B swizzled WGMMA")
|
||||
if swizzle != 128 and not tma_inputs:
|
||||
raise self.skipTest("Copy with non-128B swizzles not implemented")
|
||||
|
||||
in_mlir_dtype = in_mlir_dtype_cls.get()
|
||||
out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype)
|
||||
@ -696,59 +692,32 @@ class WGMMATest(TestCase):
|
||||
nk_tile = swizzle // bytewidth(in_mlir_dtype)
|
||||
k = nk_tile * k_steps
|
||||
assert m % 64 == 0 and n % nk_tile == 0
|
||||
index = ir.IndexType.get()
|
||||
|
||||
def kernel(ctx, lhs, rhs, out, scratch):
|
||||
lhs_smem, rhs_smem, barriers = scratch
|
||||
if tma_inputs:
|
||||
lhs_transform = (mgpu.TileTransform((64, nk_tile)),)
|
||||
if lhs_transpose:
|
||||
assert nk_tile == 64 # Make sure we didn't have to transpose tiling.
|
||||
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
|
||||
if rhs_transpose:
|
||||
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
ctx.async_copy(
|
||||
src_ref=lhs,
|
||||
dst_ref=lhs_smem,
|
||||
swizzle=swizzle,
|
||||
gmem_transform=lhs_transform,
|
||||
barrier=barriers[0],
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=rhs,
|
||||
dst_ref=rhs_smem,
|
||||
swizzle=swizzle,
|
||||
gmem_transform=rhs_transform,
|
||||
barrier=barriers[1],
|
||||
)
|
||||
for i in range(2):
|
||||
barriers[i].wait()
|
||||
else:
|
||||
for mi in range(m // 64):
|
||||
for ki in range(k // nk_tile):
|
||||
lhs_slice = (
|
||||
ds(c(mi * 64, index), 64),
|
||||
ds(c(ki * nk_tile, index), nk_tile),
|
||||
)
|
||||
if lhs_transpose:
|
||||
lhs_slice = lhs_slice[::-1]
|
||||
copy(
|
||||
src=memref_slice(lhs, lhs_slice),
|
||||
dst=memref_slice(lhs_smem, (mi, ki)),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
for ki in range(k // nk_tile):
|
||||
k_slice = ds(c(ki * nk_tile, index), nk_tile)
|
||||
for ni in range(n // nk_tile):
|
||||
rhs_slice = (k_slice, ds(c(ni * nk_tile, index), nk_tile))
|
||||
if rhs_transpose:
|
||||
rhs_slice = rhs_slice[::-1]
|
||||
copy(
|
||||
src=memref_slice(rhs, rhs_slice),
|
||||
dst=memref_slice(rhs_smem, (ki, ni)),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
lhs_transform = (mgpu.TileTransform((64, nk_tile)),)
|
||||
if lhs_transpose:
|
||||
assert nk_tile == 64 # Make sure we didn't have to transpose tiling.
|
||||
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
|
||||
if rhs_transpose:
|
||||
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
ctx.async_copy(
|
||||
src_ref=lhs,
|
||||
dst_ref=lhs_smem,
|
||||
swizzle=swizzle,
|
||||
gmem_transform=lhs_transform,
|
||||
barrier=barriers[0],
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=rhs,
|
||||
dst_ref=rhs_smem,
|
||||
swizzle=swizzle,
|
||||
gmem_transform=rhs_transform,
|
||||
barrier=barriers[1],
|
||||
)
|
||||
for i in range(2):
|
||||
barriers[i].wait()
|
||||
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype)
|
||||
if lhs_transpose:
|
||||
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))
|
||||
|
Loading…
x
Reference in New Issue
Block a user