Wrap wgmma.fence in llvm.inline_asm to constrain LLVM scheduling

wgmma.fence.aligned is a weird PTX instruction in that it is one of the
few (if not the only one?) that disallows sinking ALU ops on registers
below it. But, LLVM assumes that all operations on registers are pure
and will often happily sink them below this instruction. By wrapping
the fence in an inline assembly block that simply copies over the
registers, we can force LLVM to construct the registers before the fence.
And ptxas should be able to eliminate the unnecessary register copies.

PiperOrigin-RevId: 639011288
This commit is contained in:
Adam Paszke 2024-05-31 06:07:44 -07:00 committed by jax authors
parent 33c7c8d30e
commit 41685db0cb

View File

@ -50,7 +50,7 @@ class WGMMAAccumulator:
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
self.value = _value
if _sync:
nvvm.wgmma_fence_aligned()
self._value = wgmma_fence(_value)
@classmethod
def zero(cls, m, n):
@ -226,11 +226,6 @@ def wgmma_m64k128B(
def lc(x):
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
def as_i32_reg(v):
return llvm.extractelement(
vector.bitcast(ir.VectorType.get((1,), i32), v), lc(0)
)
use_out = scale_a = scale_b = lc(1)
imms = [use_out, scale_a, scale_b]
if supports_transpose and a_transpose is not None:
@ -251,7 +246,7 @@ def wgmma_m64k128B(
# Slice out the relevant part of A or advance the A descriptor.
if a_in_regs:
a_slice = a[:, (i * 16) : ((i + 1) * 16)]
a_args = [as_i32_reg(v) for v in a_slice.registers.flat]
a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
else:
if i > 0:
a = llvm_add(
@ -277,13 +272,7 @@ def wgmma_m64k128B(
acc_regs = [
llvm.extractvalue(f32, acc_struct, [i]) for i in range(len(acc_regs))
]
acc_vec_regs = []
for first, second in zip(acc_regs[::2], acc_regs[1::2]):
vec = llvm.mlir_undef(ir.VectorType.get((2,), f32))
vec = llvm.insertelement(vec, first, position=lc(0))
vec = llvm.insertelement(vec, second, position=lc(1))
acc_vec_regs.append(vec)
return np.asarray(acc_vec_regs, dtype=object).reshape(acc.shape)
return _as_fragmented_reg_ndarray(acc_regs, f32, acc.shape)
class WGMMALayout(enum.Enum):
@ -373,7 +362,7 @@ def wgmma(
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
if a_in_regs:
nvvm.wgmma_fence_aligned() # Make sure the registers are ready.
a = wgmma_fence(a) # Make sure the registers are ready.
a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype.
else:
a_desc_base = create_descriptor(a, **a_desc_fields)
@ -410,3 +399,87 @@ def wgmma(
),
_sync=False,
)
def wgmma_fence(array: mgpu.FragmentedArray):
"""Fences the array construction from WGMMA instructions.
This is a little workaround to force LLVM to initialize the PTX registers
before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats
in-register computation as pure and can move it after the fence, which is
explicitly disallowed by the PTX programming model.
"""
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
dtype = array.mlir_dtype
src_vec_ty = ir.VectorType(array.registers.flat[0].type)
assert src_vec_ty.shape == [2]
if dtype == ir.F32Type.get():
regs = [ # pylint: disable=g-complex-comprehension
vector.extractelement(reg, position=c(pos, index))
for reg in array.registers.flat
for pos in range(2)
]
reg_dtype = dtype
reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs)
ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
regs = [_as_i32_reg(reg) for reg in array.registers.flat]
reg_dtype = i32
reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs)
ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
else:
raise NotImplementedError(dtype)
reg_constraints = ",".join(reg_constraints_list)
# Copy over the registers. ptxas should be able to remove the moves.
ptx_lines.append("wgmma.fence.sync.aligned")
ptx = ";\n".join(ptx_lines) + ";\n"
dtype_str = str(reg_dtype)
struct_ty = ir.Type.parse(
f"!llvm.struct<({','.join(dtype_str for _ in regs)})>"
)
acc_struct = llvm.inline_asm(
struct_ty, regs, ptx, reg_constraints,
asm_dialect=0, has_side_effects=True,
)
regs = [
llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs))
]
if dtype == ir.F32Type.get():
registers = _as_fragmented_reg_ndarray(
regs, array.mlir_dtype, array.registers.shape
)
elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
regs = [
vector.bitcast(
src_vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
)
for r in regs
]
registers = np.asarray(regs, dtype=object).reshape(array.registers.shape)
else:
raise NotImplementedError(dtype)
return mgpu.FragmentedArray(_registers=registers, _layout=array.layout)
def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]):
vec_regs = []
for first, second in zip(flat_regs[::2], flat_regs[1::2]):
vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
vec = llvm.insertelement(vec, first, position=_lc(0))
vec = llvm.insertelement(vec, second, position=_lc(1))
vec_regs.append(vec)
return np.asarray(vec_regs, dtype=object).reshape(shape)
def _as_i32_reg(v):
i32 = ir.IntegerType.get_signless(32)
return llvm.extractelement(
vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0)
)
def _lc(x):
i32 = ir.IntegerType.get_signless(32)
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result