mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
This commit is contained in:
parent
1de2f839d5
commit
99a12ef9ea
@ -603,10 +603,91 @@ def _for_op_lowering_rule(
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(scf.ForOp)
|
||||
def _for_op_lowering_rule(
|
||||
ctx: LoweringContext, for_op: scf.ForOp
|
||||
) -> MlirLoweringRuleResult:
|
||||
if not layouts.should_have_layout(for_op):
|
||||
return _traverse_op_lowering_rule(ctx, for_op)
|
||||
in_layouts = layouts.in_layouts(for_op)
|
||||
out_layouts = layouts.out_layouts(for_op)
|
||||
yield_op = for_op.body.operations[len(for_op.body.operations) - 1]
|
||||
yield_layouts = layouts.in_layouts(yield_op)
|
||||
if in_layouts != out_layouts or in_layouts != yield_layouts:
|
||||
raise ValueError("Layout mismatch")
|
||||
fa_layouts = in_layouts
|
||||
|
||||
fa_layouts_it = iter(fa_layouts)
|
||||
arg_template = [
|
||||
(_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type)
|
||||
if ir.VectorType.isinstance(arg.type)
|
||||
else (arg, arg.type)
|
||||
for arg in for_op.initArgs
|
||||
]
|
||||
def lower_carry(carry):
|
||||
fa_layouts_it = iter(fa_layouts)
|
||||
carry_with_fas = [
|
||||
_fragmented_array_from_ir(arg, next(fa_layouts_it))
|
||||
if ir.VectorType.isinstance(arg.type)
|
||||
else arg
|
||||
for arg in carry
|
||||
]
|
||||
lowered_carry = []
|
||||
for c in carry_with_fas:
|
||||
if isinstance(c, fa.FragmentedArray):
|
||||
lowered_carry.extend(c.registers.flat)
|
||||
else:
|
||||
lowered_carry.append(c)
|
||||
return lowered_carry
|
||||
|
||||
def recreate_carry(lowered_carry):
|
||||
recreated_carry = []
|
||||
arg_it = iter(lowered_carry)
|
||||
for arg_value, arg_type in arg_template:
|
||||
if isinstance(arg_value, fa.FragmentedArray):
|
||||
carry_registers = np.asarray(
|
||||
[next(arg_it) for _ in arg_value.registers.flat], dtype=object
|
||||
)
|
||||
carry_registers = carry_registers.reshape(arg_value.registers.shape)
|
||||
carry = fa.FragmentedArray(
|
||||
_registers=carry_registers,
|
||||
_layout=arg_value.layout,
|
||||
_is_signed=arg_value.is_signed,
|
||||
)
|
||||
recreated_carry.append(_fragmented_array_to_ir(carry, arg_type))
|
||||
else:
|
||||
recreated_carry.append(next(arg_it))
|
||||
return recreated_carry
|
||||
|
||||
new_for_op = scf.ForOp(
|
||||
for_op.lowerBound,
|
||||
for_op.upperBound,
|
||||
for_op.step,
|
||||
lower_carry(for_op.initArgs),
|
||||
)
|
||||
with ir.InsertionPoint(new_for_op.body):
|
||||
recreated_carry = recreate_carry(new_for_op.body.arguments[1:])
|
||||
ops_to_lower = []
|
||||
for op in for_op.body:
|
||||
if op == yield_op:
|
||||
continue
|
||||
mgpu.private_operation_remove_from_parent(op)
|
||||
mgpu.private_block_append_owned_operation(new_for_op.body, op)
|
||||
ops_to_lower.append(op)
|
||||
new_args = (new_for_op.induction_variable, *recreated_carry)
|
||||
for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True):
|
||||
old_carry.replace_all_uses_with(new_carry)
|
||||
for op in ops_to_lower:
|
||||
ctx.lower_op(op)
|
||||
new_yield_operands = lower_carry(yield_op.operands)
|
||||
yield_op.erase()
|
||||
scf.yield_(new_yield_operands)
|
||||
return recreate_carry(new_for_op.results)
|
||||
|
||||
|
||||
@_register_lowering(func.FuncOp)
|
||||
@_register_lowering(gpu.LaunchOp)
|
||||
@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule.
|
||||
@_register_lowering(scf.ForOp) # TODO(apaszke,bchetioui): Add a proper rule.
|
||||
@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule.
|
||||
def _traverse_op_lowering_rule(
|
||||
ctx: LoweringContext, op: ir.OpView
|
||||
@ -661,6 +742,7 @@ def _should_lower(op: ir.OpView) -> bool:
|
||||
def lower_mgpu_dialect(
|
||||
module: ir.Module, launch_context: launch_context.LaunchContext | None
|
||||
):
|
||||
# TODO(apaszke,bchetioui): Make sure the layouts match.
|
||||
# TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full
|
||||
# module and to traverse all `gpu.LaunchOp`s if we have a `LaunchContext` that
|
||||
# references a single `gpu.LaunchOp`.
|
||||
|
@ -317,6 +317,8 @@ def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts:
|
||||
if not ir.VectorType.isinstance(result.type):
|
||||
continue
|
||||
if (layout := _value_layout(result)) is not None:
|
||||
if layouts_lib.is_splat_fragmented_layout(layout):
|
||||
return None
|
||||
layouts.append(layout)
|
||||
else:
|
||||
# Not all layouts could be inferred for vector ops. Return for now.
|
||||
|
@ -35,6 +35,8 @@ NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
}
|
||||
},
|
||||
nb::arg("context"), nb::arg("load") = true);
|
||||
m.def("private_operation_remove_from_parent", mlirOperationRemoveFromParent);
|
||||
m.def("private_block_append_owned_operation", mlirBlockAppendOwnedOperation);
|
||||
|
||||
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
|
||||
m, "TileTransformAttr", mlirMosaicGpuIsATileTransformAttr)
|
||||
|
@ -761,6 +761,47 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
check_type(store1.valueToStore.type)
|
||||
check_type(store2.valueToStore.type)
|
||||
|
||||
def test_lowering_for(self):
|
||||
shape = (4, 128)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
vec_ty = ir.VectorType.get(shape, i32)
|
||||
splat_layout_attr = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape))
|
||||
strided_layout_attr = layouts.to_layout_attr(
|
||||
mgpu.WGStridedFragLayout.from_shaped_type(vec_ty)
|
||||
)
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
i1 = arith.constant(ir.IndexType.get(), 1)
|
||||
c1 = arith.constant(i32, 1)
|
||||
splat = vector.SplatOp(
|
||||
ir.VectorType.get(shape, i32), arith.constant(i32, 1234),
|
||||
)
|
||||
splat.attributes["out_layouts"] = ir.ArrayAttr.get([
|
||||
splat_layout_attr
|
||||
])
|
||||
ptr = llvm.mlir_undef(ir.Type.parse("!llvm.ptr"))
|
||||
ref = mgpu_utils.ptr_as_memref(ptr, ir.MemRefType.get(shape, i32))
|
||||
i0 = arith.constant(ir.IndexType.get(), 0)
|
||||
other_vec = vector.LoadOp(vec_ty, ref, [i0, i0])
|
||||
other_vec.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
|
||||
for_op = scf.ForOp(i1, i1, i1, [c1, splat.result])
|
||||
for_op.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
|
||||
for_op.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
|
||||
with ir.InsertionPoint(for_op.body):
|
||||
i, int_carry, vec_carry = for_op.body.arguments
|
||||
new_int_carry = arith.addi(int_carry, arith.index_castui(i32, i))
|
||||
new_vec_carry = arith.AddIOp(vec_carry, other_vec)
|
||||
new_vec_carry.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr] * 2)
|
||||
new_vec_carry.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
|
||||
yield_op = scf.YieldOp([new_int_carry, new_vec_carry])
|
||||
yield_op.attributes["in_layouts"] = ir.ArrayAttr.get([strided_layout_attr])
|
||||
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
self.module.operation.verify()
|
||||
[for_op] = find_if(self.module, lambda op: isinstance(op, scf.ForOp))
|
||||
result_types = [r.type for r in for_op.results]
|
||||
reg_vec_ty = ir.VectorType.get((2,), i32)
|
||||
self.assertSequenceEqual(result_types, [i32, reg_vec_ty, reg_vec_ty])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -902,9 +902,8 @@ class PallasCallTest(PallasTest):
|
||||
force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics]
|
||||
)
|
||||
def test_fori_loop_array(self, force_while, thread_semantics):
|
||||
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
|
||||
# TODO(apaszke,bchetioui,slebedev): Support while + array carries.
|
||||
self.skipTest("WG semantics unsupported")
|
||||
if force_while and thread_semantics == plgpu.ThreadSemantics.Warpgroup:
|
||||
self.skipTest("WG semantics does not support force_while.")
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
|
Loading…
x
Reference in New Issue
Block a user