Benjamin Chetioui 1e36cbe597 [Mosaic GPU] Raise a NotImplementedError if swizzle=16.
Unswizzled MMAs don't lower correctly, and are not currently intended to be
supported.

PiperOrigin-RevId: 737981373
2025-03-18 06:29:13 -07:00

441 lines
15 KiB
Python

# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dataclasses
import functools
import itertools
import math
import jax
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import vector
import numpy as np
from . import fragmented_array as fa
from . import mma_utils
from . import utils
# mypy: ignore-errors
c = utils.c
bytewidth = utils.bytewidth
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass
class WGMMAAccumulator:
"""A FragmentedArray that has is synchronized with the async proxy.
This implies that it requires no additional synchronization when passed in
as a WGMMA accumulator. In particular, when created from a
FragmentedArray, the necessary synchronization is inserted at construction.
"""
value: fa.FragmentedArray
def __init__(self, *, _value: fa.FragmentedArray, _sync: bool = True):
if _value.layout != fa.WGMMA_LAYOUT:
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
self.value = _value
if _sync:
self.value = wgmma_fence(_value)
@classmethod
def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None):
if m % 64 or n % 8:
raise ValueError
if is_signed is False:
raise TypeError("PTX does not support unsigned WGMMA accumulators")
f32 = ir.F32Type.get()
if dtype is None:
dtype = f32
zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0))
return cls(
_value=fa.FragmentedArray.splat(
zero, (m, n), fa.WGMMA_LAYOUT, is_signed=is_signed
)
)
@classmethod
def from_registers(cls, registers):
return cls(_value=registers)
def tree_flatten(self):
return (self.value,), ()
@classmethod
def tree_unflatten(cls, aux, value):
del aux
return cls(_value=value[0], _sync=False)
def _supported_wgmma_types(dtype, abtype) -> bool:
input_types_are = lambda ty: ty.isinstance(abtype)
if ir.F32Type.isinstance(dtype):
return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type))
elif ir.F16Type.isinstance(dtype):
return input_types_are(ir.F16Type)
else:
return False
def wgmma_m64(
acc: np.ndarray, # of register Values
a,
b_descriptor: ir.Value,
a_transpose: bool | None,
b_transpose: bool,
a_k_stride: int | None,
b_k_stride: int,
n: int,
swizzle: int,
element_type: ir.Type,
):
out_ty = ir.VectorType(acc.flat[0].type).element_type
if not _supported_wgmma_types(out_ty, element_type):
raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}")
if n % 8:
raise ValueError
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
index = ir.IndexType.get()
if b_k_stride % 16:
raise ValueError
# Only 16-bit types support transposes
supports_transpose = bytewidth(element_type) == 2
if not supports_transpose and (a_transpose or b_transpose):
raise ValueError("Only f16 WGMMA supports transposes")
if a_in_regs := isinstance(a, fa.FragmentedArray):
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
# Column count must be equal to swizzle // bytewidth.
if a.layout != fa.WGMMA_LAYOUT or a.shape != (64, swizzle // 2):
raise ValueError("Unsupported A register array layout")
if a_k_stride is not None or a_transpose is not None:
raise ValueError("Unsupported WGMMA features with A in registers")
else:
if a_k_stride is None or a_k_stride % 16:
raise ValueError
if a_transpose is None:
raise ValueError
if ir.F32Type.isinstance(out_ty):
num_acc_regs = n // 2
out_ty_field = out_ty
acc_regs = [ # pylint: disable=g-complex-comprehension
vector.extractelement(reg, position=c(pos, index))
for reg in acc.flat
for pos in range(2)
]
to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape)
acc_constraint = "f"
elif ir.F16Type.isinstance(out_ty):
num_acc_regs = n // 4
out_ty_field = i32
acc_regs = [_as_i32_reg(reg) for reg in acc.flat]
vec_ty = ir.VectorType(acc.flat[0].type)
to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape)
acc_constraint = "r"
else:
raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})")
num_imm_regs = 4 if supports_transpose else 2
if a_in_regs:
a_reg_constraints = ["r"] * 4 # 4x f16x2 registers
num_imm_regs -= 1 # transpose not supported for a in registers
else:
a_reg_constraints = ["l"] # descriptor
# Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
# Seems like it's not actually documented in LLVM IR docs.
reg_constraints_list = (
[f"={acc_constraint}"] * num_acc_regs # accumulator registers
+ [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too.
+ a_reg_constraints # a descriptor / registers
+ ["l"] * 1 # b descriptor
+ ["n"] * (1 + num_imm_regs) # literal constants
)
reg_constraints = ",".join(reg_constraints_list)
reg_count = itertools.count()
def take_regs(n):
return (f"${i}" for i in itertools.islice(reg_count, n))
acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}"
for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing.
pass
if a_in_regs:
a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}"
else:
a_regs, = take_regs(1)
b_desc_reg, use_out_reg = take_regs(2)
imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...).
assert next(reg_count) == len(reg_constraints_list)
el_ty = element_type
k_instr = 32 // bytewidth(element_type)
wgmma_instr = (
f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} "
f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};"
)
ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"
def lc(x):
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
use_out = scale_a = scale_b = lc(1)
imms = [use_out, scale_a, scale_b]
if supports_transpose and a_transpose is not None:
imms += [lc(int(a_transpose)), lc(int(b_transpose))]
elif supports_transpose:
imms += [lc(int(b_transpose))]
if acc.ndim != 10 or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != 2:
raise ValueError(acc.shape)
acc_struct_type = ir.Type.parse(
f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>"
)
for i in range((swizzle // bytewidth(element_type)) // k_instr):
# Slice out the relevant part of A or advance the A descriptor.
if a_in_regs:
a_slice = a[:, (i * 16) : ((i + 1) * 16)]
a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
else:
if i > 0:
a = _llvm_add(
a,
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
)
a_args = [a]
# Advance the B descriptor.
if i > 0:
b_descriptor = _llvm_add(
b_descriptor,
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
)
assert len(a_args) == len(a_reg_constraints)
acc_struct = llvm.inline_asm(
acc_struct_type,
[*acc_regs, *a_args, b_descriptor, *imms],
ptx,
reg_constraints,
asm_dialect=0,
has_side_effects=True,
)
acc_regs = [
llvm.extractvalue(out_ty_field, acc_struct, [i]) for i in range(len(acc_regs))
]
return to_acc_vec_regs(acc_regs)
def wgmma(
acc: WGMMAAccumulator,
a: fa.FragmentedArray | ir.Value,
b: ir.Value,
*,
swizzle: int = 128,
):
"""Perform acc += a @ b using the WGMMA instruction.
The expected memref shapes are:
a: (m, k, 64, S)
b: (k, n, S, S)
where S = swizzle // bytewidth(element_type).
The refs must be contiguous or be contiguous except for having their two minor
dimensions swapped.
"""
if swizzle == 16:
raise NotImplementedError("No swizzle is not supported")
# Step 1. Establish the shape and element type of the operation.
if not ir.MemRefType.isinstance(b.type):
raise ValueError(f"B must be a memref, got: {b.type}")
(k, n), element_type = mma_utils.tiled_memref_shape(b)
if a_in_regs := isinstance(a, fa.FragmentedArray):
m, k2 = a.shape
element_type2 = a.mlir_dtype
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(
f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}"
)
elif ir.MemRefType.isinstance(a.type):
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
else:
raise ValueError(f"Unsupported A type: {type(a)}")
if k != k2:
raise ValueError(
"WGMMA requires A and B to have the same contraction dimension (K),"
f" got: {k2} and {k}"
)
if element_type != element_type2:
raise ValueError(
"WGMMA requires A and B to have the same element type, got:"
f" {element_type2} and {element_type}"
)
if acc.value.shape != (m, n):
raise ValueError(
f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}"
)
f32 = ir.F32Type.get()
if element_type == f32 or element_type == ir.BF16Type.get():
if acc.value.mlir_dtype != f32:
raise ValueError(
f"WGMMA with element type {element_type} only supports accumulators"
f" of type f32, but got: {acc.value.mlir_dtype}"
)
elif element_type == ir.F16Type.get():
if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32:
raise ValueError(
"WGMMA with element type f16 only supports accumulators of type f32"
f" or f16, but got: {acc.value.mlir_dtype}"
)
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
# instructions must be issued in groups of the same width as the swizzle.
m_group_elems = 64 # Hopper has a fixed M instruction shape.
k_group_elems = swizzle // utils.bytewidth(element_type)
if n > 256 or n % 8:
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
n_group_elems = n # We assume only one N group below.
if m % m_group_elems:
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
if k % k_group_elems:
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
m_groups = m // m_group_elems
k_groups = k // k_group_elems
# TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA.
wgmma_element_type = (
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
)
# Step 3. Compute the operand descriptors.
if a_in_regs:
a_desc_base = a_m_group_stride = a_k_group_stride = None
a_instr_params = dict(a_transpose=None, a_k_stride=None)
else:
(
(a_desc_base, a_k_instr_stride),
(a_m_group_stride, a_k_group_stride),
a_fastest,
) = mma_utils.create_descriptor(
a,
swizzle=swizzle,
large_tile=(m_group_elems, k_group_elems),
group_size=(m_group_elems, k_group_elems),
logical_k_major=False,
)
a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K,
a_k_stride=a_k_instr_stride)
(
(b_desc_base, b_k_instr_stride),
(b_n_group_stride, b_k_group_stride),
b_fastest,
) = mma_utils.create_descriptor(
b,
swizzle=swizzle,
large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n.
group_size=(k_group_elems, n_group_elems),
logical_k_major=True,
)
del b_n_group_stride # We only support one N group.
# Step 4. Issue the instructions.
if a_in_regs:
a = wgmma_fence(a) # Make sure the registers are ready.
i64 = ir.IntegerType.get_signless(64)
new_acc_regs = acc.value.registers.copy()
for mi in range(m_groups):
for ki in range(k_groups):
if a_in_regs:
a_mk = a[
mi * m_group_elems : (mi + 1) * m_group_elems,
ki * k_group_elems : (ki + 1) * k_group_elems,
]
else:
a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride
a_mk = _llvm_add(
a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64),
)
b_k = _llvm_add(
b_desc_base, c(mma_utils.encode_addr(ki * b_k_group_stride), i64)
)
new_acc_regs[mi : mi + 1] = wgmma_m64(
new_acc_regs[mi : mi + 1],
a_mk,
b_k,
swizzle=swizzle,
n=n_group_elems,
element_type=wgmma_element_type,
b_transpose=b_fastest != mma_utils.Dim.K,
b_k_stride=b_k_instr_stride,
**a_instr_params,
)
return WGMMAAccumulator(
_value=fa.FragmentedArray(
_registers=new_acc_regs,
_layout=fa.WGMMA_LAYOUT,
_is_signed=acc.value.is_signed,
),
_sync=False,
)
def wgmma_fence(array: fa.FragmentedArray):
"""Fences the array construction from WGMMA instructions.
LLVM treats in-register computation as pure and can move it after the fence,
which is explicitly disallowed by the PTX programming model. For that reason,
we insert an LLVM optimization barrier before the fence.
"""
array = fa.optimization_barrier(array)
nvvm.wgmma_fence_aligned()
return array
def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]):
vec_regs = []
for first, second in zip(flat_regs[::2], flat_regs[1::2]):
vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
vec = llvm.insertelement(vec, first, position=_lc(0))
vec = llvm.insertelement(vec, second, position=_lc(1))
vec_regs.append(vec)
return np.asarray(vec_regs, dtype=object).reshape(shape)
def _as_i32_reg(v):
i32 = ir.IntegerType.get_signless(32)
return llvm.extractelement(
vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0)
)
def _lc(x):
i32 = ir.IntegerType.get_signless(32)
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
def _llvm_add(x, y):
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
def _unpack_i32(vec_ty, r):
i32 = ir.IntegerType.get_signless(32)
return vector.bitcast(
vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
)