From ba2f7c9ad96c77a88c8cc7eb2d0fd859f517a43a Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Mar 2025 04:53:00 -0700 Subject: [PATCH] [Mosaic GPU] Add transform inference rule for `mgpu.slice_smem`. PiperOrigin-RevId: 737957778 --- .../mosaic/gpu/transform_inference.py | 29 +++++++- tests/mosaic/gpu_transform_inference_test.py | 72 +++++++++++++++++++ 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index a3919ea1d..be3f2c381 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -43,7 +43,8 @@ _transform_inference_rules: dict[str, TransformInferenceRule] = {} def _add_transform_inference_rule( op: type[ir.OpView], rule: TransformInferenceRule ): - _transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + if op is not None: + _transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error return rule @@ -169,6 +170,32 @@ def _infer_vector_load_store_transforms( return None +# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. +SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) + +@partial(_add_transform_inference_rule, SliceSMEMOp) +def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: + transforms = None + uses = cast(ir.OpResult, op.result).uses + + for op_operand_use in uses: + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + out_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + if transforms is not None and out_transforms is not None: + if transforms != out_transforms: + raise NotImplementedError( + f"Conflicting transforms for {op_user} in {op}: " + f"{transforms} != {out_transforms}." + ) + elif out_transforms is not None: + transforms = out_transforms + + return None if transforms is None else ([], [transforms]) + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index 2618c22ac..b7cd146df 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -346,6 +346,78 @@ class TransformInferenceTest(parameterized.TestCase): with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): mgpu.infer_transforms(self.module) + def test_infer_transforms_for_slice_smem_op_derives_from_user(self): + slice_smem_op = vector_load_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + def body(offset): + nonlocal slice_smem_op, vector_load_op + slice_smem_op = mgpu.dialect.SliceSMEMOp( + ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset + ) + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + load_offsets = [zero] * len(shape) + vector_load_op = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets + ) + + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body) + + vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] + ) + + mgpu.infer_transforms(self.module) + + expected_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + + self.assertEmpty(inference_utils.in_transforms(slice_smem_op)) + self.assertSequenceEqual( + inference_utils.out_transforms(slice_smem_op), [expected_transforms] + ) + + def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self): + slice_smem_op = vector_load_op1 = vector_load_op2 = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + def body(offset): + nonlocal slice_smem_op, vector_load_op1, vector_load_op2 + slice_smem_op = mgpu.dialect.SliceSMEMOp( + ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset + ) + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + load_offsets = [zero] * len(shape) + vector_load_op1 = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets + ) + vector_load_op2 = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets + ) + + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body) + + vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] + ) + vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get( + [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] + ) + vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get( + [ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])] + ) + + with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): + mgpu.infer_transforms(self.module) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader())