mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU][NFC] Move the calculation of group strides into _validate_mma
This allows us to unify this logic between Hopper and Blackwell. PiperOrigin-RevId: 732862875
This commit is contained in:
parent
bbadf99054
commit
11e6cfbc6a
@ -106,83 +106,61 @@ def mma(
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
if a_swizzle != b_swizzle:
|
||||
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
|
||||
swizzle = a_swizzle
|
||||
if isinstance(accumulate, bool):
|
||||
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
|
||||
num_cta = 2 if collective else 1
|
||||
|
||||
m_group_size = d.layout.elements_in_tile[0]
|
||||
if m_group_size != 128:
|
||||
raise NotImplementedError("Only 128-row accumulators supported for now")
|
||||
|
||||
(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_mem_tiling, a_k_mem_tiling, b_k_mem_tiling, n_mem_tiling),
|
||||
element_type,
|
||||
(m_groups, k_groups, n_groups),
|
||||
(a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride),
|
||||
mma_params,
|
||||
) = _wgmma._validate_mma(
|
||||
a,
|
||||
b,
|
||||
a_swizzle,
|
||||
m_group_size=m_group_size,
|
||||
descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
|
||||
)
|
||||
element_bytewidth = utils.bytewidth(element_type)
|
||||
|
||||
k_group_tiling = swizzle // element_bytewidth
|
||||
|
||||
if (m_group_tiling := d.layout.elements_in_tile[0]) != m_mem_tiling:
|
||||
raise ValueError(
|
||||
f"A's row tiling must be equal to {m_group_tiling} (inferred from"
|
||||
f" accumulator's TMEM layout), got: {m_mem_tiling}"
|
||||
n_group_size = n // n_groups
|
||||
if n > 512:
|
||||
raise ValueError(f"N is too big: at most 512 is supported, but got {n}")
|
||||
num_cta = 2 if collective else 1
|
||||
if num_cta == 2 and n > 256:
|
||||
raise NotImplementedError(
|
||||
"N is too big for collective MMA. Only up to 256 is supported."
|
||||
)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_m_group_stride = a_strides[0] * element_bytewidth
|
||||
a_k_tiles_per_group = k_group_tiling // a_k_mem_tiling
|
||||
a_k_group_stride = a_strides[1] * element_bytewidth * a_k_tiles_per_group
|
||||
|
||||
if n * num_cta <= 256:
|
||||
n_group_tiling = n
|
||||
elif n * num_cta == 512:
|
||||
if collective:
|
||||
raise NotImplementedError("Collective MMA with effective N=512 is unsupported")
|
||||
n_group_tiling = 256
|
||||
else:
|
||||
raise NotImplementedError("The only supported N larger than 256 is 512")
|
||||
|
||||
b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset()
|
||||
b_k_tiles_per_group = k_group_tiling // b_k_mem_tiling
|
||||
b_k_group_stride = b_strides[0] * element_bytewidth * b_k_tiles_per_group
|
||||
n_tiles_per_group = n_group_tiling // n_mem_tiling
|
||||
b_n_group_stride = b_strides[1] * element_bytewidth * n_tiles_per_group
|
||||
|
||||
groups_k = k // k_group_tiling
|
||||
groups_m = m // m_group_tiling
|
||||
groups_n = n // n_group_tiling
|
||||
|
||||
# TODO(apaszke): Verify that the cluster shape matches the expectation of
|
||||
# collective MMA.
|
||||
expected_acc_shape = (m, n * (2 if collective else 1))
|
||||
expected_acc_shape = (m, n * num_cta)
|
||||
if d.shape != expected_acc_shape:
|
||||
raise ValueError(
|
||||
f"Accumulator shape mismatch: expected {expected_acc_shape}, got {d.shape}"
|
||||
)
|
||||
|
||||
true = arith.constant(ir.IntegerType.get_signless(1), 1)
|
||||
for mi, ni, ki in np.ndindex(groups_m, groups_n, groups_k):
|
||||
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
|
||||
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
||||
a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64))
|
||||
b_offset = ni * b_n_group_stride + ki * b_k_group_stride
|
||||
b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64))
|
||||
if groups_m != 1:
|
||||
if m_groups != 1:
|
||||
raise NotImplementedError("D needs to be sliced")
|
||||
acc = accumulate if ki == 0 else true
|
||||
_do_mma(
|
||||
d.slice(
|
||||
slice(None), utils.ds(ni * n_group_tiling, n_group_tiling)
|
||||
slice(None), utils.ds(ni * n_group_size, n_group_size)
|
||||
).address,
|
||||
a_mk,
|
||||
b_nk,
|
||||
d_type=ir.F32Type.get(),
|
||||
m=m_group_tiling,
|
||||
n=n_group_tiling,
|
||||
m=m_group_size,
|
||||
collective=collective,
|
||||
**mma_params,
|
||||
accumulate=acc,
|
||||
|
@ -306,6 +306,7 @@ def _validate_mma(
|
||||
a: Any,
|
||||
b: ir.Value,
|
||||
swizzle: int,
|
||||
m_group_size: int, # The M used by a single instruction.
|
||||
descriptor_const_init: int = 0,
|
||||
):
|
||||
# We need swizzle >= 32 to ensure that our K tiling is larger than the MMA
|
||||
@ -399,27 +400,49 @@ def _validate_mma(
|
||||
else:
|
||||
raise ValueError(b_byte_strides)
|
||||
|
||||
if n > 256 and n % 256:
|
||||
raise ValueError(
|
||||
f"N group size must be a multiple of 256 when larger than 256, got: {n}"
|
||||
)
|
||||
k_group_size = swizzle_elems
|
||||
n_group_size = min(n, 256)
|
||||
b_k_tiles_per_group = k_group_size // b_k_tiling
|
||||
b_k_group_stride = b_k_byte_stride * b_k_tiles_per_group
|
||||
n_tiles_per_group = n_group_size // n_tiling
|
||||
b_n_group_stride = b_n_byte_stride * n_tiles_per_group
|
||||
|
||||
# Verify the shape and strides of A are as expected.
|
||||
if not a_in_smem:
|
||||
m = a_shape[0]
|
||||
a_order = m_tiling = a_k_tiling = None
|
||||
a_order = a_m_group_stride = a_k_group_stride = None
|
||||
else:
|
||||
a_ty = ir.MemRefType(a.type)
|
||||
m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape
|
||||
m = m_tiles * m_tiling
|
||||
# TODO(apaszke): I'm not actually convinced that we need this check.
|
||||
if m_tiling != m_group_size:
|
||||
raise ValueError(
|
||||
f"A's row tiling must be equal to {m_group_size}, got: {m_tiling}"
|
||||
)
|
||||
if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k:
|
||||
raise ValueError(a_ty.shape)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_byte_strides = [s * element_bytewidth for s in a_strides]
|
||||
if a_byte_strides[2:] == [swizzle, element_bytewidth]:
|
||||
a_m_byte_stride, a_k_byte_stride, *a_tile_byte_strides = [
|
||||
s * element_bytewidth for s in a_strides
|
||||
]
|
||||
if a_tile_byte_strides == [swizzle, element_bytewidth]:
|
||||
a_order = WGMMALayout.ROW_MAJOR
|
||||
elif a_byte_strides[2:] == [element_bytewidth, swizzle]:
|
||||
elif a_tile_byte_strides == [element_bytewidth, swizzle]:
|
||||
a_order = WGMMALayout.COL_MAJOR
|
||||
else:
|
||||
raise ValueError(a_byte_strides)
|
||||
raise ValueError(a_strides)
|
||||
if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems:
|
||||
# Not sure what the layout is like, since the tiles aren't square.
|
||||
raise NotImplementedError
|
||||
a_m_tiles_per_group = m_group_size // m_tiling
|
||||
a_m_group_stride = a_m_byte_stride * a_m_tiles_per_group
|
||||
a_k_tiles_per_group = k_group_size // a_k_tiling
|
||||
a_k_group_stride = a_k_byte_stride * a_k_tiles_per_group
|
||||
|
||||
b_k_fastest = b_order == WGMMALayout.COL_MAJOR
|
||||
a_k_fastest = a_order == WGMMALayout.ROW_MAJOR
|
||||
@ -488,6 +511,7 @@ def _validate_mma(
|
||||
a_k_stride=32 if a_k_fastest else swizzle * 16,
|
||||
b_k_stride=b_k_wgmma_stride,
|
||||
swizzle=swizzle,
|
||||
n=n_group_size,
|
||||
element_type=ir.FloatTF32Type.get()
|
||||
if ir.F32Type.isinstance(element_type)
|
||||
else element_type,
|
||||
@ -503,12 +527,23 @@ def _validate_mma(
|
||||
b, **b_desc_fields, const_init=descriptor_const_init
|
||||
)
|
||||
|
||||
if m % m_group_size:
|
||||
raise ValueError(f"m must be a multiple of {m_group_size}, got: {m}")
|
||||
m_groups = m // m_group_size
|
||||
if k % k_group_size:
|
||||
raise ValueError(f"k must be a multiple of {k_group_size}, got: {k}")
|
||||
k_groups = k // k_group_size
|
||||
if n % n_group_size:
|
||||
raise ValueError(f"n must be a multiple of {n_group_size}, got: {n}")
|
||||
n_groups = n // n_group_size
|
||||
|
||||
return (
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_tiling, a_k_tiling, b_k_tiling, n_tiling),
|
||||
element_type,
|
||||
(m_groups, k_groups, n_groups),
|
||||
# Group strides are always in bytes!
|
||||
(a_m_group_stride, a_k_group_stride, b_k_group_stride, b_n_group_stride),
|
||||
wgmma_params,
|
||||
)
|
||||
|
||||
@ -538,49 +573,32 @@ def wgmma(
|
||||
if not ir.MemRefType.isinstance(b.type):
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
|
||||
m_group_size = 64 # Hopper has a fixed M instruction shape.
|
||||
|
||||
(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_mem_tiling, a_k_mem_tiling, b_k_mem_tiling, _),
|
||||
element_type,
|
||||
(m_groups, k_groups, n_groups),
|
||||
(a_m_group_stride, a_k_group_stride, b_k_group_stride, _),
|
||||
wgmma_params,
|
||||
) = _validate_mma(a, b, swizzle)
|
||||
element_bytewidth = bytewidth(element_type)
|
||||
) = _validate_mma(a, b, swizzle, m_group_size=m_group_size)
|
||||
|
||||
if n > 256:
|
||||
raise ValueError(f"N must be smaller than 256, got {n}")
|
||||
if n_groups > 1:
|
||||
raise ValueError("N is too big for WGMMA. Only up to 256 is supported.")
|
||||
|
||||
m_group_tiling = 64
|
||||
k_group_tiling = swizzle // element_bytewidth
|
||||
if a_in_regs:
|
||||
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
|
||||
raise ValueError(
|
||||
f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}"
|
||||
)
|
||||
if a.shape[0] % 64:
|
||||
if a.shape[0] % m_group_size:
|
||||
raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}")
|
||||
a_m_group_stride = a_k_group_stride = None
|
||||
else:
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
if m_mem_tiling != 64:
|
||||
raise ValueError(f"A must have rows tiled by 64, got: {m_mem_tiling}")
|
||||
a_m_group_stride = a_strides[0] * element_bytewidth # ^ One group per tile
|
||||
a_k_tiles_per_group = k_group_tiling // a_k_mem_tiling
|
||||
a_k_group_stride = a_strides[1] * element_bytewidth * a_k_tiles_per_group
|
||||
|
||||
b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset()
|
||||
b_k_tiles_per_group = k_group_tiling // b_k_mem_tiling
|
||||
b_k_group_stride = b_strides[0] * element_bytewidth * b_k_tiles_per_group
|
||||
|
||||
groups_m = m // m_group_tiling
|
||||
groups_k = k // k_group_tiling
|
||||
|
||||
expected_acc_shape = (groups_m * 64, n)
|
||||
if acc.value.shape != expected_acc_shape:
|
||||
if acc.value.shape != (m, n):
|
||||
raise ValueError(
|
||||
f"Accumulator shape mismatch: expected {expected_acc_shape}, got"
|
||||
f" {acc.value.shape}"
|
||||
f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}"
|
||||
)
|
||||
|
||||
if a_in_regs:
|
||||
@ -588,12 +606,13 @@ def wgmma(
|
||||
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
new_acc_regs = acc.value.registers.copy()
|
||||
for mi in range(groups_m):
|
||||
for ki in range(groups_k):
|
||||
k_group_size = k // k_groups
|
||||
for mi in range(m_groups):
|
||||
for ki in range(k_groups):
|
||||
if a_in_regs:
|
||||
a_mk = a[
|
||||
mi * m_group_tiling : (mi + 1) * m_group_tiling,
|
||||
ki * k_group_tiling : (ki + 1) * k_group_tiling
|
||||
mi * m_group_size : (mi + 1) * m_group_size,
|
||||
ki * k_group_size : (ki + 1) * k_group_size,
|
||||
]
|
||||
else:
|
||||
a_mk = llvm_add(
|
||||
@ -602,7 +621,7 @@ def wgmma(
|
||||
)
|
||||
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64))
|
||||
new_acc_regs[mi : mi + 1] = wgmma_m64(
|
||||
new_acc_regs[mi : mi + 1], a_mk, b_k, n=n, **wgmma_params
|
||||
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
|
||||
)
|
||||
return WGMMAAccumulator(
|
||||
_value=fa.FragmentedArray(
|
||||
|
Loading…
x
Reference in New Issue
Block a user