From 3b305c6617edf6cf0dba3a1f9db6027b7dd96d61 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Sun, 2 Mar 2025 03:16:30 -0800 Subject: [PATCH] [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op. This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations. PiperOrigin-RevId: 732618760 --- .../mosaic/gpu/dialect_lowering.py | 12 +- .../mosaic/gpu/layout_inference.py | 124 +++++++++++++++++- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 7 +- tests/mosaic/gpu_test.py | 38 +++++- 4 files changed, 168 insertions(+), 13 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index b99c7d385..3ca7a8571 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -546,6 +546,9 @@ def _mgpu_wgmma_op_lowering_rule( b_layout = ir.MemRefType(wgmma_op.b.type).layout b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout) + b_operand = transform_memref(wgmma_op.b, b_transforms) + if wgmma_op.transpose_b: + b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2)) if ir.VectorType.isinstance(wgmma_op.a.type): a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) @@ -558,13 +561,10 @@ def _mgpu_wgmma_op_lowering_rule( f" {b_swizzle}" ) a_operand = transform_memref(wgmma_op.a, a_transforms) + if wgmma_op.transpose_a: + a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2)) - new_acc = wgmma.wgmma( - acc, - a_operand, - transform_memref(wgmma_op.b, b_transforms), - swizzle=b_swizzle, - ) + new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)] diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index d5afeb69a..5971cfb85 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -14,7 +14,8 @@ """Layout inference pass for the MLIR Mosaic GPU dialect.""" -from collections.abc import Callable +from collections.abc import Callable, Sequence +import dataclasses import enum from functools import partial from typing import cast @@ -23,6 +24,7 @@ from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import scf +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector from . import fragmented_array as fa @@ -45,8 +47,8 @@ def _set_layout_attributes( in_layouts: list[ir.Attribute], out_layouts: list[ir.Attribute], ): - op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts) - op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts) + op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts) + op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts) def _choose_representative_layout( @@ -311,6 +313,114 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: return [layout], [layout] +@dataclasses.dataclass() +class WGMMATransforms: + swizzle: mgpu.SwizzlingMode + a_tile: tuple[int, ...] + a_transpose: bool + b_tile: tuple[int, ...] + b_transpose: bool + + +def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms: + a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape + k = a_shape[0] if wgmma_op.transpose_a else a_shape[1] + bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width + + # Try tiling with all swizzling modes starting from the largest one. + for swizzle in [ + mgpu.SwizzlingMode.k128ByteSwizzle, + mgpu.SwizzlingMode.k64ByteSwizzle, + mgpu.SwizzlingMode.k32ByteSwizzle, + ]: + s = swizzle * 8 // bitwidth + if k % s == 0: + return WGMMATransforms( + swizzle=swizzle, + a_tile=(s, 64) if wgmma_op.transpose_a else (64, s), + a_transpose=wgmma_op.transpose_a, + b_tile=(s, s), + b_transpose=wgmma_op.transpose_b, + ) + raise ValueError( + "Could not infer layouts for memref feeding into WGMMA. The " + "non-contracting dimension ({k}) must be a multiple of " + "s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle " + f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of " + "`a` and `b`." + ) + +def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None: + wgmma_use = None + uses = cast(ir.OpResult, view_op.result).uses + for use in uses: + user = use.owner + if isinstance(user, memref.CastOp): + # This memref is already cast, so we don't need to do anything. + return None + if isinstance(user, mgpu.WGMMAOp): + if wgmma_use is not None: + raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.") + wgmma_use = use + break + if ( + not isinstance(user, mgpu.AsyncLoadOp) + and not isinstance(user, mgpu.AsyncStoreOp) + and not isinstance(user, vector.LoadOp) + and not isinstance(user, vector.StoreOp) + ): + raise NotImplementedError(f"Unsupported user {user} of {view_op}.") + + if wgmma_use is None: + # This memref is not used by a WGMMA operation, so we don't need to do + # anything. + return None + + transforms = infer_wgmma_transforms(wgmma_use.owner) + if wgmma_use.operand_number == 1: + tile = transforms.a_tile + transpose = transforms.a_transpose + else: + tile = transforms.b_tile + transpose = transforms.b_transpose + transpose_attr = ( + [mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else [] + ) + + layout = mgpu.LayoutAttr.get( + 2, + [mgpu.TileTransformAttr.get(tile)] + + transpose_attr + + [mgpu.SwizzleTransformAttr.get(transforms.swizzle)], + ) + + return layout + + +def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView: + owners = [use.owner for use in uses] + for region in regions: + for block in region: + for op in block: + if op in owners: + return op + raise ValueError("None of uses are in the given block") + + +def _insert_memref_layout_cast(layout: ir.Attribute, view_op: memref.ViewOp): + mem_ref_type = ir.MemRefType(view_op.result.type) + memref_new_type = ir.MemRefType.get( + mem_ref_type.shape, + mem_ref_type.element_type, + layout, + mem_ref_type.memory_space, + ) + uses = list(view_op.result.uses) + with ir.InsertionPoint(_earliest_use(view_op.parent.regions, uses)): + cast_op = memref.cast(memref_new_type, view_op.result) + for use in uses: + use.owner.operands[use.operand_number] = cast_op + class TraversalOrder(enum.Enum): """Traversal orders with respect to the data flow for IR.""" @@ -402,3 +512,11 @@ def infer_layout(module: ir.Module): for op in module.body: traverse_op(op, set_default_layout) + + def infer_memref_layouts_and_insert_casts(op: ir.OpView): + if op.name == "memref.view": + if layout := _layout_for_memref_view(op): + _insert_memref_layout_cast(layout, op) + + for op in module.body: + traverse_op(op, infer_memref_layouts_and_insert_casts) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 48c6a0464..964039d82 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -401,6 +401,8 @@ def MosaicGPU_WGMMAOp : Op { Where: - `s == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.) and the tilings are [64, s] for `a` and [s, s] for `b`. + - `a` and/or `b` may be transposed if the corresponding attribute is set + to `true`. The output has an identical shape and type as the input accumulator. @@ -423,7 +425,10 @@ def MosaicGPU_WGMMAOp : Op { AnyTypeOf<[ MemRefOf<[MosaicGPU_WGMMASupportedType]>, VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a, - MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b + MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b, + + DefaultValuedOptionalAttr:$transpose_a, + DefaultValuedOptionalAttr:$transpose_b ); let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>); diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index dd6161bd4..a7148aa57 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2701,6 +2701,8 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): shape_res: tuple[int, ...] = () transforms_a: tuple[Tile | Transpose | Swizzle, ...] = () transforms_b: tuple[Tile | Transpose | Swizzle, ...] = () + transpose_a: bool = False + transpose_b: bool = False result = [] for swizzle in [ @@ -2714,6 +2716,17 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): groups_n = 1 groups_k = 1 result.extend([ + TestCaseInput( + shape_a=[groups_m * 64, groups_k * k], + shape_b=[groups_k * k, groups_n * k], + shape_res=[groups_m * 64, groups_n * k], + ), + TestCaseInput( + shape_a=[groups_m * 64, groups_k * k], + shape_b=[groups_n * k, groups_k * k], + shape_res=[groups_m * 64, groups_n * k], + transpose_b=True, + ), TestCaseInput( shape_a=[groups_m * 64, groups_k * k], shape_b=[groups_k * k, groups_n * k], @@ -2722,6 +2735,18 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): transforms_b=[Tile([k, k]), Swizzle(swizzle)], ), ]) + # The below only works for 128-byte swizzling. Regardless of transposing, + # TMA needs the size of the last dimension to be compatible with the + # swizzle. + if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle: + result.append( + TestCaseInput( + shape_a=[groups_k * k, groups_m * 64], + shape_b=[groups_k * k, groups_n * k], + shape_res=[groups_m * 64, groups_n * k], + transpose_a=True, + ) + ) return result @parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16)) @@ -2781,7 +2806,13 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): accumulator = vector.splat( ir.VectorType.get(shape_result, result_elt_type), zero_acc ) - result = mgpu_dialect.wgmma(accumulator, a_smem_ref, b_smem_ref) + result = mgpu_dialect.wgmma( + accumulator, + a_smem_ref, + b_smem_ref, + transpose_a=test_case.transpose_a, + transpose_b=test_case.transpose_b, + ) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) @@ -2822,11 +2853,12 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype) y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype) + transpose = lambda x, t: x.T if t else x self.assertArraysAllClose( jax.jit(kernel)(x, y), np.matmul( - x.reshape(test_case.shape_a), - y.reshape(test_case.shape_b), + transpose(x.reshape(test_case.shape_a), test_case.transpose_a), + transpose(y.reshape(test_case.shape_b), test_case.transpose_b), ), atol=1e-5, rtol=1e-5,