[Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.

This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.

PiperOrigin-RevId: 732618760
This commit is contained in:
Dimitar (Mitko) Asenov 2025-03-02 03:16:30 -08:00 committed by jax authors
parent c60ef5a2a1
commit 3b305c6617
4 changed files with 168 additions and 13 deletions

View File

@ -546,6 +546,9 @@ def _mgpu_wgmma_op_lowering_rule(
b_layout = ir.MemRefType(wgmma_op.b.type).layout
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
b_operand = transform_memref(wgmma_op.b, b_transforms)
if wgmma_op.transpose_b:
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
if ir.VectorType.isinstance(wgmma_op.a.type):
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
@ -558,13 +561,10 @@ def _mgpu_wgmma_op_lowering_rule(
f" {b_swizzle}"
)
a_operand = transform_memref(wgmma_op.a, a_transforms)
if wgmma_op.transpose_a:
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
new_acc = wgmma.wgmma(
acc,
a_operand,
transform_memref(wgmma_op.b, b_transforms),
swizzle=b_swizzle,
)
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)]

View File

@ -14,7 +14,8 @@
"""Layout inference pass for the MLIR Mosaic GPU dialect."""
from collections.abc import Callable
from collections.abc import Callable, Sequence
import dataclasses
import enum
from functools import partial
from typing import cast
@ -23,6 +24,7 @@ from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import vector
from . import fragmented_array as fa
@ -45,8 +47,8 @@ def _set_layout_attributes(
in_layouts: list[ir.Attribute],
out_layouts: list[ir.Attribute],
):
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)
def _choose_representative_layout(
@ -311,6 +313,114 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
return [layout], [layout]
@dataclasses.dataclass()
class WGMMATransforms:
swizzle: mgpu.SwizzlingMode
a_tile: tuple[int, ...]
a_transpose: bool
b_tile: tuple[int, ...]
b_transpose: bool
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
# Try tiling with all swizzling modes starting from the largest one.
for swizzle in [
mgpu.SwizzlingMode.k128ByteSwizzle,
mgpu.SwizzlingMode.k64ByteSwizzle,
mgpu.SwizzlingMode.k32ByteSwizzle,
]:
s = swizzle * 8 // bitwidth
if k % s == 0:
return WGMMATransforms(
swizzle=swizzle,
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
a_transpose=wgmma_op.transpose_a,
b_tile=(s, s),
b_transpose=wgmma_op.transpose_b,
)
raise ValueError(
"Could not infer layouts for memref feeding into WGMMA. The "
"non-contracting dimension ({k}) must be a multiple of "
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
"`a` and `b`."
)
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
wgmma_use = None
uses = cast(ir.OpResult, view_op.result).uses
for use in uses:
user = use.owner
if isinstance(user, memref.CastOp):
# This memref is already cast, so we don't need to do anything.
return None
if isinstance(user, mgpu.WGMMAOp):
if wgmma_use is not None:
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
wgmma_use = use
break
if (
not isinstance(user, mgpu.AsyncLoadOp)
and not isinstance(user, mgpu.AsyncStoreOp)
and not isinstance(user, vector.LoadOp)
and not isinstance(user, vector.StoreOp)
):
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
if wgmma_use is None:
# This memref is not used by a WGMMA operation, so we don't need to do
# anything.
return None
transforms = infer_wgmma_transforms(wgmma_use.owner)
if wgmma_use.operand_number == 1:
tile = transforms.a_tile
transpose = transforms.a_transpose
else:
tile = transforms.b_tile
transpose = transforms.b_transpose
transpose_attr = (
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
)
layout = mgpu.LayoutAttr.get(
2,
[mgpu.TileTransformAttr.get(tile)]
+ transpose_attr
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
)
return layout
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
owners = [use.owner for use in uses]
for region in regions:
for block in region:
for op in block:
if op in owners:
return op
raise ValueError("None of uses are in the given block")
def _insert_memref_layout_cast(layout: ir.Attribute, view_op: memref.ViewOp):
mem_ref_type = ir.MemRefType(view_op.result.type)
memref_new_type = ir.MemRefType.get(
mem_ref_type.shape,
mem_ref_type.element_type,
layout,
mem_ref_type.memory_space,
)
uses = list(view_op.result.uses)
with ir.InsertionPoint(_earliest_use(view_op.parent.regions, uses)):
cast_op = memref.cast(memref_new_type, view_op.result)
for use in uses:
use.owner.operands[use.operand_number] = cast_op
class TraversalOrder(enum.Enum):
"""Traversal orders with respect to the data flow for IR."""
@ -402,3 +512,11 @@ def infer_layout(module: ir.Module):
for op in module.body:
traverse_op(op, set_default_layout)
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
if op.name == "memref.view":
if layout := _layout_for_memref_view(op):
_insert_memref_layout_cast(layout, op)
for op in module.body:
traverse_op(op, infer_memref_layouts_and_insert_casts)

View File

@ -401,6 +401,8 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
Where:
- `s == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.)
and the tilings are [64, s] for `a` and [s, s] for `b`.
- `a` and/or `b` may be transposed if the corresponding attribute is set
to `true`.
The output has an identical shape and type as the input accumulator.
@ -423,7 +425,10 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
AnyTypeOf<[
MemRefOf<[MosaicGPU_WGMMASupportedType]>,
VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b,
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_b
);
let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>);

View File

@ -2701,6 +2701,8 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
shape_res: tuple[int, ...] = ()
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = ()
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
transpose_a: bool = False
transpose_b: bool = False
result = []
for swizzle in [
@ -2714,6 +2716,17 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
groups_n = 1
groups_k = 1
result.extend([
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_n * k, groups_k * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_b=True,
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
@ -2722,6 +2735,18 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
),
])
# The below only works for 128-byte swizzling. Regardless of transposing,
# TMA needs the size of the last dimension to be compatible with the
# swizzle.
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
result.append(
TestCaseInput(
shape_a=[groups_k * k, groups_m * 64],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_a=True,
)
)
return result
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16))
@ -2781,7 +2806,13 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
accumulator = vector.splat(
ir.VectorType.get(shape_result, result_elt_type), zero_acc
)
result = mgpu_dialect.wgmma(accumulator, a_smem_ref, b_smem_ref)
result = mgpu_dialect.wgmma(
accumulator,
a_smem_ref,
b_smem_ref,
transpose_a=test_case.transpose_a,
transpose_b=test_case.transpose_b,
)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
@ -2822,11 +2853,12 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype)
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype)
transpose = lambda x, t: x.T if t else x
self.assertArraysAllClose(
jax.jit(kernel)(x, y),
np.matmul(
x.reshape(test_case.shape_a),
y.reshape(test_case.shape_b),
transpose(x.reshape(test_case.shape_a), test_case.transpose_a),
transpose(y.reshape(test_case.shape_b), test_case.transpose_b),
),
atol=1e-5,
rtol=1e-5,