mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Wrap wgmma.fence in llvm.inline_asm to constrain LLVM scheduling
wgmma.fence.aligned is a weird PTX instruction in that it is one of the few (if not the only one?) that disallows sinking ALU ops on registers below it. But, LLVM assumes that all operations on registers are pure and will often happily sink them below this instruction. By wrapping the fence in an inline assembly block that simply copies over the registers, we can force LLVM to construct the registers before the fence. And ptxas should be able to eliminate the unnecessary register copies. PiperOrigin-RevId: 639011288
This commit is contained in:
parent
33c7c8d30e
commit
41685db0cb
@ -50,7 +50,7 @@ class WGMMAAccumulator:
|
||||
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
|
||||
self.value = _value
|
||||
if _sync:
|
||||
nvvm.wgmma_fence_aligned()
|
||||
self._value = wgmma_fence(_value)
|
||||
|
||||
@classmethod
|
||||
def zero(cls, m, n):
|
||||
@ -226,11 +226,6 @@ def wgmma_m64k128B(
|
||||
def lc(x):
|
||||
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
||||
|
||||
def as_i32_reg(v):
|
||||
return llvm.extractelement(
|
||||
vector.bitcast(ir.VectorType.get((1,), i32), v), lc(0)
|
||||
)
|
||||
|
||||
use_out = scale_a = scale_b = lc(1)
|
||||
imms = [use_out, scale_a, scale_b]
|
||||
if supports_transpose and a_transpose is not None:
|
||||
@ -251,7 +246,7 @@ def wgmma_m64k128B(
|
||||
# Slice out the relevant part of A or advance the A descriptor.
|
||||
if a_in_regs:
|
||||
a_slice = a[:, (i * 16) : ((i + 1) * 16)]
|
||||
a_args = [as_i32_reg(v) for v in a_slice.registers.flat]
|
||||
a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
|
||||
else:
|
||||
if i > 0:
|
||||
a = llvm_add(
|
||||
@ -277,13 +272,7 @@ def wgmma_m64k128B(
|
||||
acc_regs = [
|
||||
llvm.extractvalue(f32, acc_struct, [i]) for i in range(len(acc_regs))
|
||||
]
|
||||
acc_vec_regs = []
|
||||
for first, second in zip(acc_regs[::2], acc_regs[1::2]):
|
||||
vec = llvm.mlir_undef(ir.VectorType.get((2,), f32))
|
||||
vec = llvm.insertelement(vec, first, position=lc(0))
|
||||
vec = llvm.insertelement(vec, second, position=lc(1))
|
||||
acc_vec_regs.append(vec)
|
||||
return np.asarray(acc_vec_regs, dtype=object).reshape(acc.shape)
|
||||
return _as_fragmented_reg_ndarray(acc_regs, f32, acc.shape)
|
||||
|
||||
|
||||
class WGMMALayout(enum.Enum):
|
||||
@ -373,7 +362,7 @@ def wgmma(
|
||||
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
|
||||
|
||||
if a_in_regs:
|
||||
nvvm.wgmma_fence_aligned() # Make sure the registers are ready.
|
||||
a = wgmma_fence(a) # Make sure the registers are ready.
|
||||
a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype.
|
||||
else:
|
||||
a_desc_base = create_descriptor(a, **a_desc_fields)
|
||||
@ -410,3 +399,87 @@ def wgmma(
|
||||
),
|
||||
_sync=False,
|
||||
)
|
||||
|
||||
|
||||
def wgmma_fence(array: mgpu.FragmentedArray):
|
||||
"""Fences the array construction from WGMMA instructions.
|
||||
|
||||
This is a little workaround to force LLVM to initialize the PTX registers
|
||||
before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats
|
||||
in-register computation as pure and can move it after the fence, which is
|
||||
explicitly disallowed by the PTX programming model.
|
||||
"""
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
index = ir.IndexType.get()
|
||||
dtype = array.mlir_dtype
|
||||
src_vec_ty = ir.VectorType(array.registers.flat[0].type)
|
||||
assert src_vec_ty.shape == [2]
|
||||
|
||||
if dtype == ir.F32Type.get():
|
||||
regs = [ # pylint: disable=g-complex-comprehension
|
||||
vector.extractelement(reg, position=c(pos, index))
|
||||
for reg in array.registers.flat
|
||||
for pos in range(2)
|
||||
]
|
||||
reg_dtype = dtype
|
||||
reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs)
|
||||
ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
|
||||
elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
|
||||
regs = [_as_i32_reg(reg) for reg in array.registers.flat]
|
||||
reg_dtype = i32
|
||||
reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs)
|
||||
ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
|
||||
else:
|
||||
raise NotImplementedError(dtype)
|
||||
reg_constraints = ",".join(reg_constraints_list)
|
||||
# Copy over the registers. ptxas should be able to remove the moves.
|
||||
ptx_lines.append("wgmma.fence.sync.aligned")
|
||||
ptx = ";\n".join(ptx_lines) + ";\n"
|
||||
dtype_str = str(reg_dtype)
|
||||
struct_ty = ir.Type.parse(
|
||||
f"!llvm.struct<({','.join(dtype_str for _ in regs)})>"
|
||||
)
|
||||
acc_struct = llvm.inline_asm(
|
||||
struct_ty, regs, ptx, reg_constraints,
|
||||
asm_dialect=0, has_side_effects=True,
|
||||
)
|
||||
regs = [
|
||||
llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs))
|
||||
]
|
||||
if dtype == ir.F32Type.get():
|
||||
registers = _as_fragmented_reg_ndarray(
|
||||
regs, array.mlir_dtype, array.registers.shape
|
||||
)
|
||||
elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
|
||||
regs = [
|
||||
vector.bitcast(
|
||||
src_vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
|
||||
)
|
||||
for r in regs
|
||||
]
|
||||
registers = np.asarray(regs, dtype=object).reshape(array.registers.shape)
|
||||
else:
|
||||
raise NotImplementedError(dtype)
|
||||
return mgpu.FragmentedArray(_registers=registers, _layout=array.layout)
|
||||
|
||||
|
||||
def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]):
|
||||
vec_regs = []
|
||||
for first, second in zip(flat_regs[::2], flat_regs[1::2]):
|
||||
vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
|
||||
vec = llvm.insertelement(vec, first, position=_lc(0))
|
||||
vec = llvm.insertelement(vec, second, position=_lc(1))
|
||||
vec_regs.append(vec)
|
||||
return np.asarray(vec_regs, dtype=object).reshape(shape)
|
||||
|
||||
|
||||
def _as_i32_reg(v):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return llvm.extractelement(
|
||||
vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0)
|
||||
)
|
||||
|
||||
|
||||
def _lc(x):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
||||
|
Loading…
x
Reference in New Issue
Block a user