mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add layout inference for scf.ForOp
and scf.YieldOp
.
PiperOrigin-RevId: 730873769
This commit is contained in:
parent
7acd60c867
commit
5024ef213f
@ -22,7 +22,7 @@ from typing import cast
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
|
||||
from . import fragmented_array as fa
|
||||
@ -161,17 +161,17 @@ def _value_layout(value: ir.Value) -> ir.Attribute | None:
|
||||
].index(value)
|
||||
return layouts_lib.out_layouts(owner)[value_result_number]
|
||||
|
||||
# Function block case, useful when attempting to derive layouts for ops
|
||||
# depending on function parameters.
|
||||
if isinstance(owner, ir.Block) and isinstance(owner.owner, func.FuncOp):
|
||||
func_op = owner.owner
|
||||
# Block case, useful when attempting to derive layouts for ops
|
||||
# depending on function parameters, or loop block arguments.
|
||||
if isinstance(owner, ir.Block):
|
||||
owner_op = owner.owner
|
||||
block = cast(ir.Block, owner)
|
||||
if not layouts_lib.has_in_layouts_set(func_op):
|
||||
if not layouts_lib.has_in_layouts_set(owner_op):
|
||||
return None
|
||||
value_arg_number = [
|
||||
r for r in block.arguments if ir.VectorType.isinstance(r.type)
|
||||
].index(value)
|
||||
return layouts_lib.in_layouts(func_op)[value_arg_number]
|
||||
return layouts_lib.in_layouts(owner_op)[value_arg_number]
|
||||
|
||||
raise NotImplementedError(
|
||||
f"{owner} is not a function block nor an operation.")
|
||||
@ -303,6 +303,41 @@ def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts:
|
||||
return [], [layout]
|
||||
|
||||
|
||||
@partial(_add_layout_inference_rule, scf.YieldOp)
|
||||
def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts:
|
||||
layouts = []
|
||||
for result in op.results_:
|
||||
if not ir.VectorType.isinstance(result.type):
|
||||
continue
|
||||
if (layout := _value_layout(result)) is not None:
|
||||
layouts.append(layout)
|
||||
else:
|
||||
# Not all layouts could be inferred for vector ops. Return for now.
|
||||
return None
|
||||
|
||||
return (layouts, [])
|
||||
|
||||
|
||||
@partial(_add_layout_inference_rule, scf.ForOp)
|
||||
def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts:
|
||||
yield_op = op.body.operations[len(op.body.operations) - 1]
|
||||
assert isinstance(yield_op, scf.YieldOp)
|
||||
|
||||
if layouts_lib.has_in_layouts_set(yield_op):
|
||||
yield_layouts = list(layouts_lib.in_layouts(yield_op))
|
||||
if any(
|
||||
layouts_lib.is_splat_fragmented_layout(layout)
|
||||
for layout in yield_layouts
|
||||
):
|
||||
return None
|
||||
return (yield_layouts, yield_layouts)
|
||||
|
||||
# TODO(bchetioui): we don't attempt to propagate from outside for the moment.
|
||||
# For the existing kernels, propagating from the YieldOp should be enough.
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@partial(_add_layout_inference_rule, vector.SplatOp)
|
||||
def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
|
||||
layout = layouts_lib.to_splat_fragmented_layout_attr(
|
||||
|
@ -14,6 +14,8 @@
|
||||
# ==============================================================================
|
||||
"""Layout inference tests for the Mosaic GPU MLIR dialect."""
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
@ -196,6 +198,86 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
self.assertIn("in_layouts", add.attributes)
|
||||
self.assertIn("out_layouts", add.attributes)
|
||||
|
||||
@parameterized.parameters(
|
||||
(shape, layout)
|
||||
for shape in [(64, 32)]
|
||||
for layout in [
|
||||
mgpu.WGSplatFragLayout(shape),
|
||||
mgpu.WGStridedFragLayout(shape, vec_size=4),
|
||||
mgpu.WGMMAFragLayout(),
|
||||
]
|
||||
)
|
||||
def test_infer_layout_from_yield_op_in_layouts_for_for_op(
|
||||
self, shape, layout
|
||||
):
|
||||
add_op = for_op = yield_op = None
|
||||
|
||||
def body(lower_bound, upper_bound, step, a, b):
|
||||
nonlocal for_op
|
||||
for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b])
|
||||
[loop_a, loop_b] = list(for_op.inner_iter_args)
|
||||
with ir.InsertionPoint(for_op.body):
|
||||
nonlocal add_op, yield_op
|
||||
add_op = arith.AddFOp(loop_a, loop_b)
|
||||
yield_op = scf.YieldOp([add_op.result, add_op.result])
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
ab_type = ir.VectorType.get(shape, ir.BF16Type.get())
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
func.FuncOp.from_py_func(i32, i32, i32, ab_type, ab_type)(body)
|
||||
|
||||
add_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts.to_layout_attr(layout)]
|
||||
)
|
||||
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
if isinstance(layout, mgpu.WGSplatFragLayout):
|
||||
# In the splat case, we should not propagate the splat layout from the
|
||||
# yield op. That is because we can not convert other layouts to a splat
|
||||
# layout---which could cause trouble if the initial carries have a
|
||||
# different layout. Instead, we should get the default annotation, i.e.
|
||||
# strided layouts.
|
||||
strided_layout = layouts.to_layout_attr(
|
||||
mgpu.WGStridedFragLayout.from_shaped_type(ab_type)
|
||||
)
|
||||
carry_layouts = [strided_layout, strided_layout]
|
||||
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
|
||||
self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts)
|
||||
self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts)
|
||||
else:
|
||||
carry_layouts = [layouts.to_layout_attr(layout)] * 2
|
||||
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
|
||||
self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts)
|
||||
self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts)
|
||||
|
||||
def test_infer_layout_from_body_op_to_yield_op_to_for_op(self):
|
||||
for_op = yield_op = None
|
||||
shape = (64, 64)
|
||||
|
||||
def body(lower_bound, upper_bound, step, a, b, c):
|
||||
nonlocal for_op
|
||||
for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b, c])
|
||||
with ir.InsertionPoint(for_op.body):
|
||||
nonlocal yield_op
|
||||
[loop_a, loop_b, loop_c] = list(for_op.inner_iter_args)
|
||||
new_loop_c = mgpu.dialect.wgmma(loop_c, loop_a, loop_b)
|
||||
yield_op = scf.YieldOp([loop_a, loop_b, new_loop_c])
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
c_ty = ir.VectorType.get(shape, ir.BF16Type.get())
|
||||
ab_ty = ir.MemRefType.get(shape, ir.BF16Type.get())
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
func.FuncOp.from_py_func(i32, i32, i32, ab_ty, ab_ty, c_ty)(body)
|
||||
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
wgmma_layout = layouts.to_layout_attr(mgpu.WGMMAFragLayout())
|
||||
self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout])
|
||||
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
|
||||
self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout])
|
||||
self.assertSequenceEqual(for_op.attributes["out_layouts"], [wgmma_layout])
|
||||
|
||||
def test_infer_layout_has_no_layout_for_non_vector_types(self):
|
||||
shape = (32, 4)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
|
Loading…
x
Reference in New Issue
Block a user