[Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering

PiperOrigin-RevId: 731253431
This commit is contained in:
Adam Paszke 2025-02-26 04:02:28 -08:00 committed by jax authors
parent 3251b55ef2
commit 1de2f839d5
3 changed files with 53 additions and 21 deletions

View File

@ -107,9 +107,9 @@ def _fragmented_array_to_ir(
return conversion_cast.result
# TODO(bchetioui): add code that verifies the layout is as inferred.
def _fragmented_array_from_ir(
fragmented_array_as_ir: ir.Value,
layout: ir.Attribute,
is_signed: bool | None = None,
) -> fa.FragmentedArray:
@ -135,14 +135,14 @@ def _fragmented_array_from_ir(
registers = np.array(list(converted_outputs)).reshape(
[attr.value for attr in conversion_cast.attributes["registers_shape"]]
)
layout = layouts.from_layout_attr(conversion_cast.attributes["layout"])
producer_layout = layouts.from_layout_attr(conversion_cast.attributes["layout"])
if ir.IntegerType.isinstance(conversion_cast.outputs[0].type.element_type):
is_signed = False if is_signed is None else is_signed
return fa.FragmentedArray(
_registers=registers, _layout=layout, _is_signed=is_signed
)
_registers=registers, _layout=producer_layout, _is_signed=is_signed
).to_layout(layouts.from_layout_attr(layout))
# TODO(dasenov): Remove this when minimum jaxlib version >= 0.5.1.
@ -277,7 +277,10 @@ def _vector_store_op_lowering_rule(
f"for {vector_store_op}"
)
fragmented_array = _fragmented_array_from_ir(vector_store_op.valueToStore)
[to_store_layout] = layouts.in_layouts(vector_store_op)
fragmented_array = _fragmented_array_from_ir(
vector_store_op.valueToStore, to_store_layout
)
# TODO(dasenov): This is not efficient for WGMMA layouts
fragmented_array.store_untiled(vector_store_op.base)
@ -434,8 +437,12 @@ def _binary_op_lowering_rule(
[fa.FragmentedArray, fa.FragmentedArray], fa.FragmentedArray
],
) -> Sequence[ir.Value]:
lhs = _fragmented_array_from_ir(op.lhs, is_signed)
rhs = _fragmented_array_from_ir(op.rhs, is_signed)
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed)
rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed)
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
@ -485,9 +492,13 @@ CMPI_IMPLS = {
def _cmpi_op_lowering_rule(
_: LoweringContext, op: arith.CmpIOp
) -> Sequence[ir.Value]:
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
impl, is_signed = CMPI_IMPLS[op.predicate.value]
lhs = _fragmented_array_from_ir(op.lhs, is_signed)
rhs = _fragmented_array_from_ir(op.rhs, is_signed)
lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed)
rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed)
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
@ -505,9 +516,13 @@ CMPF_IMPLS = {
def _cmpf_op_lowering_rule(
_: LoweringContext, op: arith.CmpFOp
) -> Sequence[ir.Value]:
in_layouts = layouts.in_layouts(op)
[layout] = layouts.out_layouts(op)
if any(in_layout != layout for in_layout in in_layouts):
raise ValueError("Layout mismatch")
impl = CMPF_IMPLS[op.predicate.value]
lhs = _fragmented_array_from_ir(op.lhs)
rhs = _fragmented_array_from_ir(op.rhs)
lhs = _fragmented_array_from_ir(op.lhs, layout)
rhs = _fragmented_array_from_ir(op.rhs, layout)
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
@ -515,11 +530,15 @@ def _cmpf_op_lowering_rule(
def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
) -> Sequence[ir.Value]:
fa_layouts = (*layouts.in_layouts(wgmma_op), *layouts.out_layouts(wgmma_op))
if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)):
raise ValueError("Layout mismatch")
wgmma_layout = fa_layouts[0]
# TODO(dasenov): Move the value -> accumulator conversion outisde of wgmma.
# The associated fence could be a little expensive and is not needed if the
# result a wgmma feeds into another wgmma (even in another loop step).
acc_in = _fragmented_array_from_ir(wgmma_op.accumulator)
acc_in = _fragmented_array_from_ir(wgmma_op.accumulator, wgmma_layout)
regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
acc = wgmma.WGMMAAccumulator.from_registers(regs)
@ -527,7 +546,7 @@ def _mgpu_wgmma_op_lowering_rule(
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
if ir.VectorType.isinstance(wgmma_op.a.type):
a_operand = _fragmented_array_from_ir(wgmma_op.a)
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
else:
a_layout = ir.MemRefType(wgmma_op.a.type).layout
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout)

View File

@ -205,11 +205,14 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
# We first look at producers; this enables e.g. propagating splat layouts as
# far down as possible, until since we may be able to propagate splat layouts
# further down before requiring a relayout in that way.
all_inputs_have_layout = True
for operand in op.operands:
if not ir.VectorType.isinstance(operand.type):
continue
if (layout := _value_layout(operand)) is not None:
layouts.add(layout)
else:
all_inputs_have_layout = False
# We only look at consumers if we haven't found a possible layout yet. This is
# to avoid propagating more complicated layouts up, to e.g. preserve splat
@ -232,6 +235,10 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
# This is left for a future change, and currently we only do "down
# propagation".
layout = _choose_representative_layout(layouts)
# It is unsafe to t conclude that this op produces a splat if not all inputs
# have been inferred: some of them might turn out not to be splats!
if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout:
return None
if layout is None:
return None

View File

@ -123,7 +123,8 @@ class LayoutInferenceTest(parameterized.TestCase):
self.assertEmpty(c.attributes["in_layouts"])
self.assertSequenceEqual(c.attributes["out_layouts"], [layout])
def test_infer_splat_layout_for_vector_splat(self):
@parameterized.parameters(True, False)
def test_infer_splat_layout_for_vector_splat(self, rhs_splat):
add = splat = None
def body(lhs, rhs):
@ -135,19 +136,24 @@ class LayoutInferenceTest(parameterized.TestCase):
shape = (16, 8)
elt_type = ir.BF16Type.get()
ty = ir.VectorType.get(shape, elt_type)
func.FuncOp.from_py_func(elt_type, ty)(body)
# Not setting any layouts on the module should default in all ops having a
# splat fragmented layout.
mgpu.infer_layout(self.module)
func_op = func.FuncOp.from_py_func(elt_type, ty)(body).func_op
layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape))
if rhs_splat:
func_op.attributes["in_layouts"] = ir.ArrayAttr.get([layout])
mgpu.infer_layout(self.module)
self.assertEmpty(splat.attributes["in_layouts"])
self.assertSequenceEqual(splat.attributes["out_layouts"], [layout])
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
add_layout = layout
if not rhs_splat:
add_layout = layouts.to_layout_attr(
mgpu.WGStridedFragLayout.from_shaped_type(ty)
)
self.assertSequenceEqual(add.attributes["in_layouts"], [add_layout, add_layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [add_layout])
@parameterized.parameters(
mgpu.WGSplatFragLayout(shape=(32, 4)),