[Mosaic GPU][NFC] Move the calculation of group strides into _validate_mma

This allows us to unify this logic between Hopper and Blackwell.

PiperOrigin-RevId: 732862875
This commit is contained in:
Adam Paszke 2025-03-03 03:49:56 -08:00 committed by jax authors
parent bbadf99054
commit 11e6cfbc6a
2 changed files with 77 additions and 80 deletions

View File

@ -106,83 +106,61 @@ def mma(
raise ValueError(f"B must be a memref, got: {b.type}")
if a_swizzle != b_swizzle:
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
swizzle = a_swizzle
if isinstance(accumulate, bool):
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
num_cta = 2 if collective else 1
m_group_size = d.layout.elements_in_tile[0]
if m_group_size != 128:
raise NotImplementedError("Only 128-row accumulators supported for now")
(
a_desc_base,
b_desc_base,
(m, k, n),
(m_mem_tiling, a_k_mem_tiling, b_k_mem_tiling, n_mem_tiling),
element_type,
(m_groups, k_groups, n_groups),
(a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride),
mma_params,
) = _wgmma._validate_mma(
a,
b,
a_swizzle,
m_group_size=m_group_size,
descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
)
element_bytewidth = utils.bytewidth(element_type)
k_group_tiling = swizzle // element_bytewidth
if (m_group_tiling := d.layout.elements_in_tile[0]) != m_mem_tiling:
raise ValueError(
f"A's row tiling must be equal to {m_group_tiling} (inferred from"
f" accumulator's TMEM layout), got: {m_mem_tiling}"
n_group_size = n // n_groups
if n > 512:
raise ValueError(f"N is too big: at most 512 is supported, but got {n}")
num_cta = 2 if collective else 1
if num_cta == 2 and n > 256:
raise NotImplementedError(
"N is too big for collective MMA. Only up to 256 is supported."
)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_m_group_stride = a_strides[0] * element_bytewidth
a_k_tiles_per_group = k_group_tiling // a_k_mem_tiling
a_k_group_stride = a_strides[1] * element_bytewidth * a_k_tiles_per_group
if n * num_cta <= 256:
n_group_tiling = n
elif n * num_cta == 512:
if collective:
raise NotImplementedError("Collective MMA with effective N=512 is unsupported")
n_group_tiling = 256
else:
raise NotImplementedError("The only supported N larger than 256 is 512")
b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset()
b_k_tiles_per_group = k_group_tiling // b_k_mem_tiling
b_k_group_stride = b_strides[0] * element_bytewidth * b_k_tiles_per_group
n_tiles_per_group = n_group_tiling // n_mem_tiling
b_n_group_stride = b_strides[1] * element_bytewidth * n_tiles_per_group
groups_k = k // k_group_tiling
groups_m = m // m_group_tiling
groups_n = n // n_group_tiling
# TODO(apaszke): Verify that the cluster shape matches the expectation of
# collective MMA.
expected_acc_shape = (m, n * (2 if collective else 1))
expected_acc_shape = (m, n * num_cta)
if d.shape != expected_acc_shape:
raise ValueError(
f"Accumulator shape mismatch: expected {expected_acc_shape}, got {d.shape}"
)
true = arith.constant(ir.IntegerType.get_signless(1), 1)
for mi, ni, ki in np.ndindex(groups_m, groups_n, groups_k):
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64))
b_offset = ni * b_n_group_stride + ki * b_k_group_stride
b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64))
if groups_m != 1:
if m_groups != 1:
raise NotImplementedError("D needs to be sliced")
acc = accumulate if ki == 0 else true
_do_mma(
d.slice(
slice(None), utils.ds(ni * n_group_tiling, n_group_tiling)
slice(None), utils.ds(ni * n_group_size, n_group_size)
).address,
a_mk,
b_nk,
d_type=ir.F32Type.get(),
m=m_group_tiling,
n=n_group_tiling,
m=m_group_size,
collective=collective,
**mma_params,
accumulate=acc,

View File

@ -306,6 +306,7 @@ def _validate_mma(
a: Any,
b: ir.Value,
swizzle: int,
m_group_size: int, # The M used by a single instruction.
descriptor_const_init: int = 0,
):
# We need swizzle >= 32 to ensure that our K tiling is larger than the MMA
@ -399,27 +400,49 @@ def _validate_mma(
else:
raise ValueError(b_byte_strides)
if n > 256 and n % 256:
raise ValueError(
f"N group size must be a multiple of 256 when larger than 256, got: {n}"
)
k_group_size = swizzle_elems
n_group_size = min(n, 256)
b_k_tiles_per_group = k_group_size // b_k_tiling
b_k_group_stride = b_k_byte_stride * b_k_tiles_per_group
n_tiles_per_group = n_group_size // n_tiling
b_n_group_stride = b_n_byte_stride * n_tiles_per_group
# Verify the shape and strides of A are as expected.
if not a_in_smem:
m = a_shape[0]
a_order = m_tiling = a_k_tiling = None
a_order = a_m_group_stride = a_k_group_stride = None
else:
a_ty = ir.MemRefType(a.type)
m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape
m = m_tiles * m_tiling
# TODO(apaszke): I'm not actually convinced that we need this check.
if m_tiling != m_group_size:
raise ValueError(
f"A's row tiling must be equal to {m_group_size}, got: {m_tiling}"
)
if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k:
raise ValueError(a_ty.shape)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_byte_strides = [s * element_bytewidth for s in a_strides]
if a_byte_strides[2:] == [swizzle, element_bytewidth]:
a_m_byte_stride, a_k_byte_stride, *a_tile_byte_strides = [
s * element_bytewidth for s in a_strides
]
if a_tile_byte_strides == [swizzle, element_bytewidth]:
a_order = WGMMALayout.ROW_MAJOR
elif a_byte_strides[2:] == [element_bytewidth, swizzle]:
elif a_tile_byte_strides == [element_bytewidth, swizzle]:
a_order = WGMMALayout.COL_MAJOR
else:
raise ValueError(a_byte_strides)
raise ValueError(a_strides)
if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems:
# Not sure what the layout is like, since the tiles aren't square.
raise NotImplementedError
a_m_tiles_per_group = m_group_size // m_tiling
a_m_group_stride = a_m_byte_stride * a_m_tiles_per_group
a_k_tiles_per_group = k_group_size // a_k_tiling
a_k_group_stride = a_k_byte_stride * a_k_tiles_per_group
b_k_fastest = b_order == WGMMALayout.COL_MAJOR
a_k_fastest = a_order == WGMMALayout.ROW_MAJOR
@ -488,6 +511,7 @@ def _validate_mma(
a_k_stride=32 if a_k_fastest else swizzle * 16,
b_k_stride=b_k_wgmma_stride,
swizzle=swizzle,
n=n_group_size,
element_type=ir.FloatTF32Type.get()
if ir.F32Type.isinstance(element_type)
else element_type,
@ -503,12 +527,23 @@ def _validate_mma(
b, **b_desc_fields, const_init=descriptor_const_init
)
if m % m_group_size:
raise ValueError(f"m must be a multiple of {m_group_size}, got: {m}")
m_groups = m // m_group_size
if k % k_group_size:
raise ValueError(f"k must be a multiple of {k_group_size}, got: {k}")
k_groups = k // k_group_size
if n % n_group_size:
raise ValueError(f"n must be a multiple of {n_group_size}, got: {n}")
n_groups = n // n_group_size
return (
a_desc_base,
b_desc_base,
(m, k, n),
(m_tiling, a_k_tiling, b_k_tiling, n_tiling),
element_type,
(m_groups, k_groups, n_groups),
# Group strides are always in bytes!
(a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride),
wgmma_params,
)
@ -538,49 +573,32 @@ def wgmma(
if not ir.MemRefType.isinstance(b.type):
raise ValueError(f"B must be a memref, got: {b.type}")
m_group_size = 64 # Hopper has a fixed M instruction shape.
(
a_desc_base,
b_desc_base,
(m, k, n),
(m_mem_tiling, a_k_mem_tiling, b_k_mem_tiling, _),
element_type,
(m_groups, k_groups, n_groups),
(a_m_group_stride, a_k_group_stride, b_k_group_stride, _),
wgmma_params,
) = _validate_mma(a, b, swizzle)
element_bytewidth = bytewidth(element_type)
) = _validate_mma(a, b, swizzle, m_group_size=m_group_size)
if n > 256:
raise ValueError(f"N must be smaller than 256, got {n}")
if n_groups > 1:
raise ValueError("N is too big for WGMMA. Only up to 256 is supported.")
m_group_tiling = 64
k_group_tiling = swizzle // element_bytewidth
if a_in_regs:
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(
f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}"
)
if a.shape[0] % 64:
if a.shape[0] % m_group_size:
raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}")
a_m_group_stride = a_k_group_stride = None
else:
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
if m_mem_tiling != 64:
raise ValueError(f"A must have rows tiled by 64, got: {m_mem_tiling}")
a_m_group_stride = a_strides[0] * element_bytewidth # ^ One group per tile
a_k_tiles_per_group = k_group_tiling // a_k_mem_tiling
a_k_group_stride = a_strides[1] * element_bytewidth * a_k_tiles_per_group
b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset()
b_k_tiles_per_group = k_group_tiling // b_k_mem_tiling
b_k_group_stride = b_strides[0] * element_bytewidth * b_k_tiles_per_group
groups_m = m // m_group_tiling
groups_k = k // k_group_tiling
expected_acc_shape = (groups_m * 64, n)
if acc.value.shape != expected_acc_shape:
if acc.value.shape != (m, n):
raise ValueError(
f"Accumulator shape mismatch: expected {expected_acc_shape}, got"
f" {acc.value.shape}"
f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}"
)
if a_in_regs:
@ -588,12 +606,13 @@ def wgmma(
i64 = ir.IntegerType.get_signless(64)
new_acc_regs = acc.value.registers.copy()
for mi in range(groups_m):
for ki in range(groups_k):
k_group_size = k // k_groups
for mi in range(m_groups):
for ki in range(k_groups):
if a_in_regs:
a_mk = a[
mi * m_group_tiling : (mi + 1) * m_group_tiling,
ki * k_group_tiling : (ki + 1) * k_group_tiling
mi * m_group_size : (mi + 1) * m_group_size,
ki * k_group_size : (ki + 1) * k_group_size,
]
else:
a_mk = llvm_add(
@ -602,7 +621,7 @@ def wgmma(
)
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64))
new_acc_regs[mi : mi + 1] = wgmma_m64(
new_acc_regs[mi : mi + 1], a_mk, b_k, n=n, **wgmma_params
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
)
return WGMMAAccumulator(
_value=fa.FragmentedArray(