[Mosaic GPU] Add transform inference rule for mgpu.slice_smem.

PiperOrigin-RevId: 737957778
This commit is contained in:
Benjamin Chetioui 2025-03-18 04:53:00 -07:00 committed by jax authors
parent d4bd2570ae
commit ba2f7c9ad9
2 changed files with 100 additions and 1 deletions

View File

@ -43,7 +43,8 @@ _transform_inference_rules: dict[str, TransformInferenceRule] = {}
def _add_transform_inference_rule(
op: type[ir.OpView], rule: TransformInferenceRule
):
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
if op is not None:
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
return rule
@ -169,6 +170,32 @@ def _infer_vector_load_store_transforms(
return None
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
@partial(_add_transform_inference_rule, SliceSMEMOp)
def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
transforms = None
uses = cast(ir.OpResult, op.result).uses
for op_operand_use in uses:
consumer = op_operand_use.owner
op_user = consumer.operands[op_operand_use.operand_number]
out_transforms = inference_utils.in_transforms_for_operand(
consumer, op_user
)
if transforms is not None and out_transforms is not None:
if transforms != out_transforms:
raise NotImplementedError(
f"Conflicting transforms for {op_user} in {op}: "
f"{transforms} != {out_transforms}."
)
elif out_transforms is not None:
transforms = out_transforms
return None if transforms is None else ([], [transforms])
def _should_have_transforms(op: ir.OpView) -> bool:
"""Returns 'True' if the operation should be assigned in/out transforms."""
return any(

View File

@ -346,6 +346,78 @@ class TransformInferenceTest(parameterized.TestCase):
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
mgpu.infer_transforms(self.module)
def test_infer_transforms_for_slice_smem_op_derives_from_user(self):
slice_smem_op = vector_load_op = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
def body(offset):
nonlocal slice_smem_op, vector_load_op
slice_smem_op = mgpu.dialect.SliceSMEMOp(
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
)
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
load_offsets = [zero] * len(shape)
vector_load_op = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
)
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
)
mgpu.infer_transforms(self.module)
expected_transforms = ir.ArrayAttr.get([
mgpu.dialect.TileTransformAttr.get((8, 64)),
mgpu.dialect.SwizzleTransformAttr.get(128),
])
self.assertEmpty(inference_utils.in_transforms(slice_smem_op))
self.assertSequenceEqual(
inference_utils.out_transforms(slice_smem_op), [expected_transforms]
)
def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self):
slice_smem_op = vector_load_op1 = vector_load_op2 = None
shape = (64, 64)
elt_ty = ir.BF16Type.get()
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
def body(offset):
nonlocal slice_smem_op, vector_load_op1, vector_load_op2
slice_smem_op = mgpu.dialect.SliceSMEMOp(
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
)
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
load_offsets = [zero] * len(shape)
vector_load_op1 = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
)
vector_load_op2 = vector.LoadOp(
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
)
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
)
vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get(
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
)
vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get(
[ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])]
)
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
mgpu.infer_transforms(self.module)
if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())