From 1de2f839d5a9b84792c12e6d3eaa98cd97ce70e5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 26 Feb 2025 04:02:28 -0800 Subject: [PATCH] [Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering PiperOrigin-RevId: 731253431 --- .../mosaic/gpu/dialect_lowering.py | 45 +++++++++++++------ .../mosaic/gpu/layout_inference.py | 7 +++ tests/mosaic/gpu_layout_inference_test.py | 22 +++++---- 3 files changed, 53 insertions(+), 21 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index ac443f012..1c7265fd6 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -107,9 +107,9 @@ def _fragmented_array_to_ir( return conversion_cast.result -# TODO(bchetioui): add code that verifies the layout is as inferred. def _fragmented_array_from_ir( fragmented_array_as_ir: ir.Value, + layout: ir.Attribute, is_signed: bool | None = None, ) -> fa.FragmentedArray: @@ -135,14 +135,14 @@ def _fragmented_array_from_ir( registers = np.array(list(converted_outputs)).reshape( [attr.value for attr in conversion_cast.attributes["registers_shape"]] ) - layout = layouts.from_layout_attr(conversion_cast.attributes["layout"]) + producer_layout = layouts.from_layout_attr(conversion_cast.attributes["layout"]) if ir.IntegerType.isinstance(conversion_cast.outputs[0].type.element_type): is_signed = False if is_signed is None else is_signed return fa.FragmentedArray( - _registers=registers, _layout=layout, _is_signed=is_signed - ) + _registers=registers, _layout=producer_layout, _is_signed=is_signed + ).to_layout(layouts.from_layout_attr(layout)) # TODO(dasenov): Remove this when minimum jaxlib version >= 0.5.1. @@ -277,7 +277,10 @@ def _vector_store_op_lowering_rule( f"for {vector_store_op}" ) - fragmented_array = _fragmented_array_from_ir(vector_store_op.valueToStore) + [to_store_layout] = layouts.in_layouts(vector_store_op) + fragmented_array = _fragmented_array_from_ir( + vector_store_op.valueToStore, to_store_layout + ) # TODO(dasenov): This is not efficient for WGMMA layouts fragmented_array.store_untiled(vector_store_op.base) @@ -434,8 +437,12 @@ def _binary_op_lowering_rule( [fa.FragmentedArray, fa.FragmentedArray], fa.FragmentedArray ], ) -> Sequence[ir.Value]: - lhs = _fragmented_array_from_ir(op.lhs, is_signed) - rhs = _fragmented_array_from_ir(op.rhs, is_signed) + in_layouts = layouts.in_layouts(op) + [layout] = layouts.out_layouts(op) + if any(in_layout != layout for in_layout in in_layouts): + raise ValueError("Layout mismatch") + lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed) + rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed) return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] @@ -485,9 +492,13 @@ CMPI_IMPLS = { def _cmpi_op_lowering_rule( _: LoweringContext, op: arith.CmpIOp ) -> Sequence[ir.Value]: + in_layouts = layouts.in_layouts(op) + [layout] = layouts.out_layouts(op) + if any(in_layout != layout for in_layout in in_layouts): + raise ValueError("Layout mismatch") impl, is_signed = CMPI_IMPLS[op.predicate.value] - lhs = _fragmented_array_from_ir(op.lhs, is_signed) - rhs = _fragmented_array_from_ir(op.rhs, is_signed) + lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed) + rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed) return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] @@ -505,9 +516,13 @@ CMPF_IMPLS = { def _cmpf_op_lowering_rule( _: LoweringContext, op: arith.CmpFOp ) -> Sequence[ir.Value]: + in_layouts = layouts.in_layouts(op) + [layout] = layouts.out_layouts(op) + if any(in_layout != layout for in_layout in in_layouts): + raise ValueError("Layout mismatch") impl = CMPF_IMPLS[op.predicate.value] - lhs = _fragmented_array_from_ir(op.lhs) - rhs = _fragmented_array_from_ir(op.rhs) + lhs = _fragmented_array_from_ir(op.lhs, layout) + rhs = _fragmented_array_from_ir(op.rhs, layout) return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)] @@ -515,11 +530,15 @@ def _cmpf_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: + fa_layouts = (*layouts.in_layouts(wgmma_op), *layouts.out_layouts(wgmma_op)) + if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)): + raise ValueError("Layout mismatch") + wgmma_layout = fa_layouts[0] # TODO(dasenov): Move the value -> accumulator conversion outisde of wgmma. # The associated fence could be a little expensive and is not needed if the # result a wgmma feeds into another wgmma (even in another loop step). - acc_in = _fragmented_array_from_ir(wgmma_op.accumulator) + acc_in = _fragmented_array_from_ir(wgmma_op.accumulator, wgmma_layout) regs = acc_in.to_layout(fa.WGMMA_LAYOUT) acc = wgmma.WGMMAAccumulator.from_registers(regs) @@ -527,7 +546,7 @@ def _mgpu_wgmma_op_lowering_rule( b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout) if ir.VectorType.isinstance(wgmma_op.a.type): - a_operand = _fragmented_array_from_ir(wgmma_op.a) + a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) else: a_layout = ir.MemRefType(wgmma_op.a.type).layout a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 2f44ab649..520e83456 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -205,11 +205,14 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: # We first look at producers; this enables e.g. propagating splat layouts as # far down as possible, until since we may be able to propagate splat layouts # further down before requiring a relayout in that way. + all_inputs_have_layout = True for operand in op.operands: if not ir.VectorType.isinstance(operand.type): continue if (layout := _value_layout(operand)) is not None: layouts.add(layout) + else: + all_inputs_have_layout = False # We only look at consumers if we haven't found a possible layout yet. This is # to avoid propagating more complicated layouts up, to e.g. preserve splat @@ -232,6 +235,10 @@ def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: # This is left for a future change, and currently we only do "down # propagation". layout = _choose_representative_layout(layouts) + # It is unsafe to t conclude that this op produces a splat if not all inputs + # have been inferred: some of them might turn out not to be splats! + if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout: + return None if layout is None: return None diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 1466c6f4b..91debfe57 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -123,7 +123,8 @@ class LayoutInferenceTest(parameterized.TestCase): self.assertEmpty(c.attributes["in_layouts"]) self.assertSequenceEqual(c.attributes["out_layouts"], [layout]) - def test_infer_splat_layout_for_vector_splat(self): + @parameterized.parameters(True, False) + def test_infer_splat_layout_for_vector_splat(self, rhs_splat): add = splat = None def body(lhs, rhs): @@ -135,19 +136,24 @@ class LayoutInferenceTest(parameterized.TestCase): shape = (16, 8) elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) - func.FuncOp.from_py_func(elt_type, ty)(body) - - # Not setting any layouts on the module should default in all ops having a - # splat fragmented layout. - mgpu.infer_layout(self.module) + func_op = func.FuncOp.from_py_func(elt_type, ty)(body).func_op layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) + if rhs_splat: + func_op.attributes["in_layouts"] = ir.ArrayAttr.get([layout]) + mgpu.infer_layout(self.module) self.assertEmpty(splat.attributes["in_layouts"]) self.assertSequenceEqual(splat.attributes["out_layouts"], [layout]) - self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout]) + add_layout = layout + if not rhs_splat: + add_layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(ty) + ) + + self.assertSequenceEqual(add.attributes["in_layouts"], [add_layout, add_layout]) + self.assertSequenceEqual(add.attributes["out_layouts"], [add_layout]) @parameterized.parameters( mgpu.WGSplatFragLayout(shape=(32, 4)),