mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Add transform inference rule for mgpu.slice_smem
.
PiperOrigin-RevId: 737957778
This commit is contained in:
parent
d4bd2570ae
commit
ba2f7c9ad9
@ -43,7 +43,8 @@ _transform_inference_rules: dict[str, TransformInferenceRule] = {}
|
||||
def _add_transform_inference_rule(
|
||||
op: type[ir.OpView], rule: TransformInferenceRule
|
||||
):
|
||||
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
|
||||
if op is not None:
|
||||
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
|
||||
return rule
|
||||
|
||||
|
||||
@ -169,6 +170,32 @@ def _infer_vector_load_store_transforms(
|
||||
return None
|
||||
|
||||
|
||||
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
|
||||
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
|
||||
|
||||
@partial(_add_transform_inference_rule, SliceSMEMOp)
|
||||
def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
|
||||
transforms = None
|
||||
uses = cast(ir.OpResult, op.result).uses
|
||||
|
||||
for op_operand_use in uses:
|
||||
consumer = op_operand_use.owner
|
||||
op_user = consumer.operands[op_operand_use.operand_number]
|
||||
out_transforms = inference_utils.in_transforms_for_operand(
|
||||
consumer, op_user
|
||||
)
|
||||
if transforms is not None and out_transforms is not None:
|
||||
if transforms != out_transforms:
|
||||
raise NotImplementedError(
|
||||
f"Conflicting transforms for {op_user} in {op}: "
|
||||
f"{transforms} != {out_transforms}."
|
||||
)
|
||||
elif out_transforms is not None:
|
||||
transforms = out_transforms
|
||||
|
||||
return None if transforms is None else ([], [transforms])
|
||||
|
||||
|
||||
def _should_have_transforms(op: ir.OpView) -> bool:
|
||||
"""Returns 'True' if the operation should be assigned in/out transforms."""
|
||||
return any(
|
||||
|
@ -346,6 +346,78 @@ class TransformInferenceTest(parameterized.TestCase):
|
||||
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
def test_infer_transforms_for_slice_smem_op_derives_from_user(self):
|
||||
slice_smem_op = vector_load_op = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
|
||||
def body(offset):
|
||||
nonlocal slice_smem_op, vector_load_op
|
||||
slice_smem_op = mgpu.dialect.SliceSMEMOp(
|
||||
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
|
||||
)
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
load_offsets = [zero] * len(shape)
|
||||
vector_load_op = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
|
||||
|
||||
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||
)
|
||||
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
expected_transforms = ir.ArrayAttr.get([
|
||||
mgpu.dialect.TileTransformAttr.get((8, 64)),
|
||||
mgpu.dialect.SwizzleTransformAttr.get(128),
|
||||
])
|
||||
|
||||
self.assertEmpty(inference_utils.in_transforms(slice_smem_op))
|
||||
self.assertSequenceEqual(
|
||||
inference_utils.out_transforms(slice_smem_op), [expected_transforms]
|
||||
)
|
||||
|
||||
def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self):
|
||||
slice_smem_op = vector_load_op1 = vector_load_op2 = None
|
||||
shape = (64, 64)
|
||||
elt_ty = ir.BF16Type.get()
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
|
||||
def body(offset):
|
||||
nonlocal slice_smem_op, vector_load_op1, vector_load_op2
|
||||
slice_smem_op = mgpu.dialect.SliceSMEMOp(
|
||||
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
|
||||
)
|
||||
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
load_offsets = [zero] * len(shape)
|
||||
vector_load_op1 = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||
)
|
||||
vector_load_op2 = vector.LoadOp(
|
||||
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
|
||||
|
||||
vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
|
||||
)
|
||||
vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
|
||||
)
|
||||
vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get(
|
||||
[ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])]
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
|
||||
mgpu.infer_transforms(self.module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user