[Mosaic GPU] Allow querying layouts from a FuncOp's block arguments if set.

The motivation behind this change is twofold:

1. it simplifies test writing (no need to produce arbitrary, manual, non-splat
   constants to produce arguments with a strided layout);
2. it'll allow running layout inference on different `FuncOp`s in isolation,
   before inlining.

While the primary motivation is to simplify test writing for upcoming changes,
`2.` is useful if we ever intend to call functions whose body's layout we have
inferred from other functions. It's not clear to me that we have a use case for
that, but the theoretical benefit is worth pointing out.

Crucially, layout inference does not set default layouts for `FuncOp`s, since
the caller may choose a different layout for its arguments. As a result, there
is also no layout inference rule for `func.FuncOp`.

PiperOrigin-RevId: 716158516
This commit is contained in:
Benjamin Chetioui 2025-01-16 03:04:55 -08:00 committed by jax authors
parent 4221f109d1
commit bc7204f003
2 changed files with 67 additions and 19 deletions

View File

@ -20,6 +20,7 @@ from typing import cast
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import vector
from .fragmented_array import WGSplatFragLayout, WGStridedFragLayout
@ -123,27 +124,38 @@ def _in_layout_for_operand(
return layouts_lib.in_layouts(op)[operand_number]
def _out_layout_for_result(
op: ir.OpView,
result: ir.Value,
) -> ir.Attribute | None:
"""Returns the layout for a specific result of the given operation if it is set.
def _value_layout(value: ir.Value) -> ir.Attribute | None:
"""Returns the layout for a given value as defined by its owner.
Raises:
ValueError: If `result` is not a result of `op`, or if `result` is not a
Vector.
ValueError: If `result` is not a Vector.
"""
if not ir.VectorType.isinstance(result.type):
raise ValueError(f"{result} is not a vector.")
if not ir.VectorType.isinstance(value.type):
raise ValueError(f"{value} is not a vector.")
result_number = [
r for r in op.results if ir.VectorType.isinstance(r.type)
].index(result)
owner = value.owner
if isinstance(owner, ir.Operation):
if not layouts_lib.has_out_layouts_set(owner):
return None
value_result_number = [
r for r in owner.results if ir.VectorType.isinstance(r.type)
].index(value)
return layouts_lib.out_layouts(owner)[value_result_number]
if not layouts_lib.has_out_layouts_set(op):
return None
# Function block case, useful when attempting to derive layouts for ops
# depending on function parameters.
if isinstance(owner, ir.Block) and isinstance(owner.owner, func.FuncOp):
func_op = owner.owner
block = cast(ir.Block, owner)
if not layouts_lib.has_in_layouts_set(func_op):
return None
value_arg_number = [
r for r in block.arguments if ir.VectorType.isinstance(r.type)
].index(value)
return layouts_lib.in_layouts(func_op)[value_arg_number]
return layouts_lib.out_layouts(op)[result_number]
raise NotImplementedError(
f"{owner} is not a function block nor an operation.")
def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
@ -175,11 +187,9 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
# far down as possible, until since we may be able to propagate splat layouts
# further down before requiring a relayout in that way.
for operand in op.operands:
if not isinstance(
operand.owner, ir.Operation
) or not ir.VectorType.isinstance(operand.type):
if not ir.VectorType.isinstance(operand.type):
continue
if (layout := _out_layout_for_result(operand.owner, operand)) is not None:
if (layout := _value_layout(operand)) is not None:
layouts.add(layout)
# We only look at consumers if we haven't found a possible layout yet. This is

View File

@ -290,6 +290,44 @@ class LayoutInferenceTest(parameterized.TestCase):
self.assertSequenceEqual(add0.attributes["out_layouts"], [splat_layout])
self.assertSequenceEqual(add1.attributes["out_layouts"], [strided_layout])
def test_infer_layout_propagates_func_layouts_to_ops(self):
add = None
def body(lhs, rhs):
nonlocal add
add = arith.AddFOp(lhs, rhs)
with ir.InsertionPoint(self.module.body):
shape = (32, 4)
ty = ir.VectorType.get(shape, ir.BF16Type.get())
func.FuncOp.from_py_func(ty, ty)(body)
[f] = self.module.body.operations
splat_layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape)
)
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
mgpu.infer_layout(self.module)
self.assertSequenceEqual(
add.attributes["in_layouts"], [splat_layout, splat_layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout])
def test_infer_layout_does_not_assign_default_layouts_to_func(self):
def body(lhs, rhs):
arith.AddFOp(lhs, rhs)
with ir.InsertionPoint(self.module.body):
shape = (32, 4)
ty = ir.VectorType.get(shape, ir.BF16Type.get())
func.FuncOp.from_py_func(ty, ty)(body)
[f] = self.module.body.operations
mgpu.infer_layout(self.module)
self.assertNotIn("in_layouts", f.attributes)
self.assertNotIn("out_layouts", f.attributes)
if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())