mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add support for WGMMA lhs in registers for swizzles other than 128
PiperOrigin-RevId: 653626991
This commit is contained in:
parent
47e6da3332
commit
a07b9adcb2
@ -171,7 +171,8 @@ def wgmma_m64(
|
||||
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
|
||||
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
|
||||
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
|
||||
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64):
|
||||
# Column count must be equal to swizzle // bytewidth.
|
||||
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, swizzle // 2):
|
||||
raise ValueError("Unsupported A register array layout")
|
||||
if a_k_stride is not None or a_transpose is not None:
|
||||
raise ValueError("Unsupported WGMMA features with A in registers")
|
||||
|
@ -613,31 +613,39 @@ class WGMMATest(TestCase):
|
||||
n=(64, 128, 192),
|
||||
k_steps=(1, 2),
|
||||
rhs_transpose=(False, True),
|
||||
swizzle=(32, 64, 128),
|
||||
mlir_dtype_cls=(ir.F16Type, ir.BF16Type),
|
||||
)
|
||||
def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, mlir_dtype_cls):
|
||||
k = 64 * k_steps
|
||||
def test_wgmma_reg_lhs(
|
||||
self, m, n, k_steps, rhs_transpose, swizzle, mlir_dtype_cls
|
||||
):
|
||||
index = ir.IndexType.get()
|
||||
|
||||
row_major = mgpu.WGMMALayout.ROW_MAJOR
|
||||
col_major = mgpu.WGMMALayout.COL_MAJOR
|
||||
rhs_order = col_major if rhs_transpose else row_major
|
||||
bytewidth = 2
|
||||
nk_tile = swizzle // bytewidth
|
||||
k = nk_tile * k_steps
|
||||
|
||||
def kernel(ctx, rhs, out, rhs_smem):
|
||||
del ctx
|
||||
for ki in range(k_steps):
|
||||
for ni in range(n // 64):
|
||||
rhs_slice = (ds(c(ki * 64, index), 64), ds(c(ni * 64, index), 64))
|
||||
for ni in range(n // nk_tile):
|
||||
rhs_slice = (
|
||||
ds(c(ki * nk_tile, index), nk_tile),
|
||||
ds(c(ni * nk_tile, index), nk_tile),
|
||||
)
|
||||
if rhs_transpose:
|
||||
rhs_slice = rhs_slice[::-1]
|
||||
copy(
|
||||
src=memref_slice(rhs, rhs_slice),
|
||||
dst=memref_slice(rhs_smem, (ki, ni)),
|
||||
swizzle=128,
|
||||
swizzle=swizzle,
|
||||
)
|
||||
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
|
||||
lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get())
|
||||
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order)
|
||||
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(0)
|
||||
acc.value.store_untiled(out)
|
||||
@ -647,7 +655,7 @@ class WGMMATest(TestCase):
|
||||
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
||||
scratch_shape = jax.ShapeDtypeStruct(
|
||||
(k_steps, n // 64, 64, 64), jax_dtype
|
||||
(k_steps, n // nk_tile, nk_tile, nk_tile), jax_dtype
|
||||
)
|
||||
z = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape
|
||||
@ -656,7 +664,7 @@ class WGMMATest(TestCase):
|
||||
ref = jax.lax.dot(
|
||||
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
|
||||
)
|
||||
rtol = 0 if k_steps == 1 else 2.2e-4
|
||||
rtol = 5e-4
|
||||
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user