mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations. PiperOrigin-RevId: 732618760
This commit is contained in:
parent
c60ef5a2a1
commit
3b305c6617
@ -546,6 +546,9 @@ def _mgpu_wgmma_op_lowering_rule(
|
||||
|
||||
b_layout = ir.MemRefType(wgmma_op.b.type).layout
|
||||
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
|
||||
b_operand = transform_memref(wgmma_op.b, b_transforms)
|
||||
if wgmma_op.transpose_b:
|
||||
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
|
||||
|
||||
if ir.VectorType.isinstance(wgmma_op.a.type):
|
||||
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
|
||||
@ -558,13 +561,10 @@ def _mgpu_wgmma_op_lowering_rule(
|
||||
f" {b_swizzle}"
|
||||
)
|
||||
a_operand = transform_memref(wgmma_op.a, a_transforms)
|
||||
if wgmma_op.transpose_a:
|
||||
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
|
||||
|
||||
new_acc = wgmma.wgmma(
|
||||
acc,
|
||||
a_operand,
|
||||
transform_memref(wgmma_op.b, b_transforms),
|
||||
swizzle=b_swizzle,
|
||||
)
|
||||
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
|
||||
|
||||
return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)]
|
||||
|
||||
|
@ -14,7 +14,8 @@
|
||||
|
||||
"""Layout inference pass for the MLIR Mosaic GPU dialect."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
@ -23,6 +24,7 @@ from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
|
||||
from . import fragmented_array as fa
|
||||
@ -45,8 +47,8 @@ def _set_layout_attributes(
|
||||
in_layouts: list[ir.Attribute],
|
||||
out_layouts: list[ir.Attribute],
|
||||
):
|
||||
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
|
||||
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)
|
||||
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
|
||||
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)
|
||||
|
||||
|
||||
def _choose_representative_layout(
|
||||
@ -311,6 +313,114 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
|
||||
|
||||
return [layout], [layout]
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class WGMMATransforms:
|
||||
swizzle: mgpu.SwizzlingMode
|
||||
a_tile: tuple[int, ...]
|
||||
a_transpose: bool
|
||||
b_tile: tuple[int, ...]
|
||||
b_transpose: bool
|
||||
|
||||
|
||||
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
|
||||
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
|
||||
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
|
||||
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
|
||||
|
||||
# Try tiling with all swizzling modes starting from the largest one.
|
||||
for swizzle in [
|
||||
mgpu.SwizzlingMode.k128ByteSwizzle,
|
||||
mgpu.SwizzlingMode.k64ByteSwizzle,
|
||||
mgpu.SwizzlingMode.k32ByteSwizzle,
|
||||
]:
|
||||
s = swizzle * 8 // bitwidth
|
||||
if k % s == 0:
|
||||
return WGMMATransforms(
|
||||
swizzle=swizzle,
|
||||
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
|
||||
a_transpose=wgmma_op.transpose_a,
|
||||
b_tile=(s, s),
|
||||
b_transpose=wgmma_op.transpose_b,
|
||||
)
|
||||
raise ValueError(
|
||||
"Could not infer layouts for memref feeding into WGMMA. The "
|
||||
"non-contracting dimension ({k}) must be a multiple of "
|
||||
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
|
||||
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
|
||||
"`a` and `b`."
|
||||
)
|
||||
|
||||
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
|
||||
wgmma_use = None
|
||||
uses = cast(ir.OpResult, view_op.result).uses
|
||||
for use in uses:
|
||||
user = use.owner
|
||||
if isinstance(user, memref.CastOp):
|
||||
# This memref is already cast, so we don't need to do anything.
|
||||
return None
|
||||
if isinstance(user, mgpu.WGMMAOp):
|
||||
if wgmma_use is not None:
|
||||
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
|
||||
wgmma_use = use
|
||||
break
|
||||
if (
|
||||
not isinstance(user, mgpu.AsyncLoadOp)
|
||||
and not isinstance(user, mgpu.AsyncStoreOp)
|
||||
and not isinstance(user, vector.LoadOp)
|
||||
and not isinstance(user, vector.StoreOp)
|
||||
):
|
||||
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
|
||||
|
||||
if wgmma_use is None:
|
||||
# This memref is not used by a WGMMA operation, so we don't need to do
|
||||
# anything.
|
||||
return None
|
||||
|
||||
transforms = infer_wgmma_transforms(wgmma_use.owner)
|
||||
if wgmma_use.operand_number == 1:
|
||||
tile = transforms.a_tile
|
||||
transpose = transforms.a_transpose
|
||||
else:
|
||||
tile = transforms.b_tile
|
||||
transpose = transforms.b_transpose
|
||||
transpose_attr = (
|
||||
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
|
||||
)
|
||||
|
||||
layout = mgpu.LayoutAttr.get(
|
||||
2,
|
||||
[mgpu.TileTransformAttr.get(tile)]
|
||||
+ transpose_attr
|
||||
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
|
||||
)
|
||||
|
||||
return layout
|
||||
|
||||
|
||||
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
|
||||
owners = [use.owner for use in uses]
|
||||
for region in regions:
|
||||
for block in region:
|
||||
for op in block:
|
||||
if op in owners:
|
||||
return op
|
||||
raise ValueError("None of uses are in the given block")
|
||||
|
||||
|
||||
def _insert_memref_layout_cast(layout: ir.Attribute, view_op: memref.ViewOp):
|
||||
mem_ref_type = ir.MemRefType(view_op.result.type)
|
||||
memref_new_type = ir.MemRefType.get(
|
||||
mem_ref_type.shape,
|
||||
mem_ref_type.element_type,
|
||||
layout,
|
||||
mem_ref_type.memory_space,
|
||||
)
|
||||
uses = list(view_op.result.uses)
|
||||
with ir.InsertionPoint(_earliest_use(view_op.parent.regions, uses)):
|
||||
cast_op = memref.cast(memref_new_type, view_op.result)
|
||||
for use in uses:
|
||||
use.owner.operands[use.operand_number] = cast_op
|
||||
|
||||
|
||||
class TraversalOrder(enum.Enum):
|
||||
"""Traversal orders with respect to the data flow for IR."""
|
||||
@ -402,3 +512,11 @@ def infer_layout(module: ir.Module):
|
||||
|
||||
for op in module.body:
|
||||
traverse_op(op, set_default_layout)
|
||||
|
||||
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
|
||||
if op.name == "memref.view":
|
||||
if layout := _layout_for_memref_view(op):
|
||||
_insert_memref_layout_cast(layout, op)
|
||||
|
||||
for op in module.body:
|
||||
traverse_op(op, infer_memref_layouts_and_insert_casts)
|
||||
|
@ -401,6 +401,8 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
|
||||
Where:
|
||||
- `s == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.)
|
||||
and the tilings are [64, s] for `a` and [s, s] for `b`.
|
||||
- `a` and/or `b` may be transposed if the corresponding attribute is set
|
||||
to `true`.
|
||||
|
||||
The output has an identical shape and type as the input accumulator.
|
||||
|
||||
@ -423,7 +425,10 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
|
||||
AnyTypeOf<[
|
||||
MemRefOf<[MosaicGPU_WGMMASupportedType]>,
|
||||
VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
|
||||
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b
|
||||
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b,
|
||||
|
||||
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_a,
|
||||
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_b
|
||||
);
|
||||
let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>);
|
||||
|
||||
|
@ -2701,6 +2701,8 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
shape_res: tuple[int, ...] = ()
|
||||
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = ()
|
||||
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
|
||||
transpose_a: bool = False
|
||||
transpose_b: bool = False
|
||||
|
||||
result = []
|
||||
for swizzle in [
|
||||
@ -2714,6 +2716,17 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
groups_n = 1
|
||||
groups_k = 1
|
||||
result.extend([
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_n * k, groups_k * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transpose_b=True,
|
||||
),
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
@ -2722,6 +2735,18 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
|
||||
),
|
||||
])
|
||||
# The below only works for 128-byte swizzling. Regardless of transposing,
|
||||
# TMA needs the size of the last dimension to be compatible with the
|
||||
# swizzle.
|
||||
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
|
||||
result.append(
|
||||
TestCaseInput(
|
||||
shape_a=[groups_k * k, groups_m * 64],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transpose_a=True,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16))
|
||||
@ -2781,7 +2806,13 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
accumulator = vector.splat(
|
||||
ir.VectorType.get(shape_result, result_elt_type), zero_acc
|
||||
)
|
||||
result = mgpu_dialect.wgmma(accumulator, a_smem_ref, b_smem_ref)
|
||||
result = mgpu_dialect.wgmma(
|
||||
accumulator,
|
||||
a_smem_ref,
|
||||
b_smem_ref,
|
||||
transpose_a=test_case.transpose_a,
|
||||
transpose_b=test_case.transpose_b,
|
||||
)
|
||||
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(0)
|
||||
@ -2822,11 +2853,12 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype)
|
||||
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype)
|
||||
|
||||
transpose = lambda x, t: x.T if t else x
|
||||
self.assertArraysAllClose(
|
||||
jax.jit(kernel)(x, y),
|
||||
np.matmul(
|
||||
x.reshape(test_case.shape_a),
|
||||
y.reshape(test_case.shape_b),
|
||||
transpose(x.reshape(test_case.shape_a), test_case.transpose_a),
|
||||
transpose(y.reshape(test_case.shape_b), test_case.transpose_b),
|
||||
),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
|
Loading…
x
Reference in New Issue
Block a user