mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Allow querying layouts from a FuncOp
's block arguments if set.
The motivation behind this change is twofold: 1. it simplifies test writing (no need to produce arbitrary, manual, non-splat constants to produce arguments with a strided layout); 2. it'll allow running layout inference on different `FuncOp`s in isolation, before inlining. While the primary motivation is to simplify test writing for upcoming changes, `2.` is useful if we ever intend to call functions whose body's layout we have inferred from other functions. It's not clear to me that we have a use case for that, but the theoretical benefit is worth pointing out. Crucially, layout inference does not set default layouts for `FuncOp`s, since the caller may choose a different layout for its arguments. As a result, there is also no layout inference rule for `func.FuncOp`. PiperOrigin-RevId: 716158516
This commit is contained in:
parent
4221f109d1
commit
bc7204f003
@ -20,6 +20,7 @@ from typing import cast
|
||||
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
|
||||
from .fragmented_array import WGSplatFragLayout, WGStridedFragLayout
|
||||
@ -123,27 +124,38 @@ def _in_layout_for_operand(
|
||||
return layouts_lib.in_layouts(op)[operand_number]
|
||||
|
||||
|
||||
def _out_layout_for_result(
|
||||
op: ir.OpView,
|
||||
result: ir.Value,
|
||||
) -> ir.Attribute | None:
|
||||
"""Returns the layout for a specific result of the given operation if it is set.
|
||||
def _value_layout(value: ir.Value) -> ir.Attribute | None:
|
||||
"""Returns the layout for a given value as defined by its owner.
|
||||
|
||||
Raises:
|
||||
ValueError: If `result` is not a result of `op`, or if `result` is not a
|
||||
Vector.
|
||||
ValueError: If `result` is not a Vector.
|
||||
"""
|
||||
if not ir.VectorType.isinstance(result.type):
|
||||
raise ValueError(f"{result} is not a vector.")
|
||||
if not ir.VectorType.isinstance(value.type):
|
||||
raise ValueError(f"{value} is not a vector.")
|
||||
|
||||
result_number = [
|
||||
r for r in op.results if ir.VectorType.isinstance(r.type)
|
||||
].index(result)
|
||||
owner = value.owner
|
||||
if isinstance(owner, ir.Operation):
|
||||
if not layouts_lib.has_out_layouts_set(owner):
|
||||
return None
|
||||
value_result_number = [
|
||||
r for r in owner.results if ir.VectorType.isinstance(r.type)
|
||||
].index(value)
|
||||
return layouts_lib.out_layouts(owner)[value_result_number]
|
||||
|
||||
if not layouts_lib.has_out_layouts_set(op):
|
||||
return None
|
||||
# Function block case, useful when attempting to derive layouts for ops
|
||||
# depending on function parameters.
|
||||
if isinstance(owner, ir.Block) and isinstance(owner.owner, func.FuncOp):
|
||||
func_op = owner.owner
|
||||
block = cast(ir.Block, owner)
|
||||
if not layouts_lib.has_in_layouts_set(func_op):
|
||||
return None
|
||||
value_arg_number = [
|
||||
r for r in block.arguments if ir.VectorType.isinstance(r.type)
|
||||
].index(value)
|
||||
return layouts_lib.in_layouts(func_op)[value_arg_number]
|
||||
|
||||
return layouts_lib.out_layouts(op)[result_number]
|
||||
raise NotImplementedError(
|
||||
f"{owner} is not a function block nor an operation.")
|
||||
|
||||
|
||||
def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
|
||||
@ -175,11 +187,9 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
|
||||
# far down as possible, until since we may be able to propagate splat layouts
|
||||
# further down before requiring a relayout in that way.
|
||||
for operand in op.operands:
|
||||
if not isinstance(
|
||||
operand.owner, ir.Operation
|
||||
) or not ir.VectorType.isinstance(operand.type):
|
||||
if not ir.VectorType.isinstance(operand.type):
|
||||
continue
|
||||
if (layout := _out_layout_for_result(operand.owner, operand)) is not None:
|
||||
if (layout := _value_layout(operand)) is not None:
|
||||
layouts.add(layout)
|
||||
|
||||
# We only look at consumers if we haven't found a possible layout yet. This is
|
||||
|
@ -290,6 +290,44 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
self.assertSequenceEqual(add0.attributes["out_layouts"], [splat_layout])
|
||||
self.assertSequenceEqual(add1.attributes["out_layouts"], [strided_layout])
|
||||
|
||||
def test_infer_layout_propagates_func_layouts_to_ops(self):
|
||||
add = None
|
||||
|
||||
def body(lhs, rhs):
|
||||
nonlocal add
|
||||
add = arith.AddFOp(lhs, rhs)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
shape = (32, 4)
|
||||
ty = ir.VectorType.get(shape, ir.BF16Type.get())
|
||||
func.FuncOp.from_py_func(ty, ty)(body)
|
||||
|
||||
[f] = self.module.body.operations
|
||||
splat_layout = mgpu.to_splat_fragmented_layout_attr(
|
||||
mgpu.WGSplatFragLayout(shape)
|
||||
)
|
||||
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
self.assertSequenceEqual(
|
||||
add.attributes["in_layouts"], [splat_layout, splat_layout])
|
||||
self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout])
|
||||
|
||||
def test_infer_layout_does_not_assign_default_layouts_to_func(self):
|
||||
|
||||
def body(lhs, rhs):
|
||||
arith.AddFOp(lhs, rhs)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
shape = (32, 4)
|
||||
ty = ir.VectorType.get(shape, ir.BF16Type.get())
|
||||
func.FuncOp.from_py_func(ty, ty)(body)
|
||||
|
||||
[f] = self.module.body.operations
|
||||
mgpu.infer_layout(self.module)
|
||||
self.assertNotIn("in_layouts", f.attributes)
|
||||
self.assertNotIn("out_layouts", f.attributes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user