[Mosaic GPU] Add support for WGMMA lhs in registers for swizzles other than 128

PiperOrigin-RevId: 653626991
This commit is contained in:
Adam Paszke 2024-07-18 08:22:34 -07:00 committed by jax authors
parent 47e6da3332
commit a07b9adcb2
2 changed files with 18 additions and 9 deletions

View File

@ -171,7 +171,8 @@ def wgmma_m64(
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64):
# Column count must be equal to swizzle // bytewidth.
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, swizzle // 2):
raise ValueError("Unsupported A register array layout")
if a_k_stride is not None or a_transpose is not None:
raise ValueError("Unsupported WGMMA features with A in registers")

View File

@ -613,31 +613,39 @@ class WGMMATest(TestCase):
n=(64, 128, 192),
k_steps=(1, 2),
rhs_transpose=(False, True),
swizzle=(32, 64, 128),
mlir_dtype_cls=(ir.F16Type, ir.BF16Type),
)
def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, mlir_dtype_cls):
k = 64 * k_steps
def test_wgmma_reg_lhs(
self, m, n, k_steps, rhs_transpose, swizzle, mlir_dtype_cls
):
index = ir.IndexType.get()
row_major = mgpu.WGMMALayout.ROW_MAJOR
col_major = mgpu.WGMMALayout.COL_MAJOR
rhs_order = col_major if rhs_transpose else row_major
bytewidth = 2
nk_tile = swizzle // bytewidth
k = nk_tile * k_steps
def kernel(ctx, rhs, out, rhs_smem):
del ctx
for ki in range(k_steps):
for ni in range(n // 64):
rhs_slice = (ds(c(ki * 64, index), 64), ds(c(ni * 64, index), 64))
for ni in range(n // nk_tile):
rhs_slice = (
ds(c(ki * nk_tile, index), nk_tile),
ds(c(ni * nk_tile, index), nk_tile),
)
if rhs_transpose:
rhs_slice = rhs_slice[::-1]
copy(
src=memref_slice(rhs, rhs_slice),
dst=memref_slice(rhs_smem, (ki, ni)),
swizzle=128,
swizzle=swizzle,
)
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get())
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order)
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
acc.value.store_untiled(out)
@ -647,7 +655,7 @@ class WGMMATest(TestCase):
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
scratch_shape = jax.ShapeDtypeStruct(
(k_steps, n // 64, 64, 64), jax_dtype
(k_steps, n // nk_tile, nk_tile, nk_tile), jax_dtype
)
z = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape
@ -656,7 +664,7 @@ class WGMMATest(TestCase):
ref = jax.lax.dot(
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
)
rtol = 0 if k_steps == 1 else 2.2e-4
rtol = 5e-4
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)