mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering
PiperOrigin-RevId: 731253431
This commit is contained in:
parent
3251b55ef2
commit
1de2f839d5
@ -107,9 +107,9 @@ def _fragmented_array_to_ir(
|
||||
return conversion_cast.result
|
||||
|
||||
|
||||
# TODO(bchetioui): add code that verifies the layout is as inferred.
|
||||
def _fragmented_array_from_ir(
|
||||
fragmented_array_as_ir: ir.Value,
|
||||
layout: ir.Attribute,
|
||||
is_signed: bool | None = None,
|
||||
) -> fa.FragmentedArray:
|
||||
|
||||
@ -135,14 +135,14 @@ def _fragmented_array_from_ir(
|
||||
registers = np.array(list(converted_outputs)).reshape(
|
||||
[attr.value for attr in conversion_cast.attributes["registers_shape"]]
|
||||
)
|
||||
layout = layouts.from_layout_attr(conversion_cast.attributes["layout"])
|
||||
producer_layout = layouts.from_layout_attr(conversion_cast.attributes["layout"])
|
||||
|
||||
if ir.IntegerType.isinstance(conversion_cast.outputs[0].type.element_type):
|
||||
is_signed = False if is_signed is None else is_signed
|
||||
|
||||
return fa.FragmentedArray(
|
||||
_registers=registers, _layout=layout, _is_signed=is_signed
|
||||
)
|
||||
_registers=registers, _layout=producer_layout, _is_signed=is_signed
|
||||
).to_layout(layouts.from_layout_attr(layout))
|
||||
|
||||
|
||||
# TODO(dasenov): Remove this when minimum jaxlib version >= 0.5.1.
|
||||
@ -277,7 +277,10 @@ def _vector_store_op_lowering_rule(
|
||||
f"for {vector_store_op}"
|
||||
)
|
||||
|
||||
fragmented_array = _fragmented_array_from_ir(vector_store_op.valueToStore)
|
||||
[to_store_layout] = layouts.in_layouts(vector_store_op)
|
||||
fragmented_array = _fragmented_array_from_ir(
|
||||
vector_store_op.valueToStore, to_store_layout
|
||||
)
|
||||
|
||||
# TODO(dasenov): This is not efficient for WGMMA layouts
|
||||
fragmented_array.store_untiled(vector_store_op.base)
|
||||
@ -434,8 +437,12 @@ def _binary_op_lowering_rule(
|
||||
[fa.FragmentedArray, fa.FragmentedArray], fa.FragmentedArray
|
||||
],
|
||||
) -> Sequence[ir.Value]:
|
||||
lhs = _fragmented_array_from_ir(op.lhs, is_signed)
|
||||
rhs = _fragmented_array_from_ir(op.rhs, is_signed)
|
||||
in_layouts = layouts.in_layouts(op)
|
||||
[layout] = layouts.out_layouts(op)
|
||||
if any(in_layout != layout for in_layout in in_layouts):
|
||||
raise ValueError("Layout mismatch")
|
||||
lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed)
|
||||
rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed)
|
||||
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
|
||||
|
||||
|
||||
@ -485,9 +492,13 @@ CMPI_IMPLS = {
|
||||
def _cmpi_op_lowering_rule(
|
||||
_: LoweringContext, op: arith.CmpIOp
|
||||
) -> Sequence[ir.Value]:
|
||||
in_layouts = layouts.in_layouts(op)
|
||||
[layout] = layouts.out_layouts(op)
|
||||
if any(in_layout != layout for in_layout in in_layouts):
|
||||
raise ValueError("Layout mismatch")
|
||||
impl, is_signed = CMPI_IMPLS[op.predicate.value]
|
||||
lhs = _fragmented_array_from_ir(op.lhs, is_signed)
|
||||
rhs = _fragmented_array_from_ir(op.rhs, is_signed)
|
||||
lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed)
|
||||
rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed)
|
||||
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
|
||||
|
||||
|
||||
@ -505,9 +516,13 @@ CMPF_IMPLS = {
|
||||
def _cmpf_op_lowering_rule(
|
||||
_: LoweringContext, op: arith.CmpFOp
|
||||
) -> Sequence[ir.Value]:
|
||||
in_layouts = layouts.in_layouts(op)
|
||||
[layout] = layouts.out_layouts(op)
|
||||
if any(in_layout != layout for in_layout in in_layouts):
|
||||
raise ValueError("Layout mismatch")
|
||||
impl = CMPF_IMPLS[op.predicate.value]
|
||||
lhs = _fragmented_array_from_ir(op.lhs)
|
||||
rhs = _fragmented_array_from_ir(op.rhs)
|
||||
lhs = _fragmented_array_from_ir(op.lhs, layout)
|
||||
rhs = _fragmented_array_from_ir(op.rhs, layout)
|
||||
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
|
||||
|
||||
|
||||
@ -515,11 +530,15 @@ def _cmpf_op_lowering_rule(
|
||||
def _mgpu_wgmma_op_lowering_rule(
|
||||
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
|
||||
) -> Sequence[ir.Value]:
|
||||
fa_layouts = (*layouts.in_layouts(wgmma_op), *layouts.out_layouts(wgmma_op))
|
||||
if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)):
|
||||
raise ValueError("Layout mismatch")
|
||||
wgmma_layout = fa_layouts[0]
|
||||
|
||||
# TODO(dasenov): Move the value -> accumulator conversion outisde of wgmma.
|
||||
# The associated fence could be a little expensive and is not needed if the
|
||||
# result a wgmma feeds into another wgmma (even in another loop step).
|
||||
acc_in = _fragmented_array_from_ir(wgmma_op.accumulator)
|
||||
acc_in = _fragmented_array_from_ir(wgmma_op.accumulator, wgmma_layout)
|
||||
regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
|
||||
acc = wgmma.WGMMAAccumulator.from_registers(regs)
|
||||
|
||||
@ -527,7 +546,7 @@ def _mgpu_wgmma_op_lowering_rule(
|
||||
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
|
||||
|
||||
if ir.VectorType.isinstance(wgmma_op.a.type):
|
||||
a_operand = _fragmented_array_from_ir(wgmma_op.a)
|
||||
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
|
||||
else:
|
||||
a_layout = ir.MemRefType(wgmma_op.a.type).layout
|
||||
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout)
|
||||
|
@ -205,11 +205,14 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
|
||||
# We first look at producers; this enables e.g. propagating splat layouts as
|
||||
# far down as possible, until since we may be able to propagate splat layouts
|
||||
# further down before requiring a relayout in that way.
|
||||
all_inputs_have_layout = True
|
||||
for operand in op.operands:
|
||||
if not ir.VectorType.isinstance(operand.type):
|
||||
continue
|
||||
if (layout := _value_layout(operand)) is not None:
|
||||
layouts.add(layout)
|
||||
else:
|
||||
all_inputs_have_layout = False
|
||||
|
||||
# We only look at consumers if we haven't found a possible layout yet. This is
|
||||
# to avoid propagating more complicated layouts up, to e.g. preserve splat
|
||||
@ -232,6 +235,10 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
|
||||
# This is left for a future change, and currently we only do "down
|
||||
# propagation".
|
||||
layout = _choose_representative_layout(layouts)
|
||||
# It is unsafe to t conclude that this op produces a splat if not all inputs
|
||||
# have been inferred: some of them might turn out not to be splats!
|
||||
if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout:
|
||||
return None
|
||||
if layout is None:
|
||||
return None
|
||||
|
||||
|
@ -123,7 +123,8 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
self.assertEmpty(c.attributes["in_layouts"])
|
||||
self.assertSequenceEqual(c.attributes["out_layouts"], [layout])
|
||||
|
||||
def test_infer_splat_layout_for_vector_splat(self):
|
||||
@parameterized.parameters(True, False)
|
||||
def test_infer_splat_layout_for_vector_splat(self, rhs_splat):
|
||||
add = splat = None
|
||||
|
||||
def body(lhs, rhs):
|
||||
@ -135,19 +136,24 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
shape = (16, 8)
|
||||
elt_type = ir.BF16Type.get()
|
||||
ty = ir.VectorType.get(shape, elt_type)
|
||||
func.FuncOp.from_py_func(elt_type, ty)(body)
|
||||
|
||||
# Not setting any layouts on the module should default in all ops having a
|
||||
# splat fragmented layout.
|
||||
mgpu.infer_layout(self.module)
|
||||
func_op = func.FuncOp.from_py_func(elt_type, ty)(body).func_op
|
||||
|
||||
layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape))
|
||||
if rhs_splat:
|
||||
func_op.attributes["in_layouts"] = ir.ArrayAttr.get([layout])
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
self.assertEmpty(splat.attributes["in_layouts"])
|
||||
self.assertSequenceEqual(splat.attributes["out_layouts"], [layout])
|
||||
|
||||
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
|
||||
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
|
||||
add_layout = layout
|
||||
if not rhs_splat:
|
||||
add_layout = layouts.to_layout_attr(
|
||||
mgpu.WGStridedFragLayout.from_shaped_type(ty)
|
||||
)
|
||||
|
||||
self.assertSequenceEqual(add.attributes["in_layouts"], [add_layout, add_layout])
|
||||
self.assertSequenceEqual(add.attributes["out_layouts"], [add_layout])
|
||||
|
||||
@parameterized.parameters(
|
||||
mgpu.WGSplatFragLayout(shape=(32, 4)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user