From 875099b25dd6dea5353cf6a38b8ccd0ebaf0a0ed Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 18 Mar 2025 11:50:58 -0700 Subject: [PATCH] [Mosaic GPU] Enable the new transform inference pass in the warpgroup lowering. A couple of dummy transform inference rules needed to be added in order to contend with parts of the lowering that do not use the dialect yet, along with a transform inference rule for `memref.view`. PiperOrigin-RevId: 738089782 --- jax/experimental/mosaic/gpu/core.py | 22 +- .../mosaic/gpu/dialect_lowering.py | 194 ++++++++--- .../mosaic/gpu/layout_inference.py | 91 ------ .../mosaic/gpu/transform_inference.py | 59 +++- tests/mosaic/gpu_test.py | 306 +++++++----------- 5 files changed, 324 insertions(+), 348 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 66e19bb5f..b255893e2 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -41,20 +41,16 @@ from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm import numpy as np -if dialect is not None: - from . import dialect_lowering - from . import layout_inference -else: - dialect_lowering = None - layout_inference = None - -from . import profiler -from . import utils -from . import launch_context -from . import tcgen05 - # mypy: ignore-errors +from . import dialect_lowering +from . import launch_context +from . import layout_inference +from . import profiler +from . import tcgen05 +from . import transform_inference +from . import utils + # MLIR can't find libdevice unless we point it to the CUDA path # TODO(apaszke): Unify with jax._src.lib.cuda_path CUDA_ROOT = "/usr/local/cuda" @@ -584,6 +580,7 @@ def as_gpu_kernel( # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error + transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error _initialize_scratch(launch_ctx, scratch_arr) @@ -666,6 +663,7 @@ def as_torch_gpu_kernel( # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error + transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error _initialize_scratch(launch_ctx, scratch_arr) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 8098d14f0..fedde5a00 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -17,6 +17,7 @@ from collections.abc import Callable import dataclasses import functools +import itertools import operator from typing import Any, Sequence, Type, cast @@ -58,7 +59,7 @@ class LoweringContext: if not _should_lower(op): return - if (name := op.OPERATION_NAME) not in _lowerings: + if (name := op.OPERATION_NAME) not in _lowerings: # pytype: disable=attribute-error raise NotImplementedError(f"Missing lowering rule for {op}") lowering_rule = _lowerings[name] @@ -227,6 +228,60 @@ def _arith_constant_op_lowering_rule( ] +def _check_transforms_and_swizzle_are_supported( + ref_ty: ir.MemRefType, + transforms: Sequence[launch_context.MemRefTransform], + swizzle: mgpu.SwizzlingMode, + minimum_swizzle: mgpu.SwizzlingMode = mgpu.SwizzlingMode.kNoSwizzle, +): + """Checks that the list of provided transforms and swizzle are supported. + + Currently, we allow the following: + - any swizzle that is larger than or equal to `minimum_swizzle`; + - optionally, a single tile transform (with rank equal to the rank of the + memref being annotated); + - optionally, a single transpose transform. + """ + if swizzle < minimum_swizzle: + raise NotImplementedError( + f"Unsupported swizzle {swizzle} smaller than {minimum_swizzle}." + ) + + partitioned_transforms = { + k: list(v) + for k, v in itertools.groupby( + transforms, lambda t: isinstance(t, launch_context.TileTransform) + ) + } + + tile_transforms = partitioned_transforms.get(True, []) + other_transforms = partitioned_transforms.get(False, []) + + if len(tile_transforms) > 1: + raise NotImplementedError( + f"{tile_transforms} contains more than one tile transform." + ) + + if len(tile_transforms) == 1: + if len(tile_transforms[0].tiling) != len(ref_ty.shape): + raise NotImplementedError( + f"Only tile transforms with rank equal to the rank of the memref " + f"being annotated are supported but got {tile_transforms[0]} for " + f"{ref_ty}." + ) + + if len(other_transforms) > 1: + raise NotImplementedError( + f"{other_transforms} contains more than one transform." + ) + + if len(other_transforms) == 1: + if not isinstance(other_transforms[0], launch_context.TransposeTransform): + raise NotImplementedError( + f"{other_transforms[0]} is not a transpose transform." + ) + + @_register_lowering(vector.LoadOp) def _vector_load_op_lowering_rule( _: LoweringContext, vector_load_op: vector.LoadOp @@ -260,8 +315,11 @@ def _vector_load_op_lowering_rule( vec_size=strided_layout.vec_size, ) elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT: - layout = ir.MemRefType(vector_load_op.base.type).layout - swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + inference_utils.in_transforms(vector_load_op)[0] + ) + ref_ty = ir.MemRefType(vector_load_op.base.type) + _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) transformed_ref = transform_memref(vector_load_op.base, transforms) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, @@ -297,8 +355,22 @@ def _vector_store_op_lowering_rule( vector_store_op.valueToStore, to_store_layout ) - # TODO(dasenov): This is not efficient for WGMMA layouts - fragmented_array.store_untiled(vector_store_op.base) + if fragmented_array.layout == fa.WGMMA_LAYOUT: + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + inference_utils.in_transforms(vector_store_op)[0] + ) + ref_ty = ir.MemRefType(vector_store_op.base.type) + _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) + fragmented_array.store_tiled( + transform_memref(vector_store_op.base, transforms), swizzle + ) + elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or + isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): + fragmented_array.store_untiled(vector_store_op.base) + else: + raise ValueError( + f"{vector_store_op} has an unsupported layout: {to_store_layout}" + ) return [] @@ -362,39 +434,43 @@ def _vector_reduction_op_lowering_rule( return [_fragmented_array_to_ir(result, op.result.type)] -def memref_layout_to_swizzle_and_transforms( - layout: ir.Attribute, +def swizzle_and_transforms_from_transforms_attr( + transforms: ir.ArrayAttr, ) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]: - """Returns the swizzle and transforms that are encoded in the given layout. + """Returns the swizzle and MemrefTransforms for the given transforms. - If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the - transforms are empty. Otherwise, the layout may have at most one swizzle - transform and any combination of tiling and transpose transforms. + Args: + transforms: a list of transform attributes. + + Returns: + A tuple containing the swizzle mode and MemRefTransforms corresponding to + the parameter transforms. If `transforms` is empty, or does not contain + any swizzling transform, the swizzle mode is assumed to be kNoSwizzle. + Raises: + ValueError: if a swizzling transform is followed by any transform. """ swizzle = None gmem_transforms: list[launch_context.MemRefTransform] = [] - if mgpu.LayoutAttr.isinstance(layout): - transforms_attr = mgpu.LayoutAttr(layout).transforms - for transform in transforms_attr: - if swizzle is not None: - raise ValueError(f"{layout} contains more transforms after the initial swizzle.") - if mgpu.SwizzleTransformAttr.isinstance(transform): - # TODO(dasenov): Swizzling can change if the ref is sliced in certain - # ways. We might want to enforce some restrictions here. - swizzle = mgpu.SwizzleTransformAttr(transform).swizzle - elif mgpu.TileTransformAttr.isinstance(transform): - tiling = mgpu.TileTransformAttr(transform).tiling - tiling_transform = launch_context.TileTransform(tuple(tiling)) - gmem_transforms.append(tiling_transform) - elif mgpu.TransposeTransformAttr.isinstance(transform): - permutation = mgpu.TransposeTransformAttr(transform).permutation - transpose_transform = launch_context.TransposeTransform( - tuple(permutation) - ) - gmem_transforms.append(transpose_transform) - else: - raise ValueError(f"{layout} has an unsupported transform: {transform}") + for transform in transforms: + if swizzle is not None: + raise ValueError(f"{transforms} contain more transforms after swizzle.") + if mgpu.SwizzleTransformAttr.isinstance(transform): + # TODO(dasenov): Swizzling can change if the ref is sliced in certain + # ways. We might want to enforce some restrictions here. + swizzle = mgpu.SwizzleTransformAttr(transform).swizzle + elif mgpu.TileTransformAttr.isinstance(transform): + tiling = mgpu.TileTransformAttr(transform).tiling + tiling_transform = launch_context.TileTransform(tuple(tiling)) + gmem_transforms.append(tiling_transform) + elif mgpu.TransposeTransformAttr.isinstance(transform): + permutation = mgpu.TransposeTransformAttr(transform).permutation + transpose_transform = launch_context.TransposeTransform( + tuple(permutation) + ) + gmem_transforms.append(transpose_transform) + else: + raise ValueError("Unknown transform: {transform}") return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) @@ -434,8 +510,14 @@ def _mgpu_async_load_op_lowering_rule( assert ctx.launch_context is not None barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier) - dst_layout = ir.MemRefType(load_op.destination.type).layout - swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout) + if inference_utils.has_in_transforms_set(load_op): + [transforms] = inference_utils.in_transforms(load_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms + ) + else: + swizzle = mgpu.SwizzlingMode.kNoSwizzle + transforms = () gmem_slice = [] for idx_i32, size in zip(load_op.indices, load_op.slice_lengths): @@ -464,8 +546,14 @@ def _mgpu_async_store_op_lowering_rule( ) -> Sequence[ir.Value]: assert ctx.launch_context is not None - src_layout = ir.MemRefType(store_op.source.type).layout - swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout) + if inference_utils.has_in_transforms_set(store_op): + [transforms] = inference_utils.in_transforms(store_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms + ) + else: + swizzle = mgpu.SwizzlingMode.kNoSwizzle + transforms = () gmem_slice = [] for idx_i32, size in zip(store_op.indices, store_op.slice_lengths): @@ -673,6 +761,9 @@ def _bitcast_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: + if wgmma_op.transpose_a or wgmma_op.transpose_b: + raise ValueError("Transpose arguments are to be deleted.") + fa_layouts = ( *inference_utils.in_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op), @@ -691,25 +782,38 @@ def _mgpu_wgmma_op_lowering_rule( regs = acc_in.to_layout(fa.WGMMA_LAYOUT) acc = wgmma.WGMMAAccumulator.from_registers(regs) - b_layout = ir.MemRefType(wgmma_op.b.type).layout - b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout) + if ir.VectorType.isinstance(wgmma_op.a.type): + a_transforms = None + b_transforms = inference_utils.in_transforms(wgmma_op)[0] + else: + a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op) + + b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr( + b_transforms + ) + minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle + ref_ty = ir.MemRefType(wgmma_op.b.type) + _check_transforms_and_swizzle_are_supported( + ref_ty, b_transforms, b_swizzle, minimum_swizzle + ) b_operand = transform_memref(wgmma_op.b, b_transforms) - if wgmma_op.transpose_b: - b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2)) if ir.VectorType.isinstance(wgmma_op.a.type): a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) else: - a_layout = ir.MemRefType(wgmma_op.a.type).layout - a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout) + a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr( + a_transforms + ) + ref_ty = ir.MemRefType(wgmma_op.a.type) + _check_transforms_and_swizzle_are_supported( + ref_ty, a_transforms, a_swizzle, minimum_swizzle + ) if a_swizzle != b_swizzle: raise ValueError( f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f" {b_swizzle}" ) a_operand = transform_memref(wgmma_op.a, a_transforms) - if wgmma_op.transpose_a: - a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2)) new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) @@ -902,7 +1006,7 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]: def _should_lower(op: ir.OpView) -> bool: """Returns 'true' if the operation should be lowered.""" return ( - op.OPERATION_NAME.startswith("mosaic_gpu.") + op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error or inference_utils.should_have_layout(op) or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? ) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index c9479e0f1..0d2811bb5 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -383,89 +383,6 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: return [layout], [layout] -@dataclasses.dataclass() -class WGMMATransforms: - swizzle: mgpu.SwizzlingMode - a_tile: tuple[int, ...] - a_transpose: bool - b_tile: tuple[int, ...] - b_transpose: bool - - -def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms: - a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape - k = a_shape[0] if wgmma_op.transpose_a else a_shape[1] - bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width - - # Try tiling with all swizzling modes starting from the largest one. - for swizzle in [ - mgpu.SwizzlingMode.k128ByteSwizzle, - mgpu.SwizzlingMode.k64ByteSwizzle, - mgpu.SwizzlingMode.k32ByteSwizzle, - ]: - s = swizzle * 8 // bitwidth - if k % s == 0: - return WGMMATransforms( - swizzle=swizzle, - a_tile=(s, 64) if wgmma_op.transpose_a else (64, s), - a_transpose=wgmma_op.transpose_a, - b_tile=(s, s), - b_transpose=wgmma_op.transpose_b, - ) - raise ValueError( - "Could not infer layouts for memref feeding into WGMMA. The " - "non-contracting dimension ({k}) must be a multiple of " - "s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle " - f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of " - "`a` and `b`." - ) - -def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None: - wgmma_use = None - uses = cast(ir.OpResult, view_op.result).uses - for use in uses: - user = use.owner - if isinstance(user, memref.CastOp): - # This memref is already cast, so we don't need to do anything. - return None - if isinstance(user, mgpu.WGMMAOp): - if wgmma_use is not None: - raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.") - wgmma_use = use - break - if ( - not isinstance(user, mgpu.AsyncLoadOp) - and not isinstance(user, mgpu.AsyncStoreOp) - and not isinstance(user, vector.LoadOp) - and not isinstance(user, vector.StoreOp) - ): - raise NotImplementedError(f"Unsupported user {user} of {view_op}.") - - if wgmma_use is None: - # This memref is not used by a WGMMA operation, so we don't need to do - # anything. - return None - - transforms = infer_wgmma_transforms(wgmma_use.owner) - if wgmma_use.operand_number == 1: - tile = transforms.a_tile - transpose = transforms.a_transpose - else: - tile = transforms.b_tile - transpose = transforms.b_transpose - transpose_attr = ( - [mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else [] - ) - - layout = mgpu.LayoutAttr.get( - 2, - [mgpu.TileTransformAttr.get(tile)] - + transpose_attr - + [mgpu.SwizzleTransformAttr.get(transforms.swizzle)], - ) - - return layout - def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView: owners = [use.owner for use in uses] @@ -607,11 +524,3 @@ def infer_layout(module: ir.Module): for op in module.body: traverse_op(op, set_default_layout) - - def infer_memref_layouts_and_insert_casts(op: ir.OpView): - if op.name == "memref.view": - if layout := _layout_for_memref_view(op): - _insert_memref_layout_cast(layout, op) - - for op in module.body: - traverse_op(op, infer_memref_layouts_and_insert_casts) diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index be3f2c381..ef2d36616 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -26,6 +26,9 @@ from typing import cast from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import builtin +from jax._src.lib.mlir.dialects import gpu +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector from . import fragmented_array as fa @@ -169,7 +172,6 @@ def _infer_vector_load_store_transforms( return None - # TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) @@ -196,6 +198,60 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: return None if transforms is None else ([], [transforms]) +# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use +# the dialect in all cases. +# The rule is necessary in order to handle the lowering of `utils.memref_ptr` +# which is used in `_construct_smem_reftree`. +@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp) +def _infer_unrealized_conversion_cast_transforms( + _: builtin.UnrealizedConversionCastOp, +) -> OptionalTransforms: + return None + + +@partial(_add_transform_inference_rule, memref.ViewOp) +def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: + if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp): + raise NotImplementedError( + "Memref view transforms are only inferred when the op is a direct user " + f"of a DynamicSharedMemoryOp but got {op}." + ) + transforms = inference_utils.value_transforms(op.source) + if transforms is not None: + raise NotImplementedError( + "memref view with in_transforms aren't yet supported" + ) + uses = cast(ir.OpResult, op.result).uses + + for op_operand_use in uses: + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + out_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + if transforms is not None and out_transforms is not None: + if transforms != out_transforms: + raise ValueError( + f"Conflicting transforms for {op_user} in {op}: " + f"{transforms} != {out_transforms}." + ) + elif out_transforms is not None: + transforms = out_transforms + + # TODO(bchetioui): do we actually need to assign a transform to the input of + # the view op? Presumably, it'll only be used to access scratch memory. + return None if transforms is None else ([], [transforms]) + + +# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use +# the dialect in all cases. +@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp) +def _infer_dynamic_smem_transforms( + _: gpu.DynamicSharedMemoryOp, +) -> OptionalTransforms: + return None + + def _should_have_transforms(op: ir.OpView) -> bool: """Returns 'True' if the operation should be assigned in/out transforms.""" return any( @@ -218,7 +274,6 @@ def infer_transforms(module: ir.Module): specified. We error out if two distinct sets of transforms are competing to annotate the same memref. """ - def inference_step(op: ir.Operation): if not _should_have_transforms(op): return diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index bc56f21d0..1fcd68641 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -55,6 +55,7 @@ else: from jax.experimental.mosaic.gpu import launch_context from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import profiler + from jax.experimental.mosaic.gpu import inference_utils from jax.experimental.mosaic.gpu.utils import * # noqa: F403 from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm @@ -2405,25 +2406,21 @@ class Swizzle: return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle) -def memref_with_transforms( - mem_ref: ir.Value, - transforms: Sequence[Tile | Transpose | Swizzle], -) -> ir.Value: - """Casts the memref to one that has a layout with the given transforms.""" - mem_ref_type = ir.MemRefType(mem_ref.type) +def set_in_transforms( + op: ir.OpView, transforms: Sequence[Sequence[Tile | Transpose | Swizzle]], +) -> None: + """Annotates an op with in_transforms.""" + if not transforms: + return - transform_attr = [t.attr() for t in transforms] - if not transform_attr: - return mem_ref + in_transforms = [] + smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable + for _, result_transforms in jax.util.safe_zip(smem_refs, transforms): + in_transforms.append( + ir.ArrayAttr.get([t.attr() for t in result_transforms]) + ) - layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr) - memref_new_type = ir.MemRefType.get( - mem_ref_type.shape, - mem_ref_type.element_type, - layout, - mem_ref_type.memory_space, - ) - return memref.cast(memref_new_type, mem_ref) + op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms) class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): @@ -2556,7 +2553,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): ): del ctx smem_ref, tma_barrier = smem - smem_ref = memref_with_transforms(smem_ref, test_case.transforms) dialect_barrier = tma_barrier.as_dialect_barrier_memref() elt_type = ir.MemRefType(in_gmem_ref.type).element_type @@ -2571,7 +2567,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices] # GMEM -> SMEM - mgpu_dialect.async_load( + load_op = mgpu_dialect.AsyncLoadOp( source=in_gmem_ref, destination=smem_ref, barrier=dialect_barrier, @@ -2579,6 +2575,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): slice_lengths=test_case.slice_lengths, collective=ir.ArrayAttr.get([]), ) + set_in_transforms(load_op, [test_case.transforms]) parities = memref.load(tma_barrier.phases, []) parity, _ = tma_barrier.update_parities(parities) @@ -2623,58 +2620,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): (x[input_slice]).reshape(test_case.shape_sliced), ) - @staticmethod - def pointwise_kernel_with_tma_cases(dtype: jnp.dtype): - @dataclasses.dataclass(frozen=True) - class TestCaseInput: - shape: tuple[int, ...] - transforms: tuple[Tile | Transpose | Swizzle, ...] = () - - result = [] - for swizzle in mgpu_dialect.SwizzlingMode: - n = swizzle * 8 // jnp.finfo(dtype).bits - if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: - # We need at least one case with no transforms, as this is handled - # differently. - result.append(TestCaseInput(shape=[128, n])) - result.extend([ - TestCaseInput( - shape=[128, n], - transforms=[Swizzle(swizzle)], - ), - TestCaseInput( - shape=[2, 3, 64, n], - transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)], - ), - TestCaseInput( - shape=[2, 3, 64, n], - transforms=[ - Transpose([1, 0, 2, 3]), - Transpose([1, 0, 2, 3]), - Swizzle(swizzle), - ], - ), - TestCaseInput( - shape=[2, 3, 64, n], - transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)], - ), - TestCaseInput( - shape=[128, n], - transforms=[Tile([64, n]), Swizzle(swizzle)], - ), - TestCaseInput( - shape=[2 * 64, 3 * n], - transforms=[ - Tile([64, n]), - Transpose([1, 0, 2, 3]), - Swizzle(swizzle), - ], - ), - ]) - return result - - @parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16)) - def test_pointwise_kernel_with_tma(self, test_case): + def test_pointwise_kernel_with_tma(self): def add( ctx: launch_context.LaunchContext, a_gmem_ref: ir.Value, @@ -2701,9 +2647,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): # GMEM -> SMEM mgpu_dialect.async_load( source=a_gmem_ref, - destination=memref_with_transforms( - a_smem_ref, test_case.transforms - ), + destination=a_smem_ref, barrier=dialect_barrier, indices=zero_slice_indices, slice_lengths=shape, @@ -2711,9 +2655,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): ) mgpu_dialect.async_load( source=b_gmem_ref, - destination=memref_with_transforms( - b_smem_ref, test_case.transforms - ), + destination=b_smem_ref, barrier=dialect_barrier, indices=zero_slice_indices, slice_lengths=shape, @@ -2740,9 +2682,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): # SMEM -> GMEM mgpu_dialect.async_store( - source=memref_with_transforms( - result_smem_ref, test_case.transforms - ), + source=result_smem_ref, destination=result_gmem_ref, indices=zero_slice_indices, slice_lengths=shape, @@ -2752,114 +2692,76 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): dtype = jnp.bfloat16 - jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype) + spec = jax.ShapeDtypeStruct((2, 3, 4, 64), dtype) kernel = mgpu.as_gpu_kernel( add, grid=(1, 1, 1), block=(128, 1, 1), - in_shape=(jax_shape, jax_shape), - out_shape=jax_shape, + in_shape=(spec, spec), + out_shape=spec, smem_scratch_shape=[ - jax_shape, - jax_shape, - jax_shape, + spec, + spec, + spec, core.TMABarrier(1), ], thread_semantics=mgpu.ThreadSemantics.Warpgroup, ) - x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) - y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) + x = self.prng.uniform(-1, 1, spec.shape).astype(dtype) + y = self.prng.uniform(-1, 1, spec.shape).astype(dtype) self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y) class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): - @staticmethod - def wgmma_kernel_with_tma_cases(abtype: jnp.dtype): - @dataclasses.dataclass(frozen=True) - class TestCaseInput: - shape_a: tuple[int, ...] = () - shape_b: tuple[int, ...] = () - shape_res: tuple[int, ...] = () - transforms_a: tuple[Tile | Transpose | Swizzle, ...] = () - transforms_b: tuple[Tile | Transpose | Swizzle, ...] = () - transpose_a: bool = False - transpose_b: bool = False - load_a_in_registers: bool = False + @parameterized.named_parameters( + ( + f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}", + swizzle, + transpose_lhs, + transpose_rhs, + lhs_in_registers, + ) + for swizzle in mgpu_dialect.SwizzlingMode + for transpose_lhs in [False, True] + for transpose_rhs in [False, True] + for lhs_in_registers in [False, True] + ) + def test_wgmma_kernel_with_tma( + self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers + ): + if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: + self.skipTest("No swizzle is not supported by wgmma") - result = [] - for swizzle in [ - # TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes. - mgpu_dialect.SwizzlingMode.k32ByteSwizzle, - mgpu_dialect.SwizzlingMode.k64ByteSwizzle, - mgpu_dialect.SwizzlingMode.k128ByteSwizzle, - ]: - k = swizzle // np.dtype(abtype).itemsize - groups_m = 4 - groups_n = 1 - groups_k = 1 - result.extend([ - TestCaseInput( - shape_a=[groups_m * 64, groups_k * k], - shape_b=[groups_k * k, groups_n * k], - shape_res=[groups_m * 64, groups_n * k], - ), - TestCaseInput( - shape_a=[groups_m * 64, groups_k * k], - shape_b=[groups_n * k, groups_k * k], - shape_res=[groups_m * 64, groups_n * k], - transpose_b=True, - ), - TestCaseInput( - shape_a=[groups_m * 64, groups_k * k], - shape_b=[groups_k * k, groups_n * k], - shape_res=[groups_m * 64, groups_n * k], - transforms_a=[Tile([64, k]), Swizzle(swizzle)], - transforms_b=[Tile([k, k]), Swizzle(swizzle)], - ), - TestCaseInput( - shape_a=[groups_m * 64, groups_k * k], - shape_b=[groups_k * k, groups_n * k], - shape_res=[groups_m * 64, groups_n * k], - transforms_a=[Tile([64, k]), Swizzle(swizzle)], - load_a_in_registers=True, - ), - ]) - # The below only works for 128-byte swizzling. Regardless of transposing, - # TMA needs the size of the last dimension to be compatible with the - # swizzle. - if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle: - result.append( - TestCaseInput( - shape_a=[groups_k * k, groups_m * 64], - shape_b=[groups_k * k, groups_n * k], - shape_res=[groups_m * 64, groups_n * k], - transpose_a=True, - ) - ) - return result + if transpose_lhs or transpose_rhs: + self.skipTest("Transposes are not supported by transform inference yet.") - @parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16)) - def test_wgmma_kernel_with_tma(self, test_case): + swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize + tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems + + groups_m, groups_n, groups_k = 4, 1, 1 + m, n, k = groups_m * tiling_m, groups_n * tiling_n, groups_k * tiling_k + + lhs_shape = (k, m) if transpose_lhs else (m, k) + rhs_shape = (n, k) if transpose_rhs else (k, n) + out_shape = (m, n) def matmul( ctx: launch_context.LaunchContext, - a_gmem_ref: ir.Value, - b_gmem_ref: ir.Value, + lhs_gmem_ref: ir.Value, + rhs_gmem_ref: ir.Value, result_gmem_ref: ir.Value, smem: list[ir.Value], ): del ctx - a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem - a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a) - b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b) + lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem dialect_barrier = tma_barrier.as_dialect_barrier_memref() - ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type - bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a) - bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b) + operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type + bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape) + bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape) mgpu_dialect.arrive_expect_tx( barrier=dialect_barrier, @@ -2869,19 +2771,19 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) # GMEM -> SMEM mgpu_dialect.async_load( - source=a_gmem_ref, - destination=a_smem_ref, + source=lhs_gmem_ref, + destination=lhs_smem_ref, barrier=dialect_barrier, - indices=[zero_i32] * len(test_case.shape_a), - slice_lengths=test_case.shape_a, + indices=[zero_i32] * len(lhs_shape), + slice_lengths=lhs_shape, collective=ir.ArrayAttr.get([]), ) mgpu_dialect.async_load( - source=b_gmem_ref, - destination=b_smem_ref, + source=rhs_gmem_ref, + destination=rhs_smem_ref, barrier=dialect_barrier, - indices=[zero_i32] * len(test_case.shape_b), - slice_lengths=test_case.shape_b, + indices=[zero_i32] * len(rhs_shape), + slice_lengths=rhs_shape, collective=ir.ArrayAttr.get([]), ) @@ -2889,29 +2791,34 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) - # SMEM -> Registers - a_operand = a_smem_ref - zero_index = arith.constant(ir.IndexType.get(), 0) - if test_case.load_a_in_registers: - a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type) - zero_vector_indices = [zero_index] * len(test_case.shape_a) - a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices) - # Computation shape_result = ir.MemRefType(result_gmem_ref.type).shape result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type + acc_elt_type = ir.F32Type.get() + acc_type = ir.VectorType.get(shape_result, acc_elt_type) zero_acc = arith.constant( - result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0) - ) - accumulator = vector.splat( - ir.VectorType.get(shape_result, result_elt_type), zero_acc + result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0) ) + accumulator = vector.splat(acc_type, zero_acc) + + if transpose_lhs: + lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0)) + if transpose_rhs: + rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0)) + + zero_index = arith.constant(ir.IndexType.get(), 0) + if load_a_in_registers: + # SMEM -> Registers + lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type) + zero_vector_indices = [zero_index] * len(lhs_shape) + lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices) + else: + lhs_operand = lhs_smem_ref + result = mgpu_dialect.wgmma( accumulator, - a_operand, - b_smem_ref, - transpose_a=test_case.transpose_a, - transpose_b=test_case.transpose_b, + lhs_operand, + rhs_smem_ref, ) nvvm.wgmma_commit_group_sync_aligned() @@ -2929,38 +2836,41 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): ) nvvm.cp_async_bulk_wait_group(0) - abtype = jnp.bfloat16 + operand_type = jnp.bfloat16 acctype = jnp.float32 - a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype) - b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype) - result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype) + lhs_jax_shape = jax.ShapeDtypeStruct(lhs_shape, operand_type) + rhs_jax_shape = jax.ShapeDtypeStruct(rhs_shape, operand_type) + result_jax_shape = jax.ShapeDtypeStruct(out_shape, acctype) kernel = mgpu.as_gpu_kernel( matmul, grid=(1, 1, 1), block=(128, 1, 1), - in_shape=(a_jax_shape, b_jax_shape), + in_shape=(lhs_jax_shape, rhs_jax_shape), out_shape=result_jax_shape, smem_scratch_shape=[ - a_jax_shape, - b_jax_shape, + lhs_jax_shape, + rhs_jax_shape, result_jax_shape, core.TMABarrier(1), ], thread_semantics=mgpu.ThreadSemantics.Warpgroup, ) - x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype) - y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype) + prng_key = jax.random.key(1234) + k0, k1 = jax.random.split(prng_key, 2) + + x = jax.random.randint(k0, lhs_shape, 0, 2).astype(operand_type) + y = jax.random.randint(k1, rhs_shape, 0, 2).astype(operand_type) transpose = lambda x, t: x.T if t else x self.assertArraysAllClose( jax.jit(kernel)(x, y), np.matmul( - transpose(x.reshape(test_case.shape_a), test_case.transpose_a), - transpose(y.reshape(test_case.shape_b), test_case.transpose_b), + transpose(x, transpose_lhs), + transpose(y, transpose_rhs) ), - atol=1e-5, - rtol=1e-5, + atol=0, + rtol=0, )