[Mosaic GPU] Add layout inference for scf.ForOp and scf.YieldOp.

PiperOrigin-RevId: 730873769
This commit is contained in:
Benjamin Chetioui 2025-02-25 07:12:34 -08:00 committed by jax authors
parent 7acd60c867
commit 5024ef213f
2 changed files with 124 additions and 7 deletions

View File

@ -22,7 +22,7 @@ from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from . import fragmented_array as fa
@ -161,17 +161,17 @@ def _value_layout(value: ir.Value) -> ir.Attribute | None:
].index(value)
return layouts_lib.out_layouts(owner)[value_result_number]
# Function block case, useful when attempting to derive layouts for ops
# depending on function parameters.
if isinstance(owner, ir.Block) and isinstance(owner.owner, func.FuncOp):
func_op = owner.owner
# Block case, useful when attempting to derive layouts for ops
# depending on function parameters, or loop block arguments.
if isinstance(owner, ir.Block):
owner_op = owner.owner
block = cast(ir.Block, owner)
if not layouts_lib.has_in_layouts_set(func_op):
if not layouts_lib.has_in_layouts_set(owner_op):
return None
value_arg_number = [
r for r in block.arguments if ir.VectorType.isinstance(r.type)
].index(value)
return layouts_lib.in_layouts(func_op)[value_arg_number]
return layouts_lib.in_layouts(owner_op)[value_arg_number]
raise NotImplementedError(
f"{owner} is not a function block nor an operation.")
@ -303,6 +303,41 @@ def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts:
return [], [layout]
@partial(_add_layout_inference_rule, scf.YieldOp)
def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts:
layouts = []
for result in op.results_:
if not ir.VectorType.isinstance(result.type):
continue
if (layout := _value_layout(result)) is not None:
layouts.append(layout)
else:
# Not all layouts could be inferred for vector ops. Return for now.
return None
return (layouts, [])
@partial(_add_layout_inference_rule, scf.ForOp)
def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts:
yield_op = op.body.operations[len(op.body.operations) - 1]
assert isinstance(yield_op, scf.YieldOp)
if layouts_lib.has_in_layouts_set(yield_op):
yield_layouts = list(layouts_lib.in_layouts(yield_op))
if any(
layouts_lib.is_splat_fragmented_layout(layout)
for layout in yield_layouts
):
return None
return (yield_layouts, yield_layouts)
# TODO(bchetioui): we don't attempt to propagate from outside for the moment.
# For the existing kernels, propagating from the YieldOp should be enough.
return None
@partial(_add_layout_inference_rule, vector.SplatOp)
def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
layout = layouts_lib.to_splat_fragmented_layout_attr(

View File

@ -14,6 +14,8 @@
# ==============================================================================
"""Layout inference tests for the Mosaic GPU MLIR dialect."""
# pylint: disable=g-complex-comprehension
from absl.testing import parameterized
import jax
from jax._src import config
@ -196,6 +198,86 @@ class LayoutInferenceTest(parameterized.TestCase):
self.assertIn("in_layouts", add.attributes)
self.assertIn("out_layouts", add.attributes)
@parameterized.parameters(
(shape, layout)
for shape in [(64, 32)]
for layout in [
mgpu.WGSplatFragLayout(shape),
mgpu.WGStridedFragLayout(shape, vec_size=4),
mgpu.WGMMAFragLayout(),
]
)
def test_infer_layout_from_yield_op_in_layouts_for_for_op(
self, shape, layout
):
add_op = for_op = yield_op = None
def body(lower_bound, upper_bound, step, a, b):
nonlocal for_op
for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b])
[loop_a, loop_b] = list(for_op.inner_iter_args)
with ir.InsertionPoint(for_op.body):
nonlocal add_op, yield_op
add_op = arith.AddFOp(loop_a, loop_b)
yield_op = scf.YieldOp([add_op.result, add_op.result])
with ir.InsertionPoint(self.module.body):
ab_type = ir.VectorType.get(shape, ir.BF16Type.get())
i32 = ir.IntegerType.get_signless(32)
func.FuncOp.from_py_func(i32, i32, i32, ab_type, ab_type)(body)
add_op.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts.to_layout_attr(layout)]
)
mgpu.infer_layout(self.module)
if isinstance(layout, mgpu.WGSplatFragLayout):
# In the splat case, we should not propagate the splat layout from the
# yield op. That is because we can not convert other layouts to a splat
# layout---which could cause trouble if the initial carries have a
# different layout. Instead, we should get the default annotation, i.e.
# strided layouts.
strided_layout = layouts.to_layout_attr(
mgpu.WGStridedFragLayout.from_shaped_type(ab_type)
)
carry_layouts = [strided_layout, strided_layout]
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts)
self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts)
else:
carry_layouts = [layouts.to_layout_attr(layout)] * 2
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts)
self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts)
def test_infer_layout_from_body_op_to_yield_op_to_for_op(self):
for_op = yield_op = None
shape = (64, 64)
def body(lower_bound, upper_bound, step, a, b, c):
nonlocal for_op
for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b, c])
with ir.InsertionPoint(for_op.body):
nonlocal yield_op
[loop_a, loop_b, loop_c] = list(for_op.inner_iter_args)
new_loop_c = mgpu.dialect.wgmma(loop_c, loop_a, loop_b)
yield_op = scf.YieldOp([loop_a, loop_b, new_loop_c])
with ir.InsertionPoint(self.module.body):
c_ty = ir.VectorType.get(shape, ir.BF16Type.get())
ab_ty = ir.MemRefType.get(shape, ir.BF16Type.get())
i32 = ir.IntegerType.get_signless(32)
func.FuncOp.from_py_func(i32, i32, i32, ab_ty, ab_ty, c_ty)(body)
mgpu.infer_layout(self.module)
wgmma_layout = layouts.to_layout_attr(mgpu.WGMMAFragLayout())
self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout])
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout])
self.assertSequenceEqual(for_op.attributes["out_layouts"], [wgmma_layout])
def test_infer_layout_has_no_layout_for_non_vector_types(self):
shape = (32, 4)
elt_ty = ir.BF16Type.get()