mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

Thanks to the previous refactor the change is quite trivial and mostly focuses on adding tests. PiperOrigin-RevId: 733754797
438 lines
15 KiB
Python
438 lines
15 KiB
Python
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
|
|
import jax
|
|
from jaxlib.mlir import ir
|
|
from jaxlib.mlir.dialects import arith
|
|
from jaxlib.mlir.dialects import llvm
|
|
from jaxlib.mlir.dialects import nvvm
|
|
from jaxlib.mlir.dialects import vector
|
|
import numpy as np
|
|
|
|
from . import fragmented_array as fa
|
|
from . import mma_utils
|
|
from . import utils
|
|
|
|
# mypy: ignore-errors
|
|
|
|
c = utils.c
|
|
bytewidth = utils.bytewidth
|
|
|
|
|
|
@jax.tree_util.register_pytree_node_class
|
|
@dataclasses.dataclass
|
|
class WGMMAAccumulator:
|
|
"""A FragmentedArray that has is synchronized with the async proxy.
|
|
|
|
This implies that it requires no additional synchronization when passed in
|
|
as a WGMMA accumulator. In particular, when created from a
|
|
FragmentedArray, the necessary synchronization is inserted at construction.
|
|
"""
|
|
value: fa.FragmentedArray
|
|
|
|
def __init__(self, *, _value: fa.FragmentedArray, _sync: bool = True):
|
|
if _value.layout not in (fa.WGMMA_LAYOUT, fa.TILED_LAYOUT_WGMMA):
|
|
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
|
|
self.value = _value
|
|
if _sync:
|
|
self.value = wgmma_fence(_value)
|
|
|
|
@classmethod
|
|
def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None):
|
|
if m % 64 or n % 8:
|
|
raise ValueError
|
|
if is_signed is False:
|
|
raise TypeError("PTX does not support unsigned WGMMA accumulators")
|
|
f32 = ir.F32Type.get()
|
|
if dtype is None:
|
|
dtype = f32
|
|
zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0))
|
|
return cls(
|
|
_value=fa.FragmentedArray.splat(
|
|
zero, (m, n), fa.WGMMA_LAYOUT, is_signed=is_signed
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def from_registers(cls, registers):
|
|
return cls(_value=registers)
|
|
|
|
def tree_flatten(self):
|
|
return (self.value,), ()
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux, value):
|
|
del aux
|
|
return cls(_value=value[0], _sync=False)
|
|
|
|
|
|
def _supported_wgmma_types(dtype, abtype) -> bool:
|
|
input_types_are = lambda ty: ty.isinstance(abtype)
|
|
if ir.F32Type.isinstance(dtype):
|
|
return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type))
|
|
elif ir.F16Type.isinstance(dtype):
|
|
return input_types_are(ir.F16Type)
|
|
else:
|
|
return False
|
|
|
|
|
|
def wgmma_m64(
|
|
acc: np.ndarray, # of register Values
|
|
a,
|
|
b_descriptor: ir.Value,
|
|
a_transpose: bool | None,
|
|
b_transpose: bool,
|
|
a_k_stride: int | None,
|
|
b_k_stride: int,
|
|
n: int,
|
|
swizzle: int,
|
|
element_type: ir.Type,
|
|
):
|
|
out_ty = ir.VectorType(acc.flat[0].type).element_type
|
|
if not _supported_wgmma_types(out_ty, element_type):
|
|
raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}")
|
|
if n % 8:
|
|
raise ValueError
|
|
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
index = ir.IndexType.get()
|
|
if b_k_stride % 16:
|
|
raise ValueError
|
|
# Only 16-bit types support transposes
|
|
supports_transpose = bytewidth(element_type) == 2
|
|
if not supports_transpose and (a_transpose or b_transpose):
|
|
raise ValueError("Only f16 WGMMA supports transposes")
|
|
if a_in_regs := isinstance(a, fa.FragmentedArray):
|
|
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
|
|
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
|
|
# Column count must be equal to swizzle // bytewidth.
|
|
if a.layout not in (fa.TILED_LAYOUT_WGMMA, fa.WGMMA_LAYOUT) or a.shape != (64, swizzle // 2):
|
|
raise ValueError("Unsupported A register array layout")
|
|
if a_k_stride is not None or a_transpose is not None:
|
|
raise ValueError("Unsupported WGMMA features with A in registers")
|
|
else:
|
|
if a_k_stride is None or a_k_stride % 16:
|
|
raise ValueError
|
|
if a_transpose is None:
|
|
raise ValueError
|
|
|
|
if ir.F32Type.isinstance(out_ty):
|
|
num_acc_regs = n // 2
|
|
out_ty_field = out_ty
|
|
acc_regs = [ # pylint: disable=g-complex-comprehension
|
|
vector.extractelement(reg, position=c(pos, index))
|
|
for reg in acc.flat
|
|
for pos in range(2)
|
|
]
|
|
to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape)
|
|
acc_constraint = "f"
|
|
elif ir.F16Type.isinstance(out_ty):
|
|
num_acc_regs = n // 4
|
|
out_ty_field = i32
|
|
acc_regs = [_as_i32_reg(reg) for reg in acc.flat]
|
|
vec_ty = ir.VectorType(acc.flat[0].type)
|
|
to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape)
|
|
acc_constraint = "r"
|
|
else:
|
|
raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})")
|
|
|
|
num_imm_regs = 4 if supports_transpose else 2
|
|
|
|
if a_in_regs:
|
|
a_reg_constraints = ["r"] * 4 # 4x f16x2 registers
|
|
num_imm_regs -= 1 # transpose not supported for a in registers
|
|
else:
|
|
a_reg_constraints = ["l"] # descriptor
|
|
# Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
|
|
# Seems like it's not actually documented in LLVM IR docs.
|
|
reg_constraints_list = (
|
|
[f"={acc_constraint}"] * num_acc_regs # accumulator registers
|
|
+ [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too.
|
|
+ a_reg_constraints # a descriptor / registers
|
|
+ ["l"] * 1 # b descriptor
|
|
+ ["n"] * (1 + num_imm_regs) # literal constants
|
|
)
|
|
reg_constraints = ",".join(reg_constraints_list)
|
|
|
|
reg_count = itertools.count()
|
|
|
|
def take_regs(n):
|
|
return (f"${i}" for i in itertools.islice(reg_count, n))
|
|
|
|
acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}"
|
|
for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing.
|
|
pass
|
|
if a_in_regs:
|
|
a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}"
|
|
else:
|
|
a_regs, = take_regs(1)
|
|
b_desc_reg, use_out_reg = take_regs(2)
|
|
imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...).
|
|
assert next(reg_count) == len(reg_constraints_list)
|
|
el_ty = element_type
|
|
k_instr = 32 // bytewidth(element_type)
|
|
wgmma_instr = (
|
|
f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} "
|
|
f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};"
|
|
)
|
|
ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"
|
|
|
|
def lc(x):
|
|
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
|
|
|
use_out = scale_a = scale_b = lc(1)
|
|
imms = [use_out, scale_a, scale_b]
|
|
if supports_transpose and a_transpose is not None:
|
|
imms += [lc(int(a_transpose)), lc(int(b_transpose))]
|
|
elif supports_transpose:
|
|
imms += [lc(int(b_transpose))]
|
|
if acc.ndim != 4 or acc.shape[0] != 1 or acc.shape[2:] != (2, 1):
|
|
raise ValueError(acc.shape)
|
|
acc_struct_type = ir.Type.parse(
|
|
f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>"
|
|
)
|
|
for i in range((swizzle // bytewidth(element_type)) // k_instr):
|
|
# Slice out the relevant part of A or advance the A descriptor.
|
|
if a_in_regs:
|
|
a_slice = a[:, (i * 16) : ((i + 1) * 16)]
|
|
a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
|
|
else:
|
|
if i > 0:
|
|
a = _llvm_add(
|
|
a,
|
|
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
|
|
)
|
|
a_args = [a]
|
|
# Advance the B descriptor.
|
|
if i > 0:
|
|
b_descriptor = _llvm_add(
|
|
b_descriptor,
|
|
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
|
|
)
|
|
assert len(a_args) == len(a_reg_constraints)
|
|
acc_struct = llvm.inline_asm(
|
|
acc_struct_type,
|
|
[*acc_regs, *a_args, b_descriptor, *imms],
|
|
ptx,
|
|
reg_constraints,
|
|
asm_dialect=0,
|
|
has_side_effects=True,
|
|
)
|
|
acc_regs = [
|
|
llvm.extractvalue(out_ty_field, acc_struct, [i]) for i in range(len(acc_regs))
|
|
]
|
|
return to_acc_vec_regs(acc_regs)
|
|
|
|
|
|
def wgmma(
|
|
acc: WGMMAAccumulator,
|
|
a: fa.FragmentedArray | ir.Value,
|
|
b: ir.Value,
|
|
*,
|
|
swizzle: int = 128,
|
|
):
|
|
"""Perform acc += a @ b using the WGMMA instruction.
|
|
|
|
The expected memref shapes are:
|
|
a: (m, k, 64, S)
|
|
b: (k, n, S, S)
|
|
where S = swizzle // bytewidth(element_type).
|
|
|
|
The refs must be contiguous or be contiguous except for having their two minor
|
|
dimensions swapped.
|
|
"""
|
|
# Step 1. Establish the shape and element type of the operation.
|
|
if not ir.MemRefType.isinstance(b.type):
|
|
raise ValueError(f"B must be a memref, got: {b.type}")
|
|
(k, n), element_type = mma_utils.tiled_memref_shape(b)
|
|
if a_in_regs := isinstance(a, fa.FragmentedArray):
|
|
m, k2 = a.shape
|
|
element_type2 = a.mlir_dtype
|
|
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
|
|
raise ValueError(
|
|
f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}"
|
|
)
|
|
elif ir.MemRefType.isinstance(a.type):
|
|
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
|
|
else:
|
|
raise ValueError(f"Unsupported A type: {type(a)}")
|
|
if k != k2:
|
|
raise ValueError(
|
|
"WGMMA requires A and B to have the same contraction dimension (K),"
|
|
f" got: {k2} and {k}"
|
|
)
|
|
if element_type != element_type2:
|
|
raise ValueError(
|
|
"WGMMA requires A and B to have the same element type, got:"
|
|
f" {element_type2} and {element_type}"
|
|
)
|
|
if acc.value.shape != (m, n):
|
|
raise ValueError(
|
|
f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}"
|
|
)
|
|
f32 = ir.F32Type.get()
|
|
if element_type == f32 or element_type == ir.BF16Type.get():
|
|
if acc.value.mlir_dtype != f32:
|
|
raise ValueError(
|
|
f"WGMMA with element type {element_type} only supports accumulators"
|
|
f" of type f32, but got: {acc.value.mlir_dtype}"
|
|
)
|
|
elif element_type == ir.F16Type.get():
|
|
if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32:
|
|
raise ValueError(
|
|
"WGMMA with element type f16 only supports accumulators of type f32"
|
|
f" or f16, but got: {acc.value.mlir_dtype}"
|
|
)
|
|
|
|
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
|
|
# instructions must be issued in groups of the same width as the swizzle.
|
|
m_group_elems = 64 # Hopper has a fixed M instruction shape.
|
|
k_group_elems = swizzle // utils.bytewidth(element_type)
|
|
if n > 256 or n % 8:
|
|
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
|
|
n_group_elems = n # We assume only one N group below.
|
|
if m % m_group_elems:
|
|
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
|
|
if k % k_group_elems:
|
|
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
|
|
m_groups = m // m_group_elems
|
|
k_groups = k // k_group_elems
|
|
# TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA.
|
|
wgmma_element_type = (
|
|
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
|
|
)
|
|
|
|
# Step 3. Compute the operand descriptors.
|
|
if a_in_regs:
|
|
a_desc_base = a_m_group_stride = a_k_group_stride = None
|
|
a_instr_params = dict(a_transpose=None, a_k_stride=None)
|
|
else:
|
|
(
|
|
(a_desc_base, a_k_instr_stride),
|
|
(a_m_group_stride, a_k_group_stride),
|
|
a_fastest,
|
|
) = mma_utils.create_descriptor(
|
|
a,
|
|
swizzle=swizzle,
|
|
large_tile=(m_group_elems, k_group_elems),
|
|
group_size=(m_group_elems, k_group_elems),
|
|
logical_k_major=False,
|
|
)
|
|
a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K,
|
|
a_k_stride=a_k_instr_stride)
|
|
(
|
|
(b_desc_base, b_k_instr_stride),
|
|
(b_n_group_stride, b_k_group_stride),
|
|
b_fastest,
|
|
) = mma_utils.create_descriptor(
|
|
b,
|
|
swizzle=swizzle,
|
|
large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n.
|
|
group_size=(k_group_elems, n_group_elems),
|
|
logical_k_major=True,
|
|
)
|
|
del b_n_group_stride # We only support one N group.
|
|
|
|
# Step 4. Issue the instructions.
|
|
if a_in_regs:
|
|
a = wgmma_fence(a) # Make sure the registers are ready.
|
|
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
new_acc_regs = acc.value.registers.copy()
|
|
for mi in range(m_groups):
|
|
for ki in range(k_groups):
|
|
if a_in_regs:
|
|
a_mk = a[
|
|
mi * m_group_elems : (mi + 1) * m_group_elems,
|
|
ki * k_group_elems : (ki + 1) * k_group_elems,
|
|
]
|
|
else:
|
|
a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
|
a_mk = _llvm_add(
|
|
a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64),
|
|
)
|
|
b_k = _llvm_add(
|
|
b_desc_base, c(mma_utils.encode_addr(ki * b_k_group_stride), i64)
|
|
)
|
|
new_acc_regs[mi : mi + 1] = wgmma_m64(
|
|
new_acc_regs[mi : mi + 1],
|
|
a_mk,
|
|
b_k,
|
|
swizzle=swizzle,
|
|
n=n_group_elems,
|
|
element_type=wgmma_element_type,
|
|
b_transpose=b_fastest != mma_utils.Dim.K,
|
|
b_k_stride=b_k_instr_stride,
|
|
**a_instr_params,
|
|
)
|
|
return WGMMAAccumulator(
|
|
_value=fa.FragmentedArray(
|
|
_registers=new_acc_regs,
|
|
_layout=fa.WGMMA_LAYOUT,
|
|
_is_signed=acc.value.is_signed,
|
|
),
|
|
_sync=False,
|
|
)
|
|
|
|
|
|
def wgmma_fence(array: fa.FragmentedArray):
|
|
"""Fences the array construction from WGMMA instructions.
|
|
|
|
LLVM treats in-register computation as pure and can move it after the fence,
|
|
which is explicitly disallowed by the PTX programming model. For that reason,
|
|
we insert an LLVM optimization barrier before the fence.
|
|
"""
|
|
array = fa.optimization_barrier(array)
|
|
nvvm.wgmma_fence_aligned()
|
|
return array
|
|
|
|
|
|
def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]):
|
|
vec_regs = []
|
|
for first, second in zip(flat_regs[::2], flat_regs[1::2]):
|
|
vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
|
|
vec = llvm.insertelement(vec, first, position=_lc(0))
|
|
vec = llvm.insertelement(vec, second, position=_lc(1))
|
|
vec_regs.append(vec)
|
|
return np.asarray(vec_regs, dtype=object).reshape(shape)
|
|
|
|
|
|
def _as_i32_reg(v):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
return llvm.extractelement(
|
|
vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0)
|
|
)
|
|
|
|
|
|
def _lc(x):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
|
|
|
|
|
def _llvm_add(x, y):
|
|
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
|
|
|
|
|
|
def _unpack_i32(vec_ty, r):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
return vector.bitcast(
|
|
vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
|
|
)
|