1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 07:16:07 +00:00

[Mosaic GPU] Enable the new transform inference pass in the warpgroup lowering.

A couple of dummy transform inference rules needed to be added in order to
contend with parts of the lowering that do not use the dialect yet, along with
a transform inference rule for `memref.view`.

PiperOrigin-RevId: 738089782
This commit is contained in:
Benjamin Chetioui 2025-03-18 11:50:58 -07:00 committed by jax authors
parent 547d602760
commit 875099b25d
5 changed files with 324 additions and 348 deletions

@ -41,20 +41,16 @@ from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import nvvm
import numpy as np import numpy as np
if dialect is not None:
from . import dialect_lowering
from . import layout_inference
else:
dialect_lowering = None
layout_inference = None
from . import profiler
from . import utils
from . import launch_context
from . import tcgen05
# mypy: ignore-errors # mypy: ignore-errors
from . import dialect_lowering
from . import launch_context
from . import layout_inference
from . import profiler
from . import tcgen05
from . import transform_inference
from . import utils
# MLIR can't find libdevice unless we point it to the CUDA path # MLIR can't find libdevice unless we point it to the CUDA path
# TODO(apaszke): Unify with jax._src.lib.cuda_path # TODO(apaszke): Unify with jax._src.lib.cuda_path
CUDA_ROOT = "/usr/local/cuda" CUDA_ROOT = "/usr/local/cuda"
@ -584,6 +580,7 @@ def as_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in # Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc # jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error layout_inference.infer_layout(module) # pytype: disable=attribute-error
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr) _initialize_scratch(launch_ctx, scratch_arr)
@ -666,6 +663,7 @@ def as_torch_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in # Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc # jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error layout_inference.infer_layout(module) # pytype: disable=attribute-error
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr) _initialize_scratch(launch_ctx, scratch_arr)

@ -17,6 +17,7 @@
from collections.abc import Callable from collections.abc import Callable
import dataclasses import dataclasses
import functools import functools
import itertools
import operator import operator
from typing import Any, Sequence, Type, cast from typing import Any, Sequence, Type, cast
@ -58,7 +59,7 @@ class LoweringContext:
if not _should_lower(op): if not _should_lower(op):
return return
if (name := op.OPERATION_NAME) not in _lowerings: if (name := op.OPERATION_NAME) not in _lowerings: # pytype: disable=attribute-error
raise NotImplementedError(f"Missing lowering rule for {op}") raise NotImplementedError(f"Missing lowering rule for {op}")
lowering_rule = _lowerings[name] lowering_rule = _lowerings[name]
@ -227,6 +228,60 @@ def _arith_constant_op_lowering_rule(
] ]
def _check_transforms_and_swizzle_are_supported(
ref_ty: ir.MemRefType,
transforms: Sequence[launch_context.MemRefTransform],
swizzle: mgpu.SwizzlingMode,
minimum_swizzle: mgpu.SwizzlingMode = mgpu.SwizzlingMode.kNoSwizzle,
):
"""Checks that the list of provided transforms and swizzle are supported.
Currently, we allow the following:
- any swizzle that is larger than or equal to `minimum_swizzle`;
- optionally, a single tile transform (with rank equal to the rank of the
memref being annotated);
- optionally, a single transpose transform.
"""
if swizzle < minimum_swizzle:
raise NotImplementedError(
f"Unsupported swizzle {swizzle} smaller than {minimum_swizzle}."
)
partitioned_transforms = {
k: list(v)
for k, v in itertools.groupby(
transforms, lambda t: isinstance(t, launch_context.TileTransform)
)
}
tile_transforms = partitioned_transforms.get(True, [])
other_transforms = partitioned_transforms.get(False, [])
if len(tile_transforms) > 1:
raise NotImplementedError(
f"{tile_transforms} contains more than one tile transform."
)
if len(tile_transforms) == 1:
if len(tile_transforms[0].tiling) != len(ref_ty.shape):
raise NotImplementedError(
f"Only tile transforms with rank equal to the rank of the memref "
f"being annotated are supported but got {tile_transforms[0]} for "
f"{ref_ty}."
)
if len(other_transforms) > 1:
raise NotImplementedError(
f"{other_transforms} contains more than one transform."
)
if len(other_transforms) == 1:
if not isinstance(other_transforms[0], launch_context.TransposeTransform):
raise NotImplementedError(
f"{other_transforms[0]} is not a transpose transform."
)
@_register_lowering(vector.LoadOp) @_register_lowering(vector.LoadOp)
def _vector_load_op_lowering_rule( def _vector_load_op_lowering_rule(
_: LoweringContext, vector_load_op: vector.LoadOp _: LoweringContext, vector_load_op: vector.LoadOp
@ -260,8 +315,11 @@ def _vector_load_op_lowering_rule(
vec_size=strided_layout.vec_size, vec_size=strided_layout.vec_size,
) )
elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT: elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT:
layout = ir.MemRefType(vector_load_op.base.type).layout swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout) inference_utils.in_transforms(vector_load_op)[0]
)
ref_ty = ir.MemRefType(vector_load_op.base.type)
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
transformed_ref = transform_memref(vector_load_op.base, transforms) transformed_ref = transform_memref(vector_load_op.base, transforms)
fragmented_array = fa.FragmentedArray.load_tiled( fragmented_array = fa.FragmentedArray.load_tiled(
transformed_ref, transformed_ref,
@ -297,8 +355,22 @@ def _vector_store_op_lowering_rule(
vector_store_op.valueToStore, to_store_layout vector_store_op.valueToStore, to_store_layout
) )
# TODO(dasenov): This is not efficient for WGMMA layouts if fragmented_array.layout == fa.WGMMA_LAYOUT:
fragmented_array.store_untiled(vector_store_op.base) swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
inference_utils.in_transforms(vector_store_op)[0]
)
ref_ty = ir.MemRefType(vector_store_op.base.type)
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
fragmented_array.store_tiled(
transform_memref(vector_store_op.base, transforms), swizzle
)
elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or
isinstance(fragmented_array.layout, fa.WGSplatFragLayout)):
fragmented_array.store_untiled(vector_store_op.base)
else:
raise ValueError(
f"{vector_store_op} has an unsupported layout: {to_store_layout}"
)
return [] return []
@ -362,39 +434,43 @@ def _vector_reduction_op_lowering_rule(
return [_fragmented_array_to_ir(result, op.result.type)] return [_fragmented_array_to_ir(result, op.result.type)]
def memref_layout_to_swizzle_and_transforms( def swizzle_and_transforms_from_transforms_attr(
layout: ir.Attribute, transforms: ir.ArrayAttr,
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]: ) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
"""Returns the swizzle and transforms that are encoded in the given layout. """Returns the swizzle and MemrefTransforms for the given transforms.
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the Args:
transforms are empty. Otherwise, the layout may have at most one swizzle transforms: a list of transform attributes.
transform and any combination of tiling and transpose transforms.
Returns:
A tuple containing the swizzle mode and MemRefTransforms corresponding to
the parameter transforms. If `transforms` is empty, or does not contain
any swizzling transform, the swizzle mode is assumed to be kNoSwizzle.
Raises:
ValueError: if a swizzling transform is followed by any transform.
""" """
swizzle = None swizzle = None
gmem_transforms: list[launch_context.MemRefTransform] = [] gmem_transforms: list[launch_context.MemRefTransform] = []
if mgpu.LayoutAttr.isinstance(layout): for transform in transforms:
transforms_attr = mgpu.LayoutAttr(layout).transforms if swizzle is not None:
for transform in transforms_attr: raise ValueError(f"{transforms} contain more transforms after swizzle.")
if swizzle is not None: if mgpu.SwizzleTransformAttr.isinstance(transform):
raise ValueError(f"{layout} contains more transforms after the initial swizzle.") # TODO(dasenov): Swizzling can change if the ref is sliced in certain
if mgpu.SwizzleTransformAttr.isinstance(transform): # ways. We might want to enforce some restrictions here.
# TODO(dasenov): Swizzling can change if the ref is sliced in certain swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
# ways. We might want to enforce some restrictions here. elif mgpu.TileTransformAttr.isinstance(transform):
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle tiling = mgpu.TileTransformAttr(transform).tiling
elif mgpu.TileTransformAttr.isinstance(transform): tiling_transform = launch_context.TileTransform(tuple(tiling))
tiling = mgpu.TileTransformAttr(transform).tiling gmem_transforms.append(tiling_transform)
tiling_transform = launch_context.TileTransform(tuple(tiling)) elif mgpu.TransposeTransformAttr.isinstance(transform):
gmem_transforms.append(tiling_transform) permutation = mgpu.TransposeTransformAttr(transform).permutation
elif mgpu.TransposeTransformAttr.isinstance(transform): transpose_transform = launch_context.TransposeTransform(
permutation = mgpu.TransposeTransformAttr(transform).permutation tuple(permutation)
transpose_transform = launch_context.TransposeTransform( )
tuple(permutation) gmem_transforms.append(transpose_transform)
) else:
gmem_transforms.append(transpose_transform) raise ValueError("Unknown transform: {transform}")
else:
raise ValueError(f"{layout} has an unsupported transform: {transform}")
return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms)
@ -434,8 +510,14 @@ def _mgpu_async_load_op_lowering_rule(
assert ctx.launch_context is not None assert ctx.launch_context is not None
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier) barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
dst_layout = ir.MemRefType(load_op.destination.type).layout if inference_utils.has_in_transforms_set(load_op):
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout) [transforms] = inference_utils.in_transforms(load_op)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
transforms
)
else:
swizzle = mgpu.SwizzlingMode.kNoSwizzle
transforms = ()
gmem_slice = [] gmem_slice = []
for idx_i32, size in zip(load_op.indices, load_op.slice_lengths): for idx_i32, size in zip(load_op.indices, load_op.slice_lengths):
@ -464,8 +546,14 @@ def _mgpu_async_store_op_lowering_rule(
) -> Sequence[ir.Value]: ) -> Sequence[ir.Value]:
assert ctx.launch_context is not None assert ctx.launch_context is not None
src_layout = ir.MemRefType(store_op.source.type).layout if inference_utils.has_in_transforms_set(store_op):
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout) [transforms] = inference_utils.in_transforms(store_op)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
transforms
)
else:
swizzle = mgpu.SwizzlingMode.kNoSwizzle
transforms = ()
gmem_slice = [] gmem_slice = []
for idx_i32, size in zip(store_op.indices, store_op.slice_lengths): for idx_i32, size in zip(store_op.indices, store_op.slice_lengths):
@ -673,6 +761,9 @@ def _bitcast_op_lowering_rule(
def _mgpu_wgmma_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp _: LoweringContext, wgmma_op: mgpu.WGMMAOp
) -> Sequence[ir.Value]: ) -> Sequence[ir.Value]:
if wgmma_op.transpose_a or wgmma_op.transpose_b:
raise ValueError("Transpose arguments are to be deleted.")
fa_layouts = ( fa_layouts = (
*inference_utils.in_layouts(wgmma_op), *inference_utils.in_layouts(wgmma_op),
*inference_utils.out_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op),
@ -691,25 +782,38 @@ def _mgpu_wgmma_op_lowering_rule(
regs = acc_in.to_layout(fa.WGMMA_LAYOUT) regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
acc = wgmma.WGMMAAccumulator.from_registers(regs) acc = wgmma.WGMMAAccumulator.from_registers(regs)
b_layout = ir.MemRefType(wgmma_op.b.type).layout if ir.VectorType.isinstance(wgmma_op.a.type):
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout) a_transforms = None
b_transforms = inference_utils.in_transforms(wgmma_op)[0]
else:
a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op)
b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr(
b_transforms
)
minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle
ref_ty = ir.MemRefType(wgmma_op.b.type)
_check_transforms_and_swizzle_are_supported(
ref_ty, b_transforms, b_swizzle, minimum_swizzle
)
b_operand = transform_memref(wgmma_op.b, b_transforms) b_operand = transform_memref(wgmma_op.b, b_transforms)
if wgmma_op.transpose_b:
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
if ir.VectorType.isinstance(wgmma_op.a.type): if ir.VectorType.isinstance(wgmma_op.a.type):
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
else: else:
a_layout = ir.MemRefType(wgmma_op.a.type).layout a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr(
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout) a_transforms
)
ref_ty = ir.MemRefType(wgmma_op.a.type)
_check_transforms_and_swizzle_are_supported(
ref_ty, a_transforms, a_swizzle, minimum_swizzle
)
if a_swizzle != b_swizzle: if a_swizzle != b_swizzle:
raise ValueError( raise ValueError(
f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !="
f" {b_swizzle}" f" {b_swizzle}"
) )
a_operand = transform_memref(wgmma_op.a, a_transforms) a_operand = transform_memref(wgmma_op.a, a_transforms)
if wgmma_op.transpose_a:
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
@ -902,7 +1006,7 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
def _should_lower(op: ir.OpView) -> bool: def _should_lower(op: ir.OpView) -> bool:
"""Returns 'true' if the operation should be lowered.""" """Returns 'true' if the operation should be lowered."""
return ( return (
op.OPERATION_NAME.startswith("mosaic_gpu.") op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error
or inference_utils.should_have_layout(op) or inference_utils.should_have_layout(op)
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
) )

@ -383,89 +383,6 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
return [layout], [layout] return [layout], [layout]
@dataclasses.dataclass()
class WGMMATransforms:
swizzle: mgpu.SwizzlingMode
a_tile: tuple[int, ...]
a_transpose: bool
b_tile: tuple[int, ...]
b_transpose: bool
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
# Try tiling with all swizzling modes starting from the largest one.
for swizzle in [
mgpu.SwizzlingMode.k128ByteSwizzle,
mgpu.SwizzlingMode.k64ByteSwizzle,
mgpu.SwizzlingMode.k32ByteSwizzle,
]:
s = swizzle * 8 // bitwidth
if k % s == 0:
return WGMMATransforms(
swizzle=swizzle,
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
a_transpose=wgmma_op.transpose_a,
b_tile=(s, s),
b_transpose=wgmma_op.transpose_b,
)
raise ValueError(
"Could not infer layouts for memref feeding into WGMMA. The "
"non-contracting dimension ({k}) must be a multiple of "
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
"`a` and `b`."
)
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
wgmma_use = None
uses = cast(ir.OpResult, view_op.result).uses
for use in uses:
user = use.owner
if isinstance(user, memref.CastOp):
# This memref is already cast, so we don't need to do anything.
return None
if isinstance(user, mgpu.WGMMAOp):
if wgmma_use is not None:
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
wgmma_use = use
break
if (
not isinstance(user, mgpu.AsyncLoadOp)
and not isinstance(user, mgpu.AsyncStoreOp)
and not isinstance(user, vector.LoadOp)
and not isinstance(user, vector.StoreOp)
):
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
if wgmma_use is None:
# This memref is not used by a WGMMA operation, so we don't need to do
# anything.
return None
transforms = infer_wgmma_transforms(wgmma_use.owner)
if wgmma_use.operand_number == 1:
tile = transforms.a_tile
transpose = transforms.a_transpose
else:
tile = transforms.b_tile
transpose = transforms.b_transpose
transpose_attr = (
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
)
layout = mgpu.LayoutAttr.get(
2,
[mgpu.TileTransformAttr.get(tile)]
+ transpose_attr
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
)
return layout
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView: def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
owners = [use.owner for use in uses] owners = [use.owner for use in uses]
@ -607,11 +524,3 @@ def infer_layout(module: ir.Module):
for op in module.body: for op in module.body:
traverse_op(op, set_default_layout) traverse_op(op, set_default_layout)
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
if op.name == "memref.view":
if layout := _layout_for_memref_view(op):
_insert_memref_layout_cast(layout, op)
for op in module.body:
traverse_op(op, infer_memref_layouts_and_insert_casts)

@ -26,6 +26,9 @@ from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import builtin
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import vector from jax._src.lib.mlir.dialects import vector
from . import fragmented_array as fa from . import fragmented_array as fa
@ -169,7 +172,6 @@ def _infer_vector_load_store_transforms(
return None return None
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. # TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
@ -196,6 +198,60 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
return None if transforms is None else ([], [transforms]) return None if transforms is None else ([], [transforms])
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
# the dialect in all cases.
# The rule is necessary in order to handle the lowering of `utils.memref_ptr`
# which is used in `_construct_smem_reftree`.
@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp)
def _infer_unrealized_conversion_cast_transforms(
_: builtin.UnrealizedConversionCastOp,
) -> OptionalTransforms:
return None
@partial(_add_transform_inference_rule, memref.ViewOp)
def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms:
if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp):
raise NotImplementedError(
"Memref view transforms are only inferred when the op is a direct user "
f"of a DynamicSharedMemoryOp but got {op}."
)
transforms = inference_utils.value_transforms(op.source)
if transforms is not None:
raise NotImplementedError(
"memref view with in_transforms aren't yet supported"
)
uses = cast(ir.OpResult, op.result).uses
for op_operand_use in uses:
consumer = op_operand_use.owner
op_user = consumer.operands[op_operand_use.operand_number]
out_transforms = inference_utils.in_transforms_for_operand(
consumer, op_user
)
if transforms is not None and out_transforms is not None:
if transforms != out_transforms:
raise ValueError(
f"Conflicting transforms for {op_user} in {op}: "
f"{transforms} != {out_transforms}."
)
elif out_transforms is not None:
transforms = out_transforms
# TODO(bchetioui): do we actually need to assign a transform to the input of
# the view op? Presumably, it'll only be used to access scratch memory.
return None if transforms is None else ([], [transforms])
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
# the dialect in all cases.
@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp)
def _infer_dynamic_smem_transforms(
_: gpu.DynamicSharedMemoryOp,
) -> OptionalTransforms:
return None
def _should_have_transforms(op: ir.OpView) -> bool: def _should_have_transforms(op: ir.OpView) -> bool:
"""Returns 'True' if the operation should be assigned in/out transforms.""" """Returns 'True' if the operation should be assigned in/out transforms."""
return any( return any(
@ -218,7 +274,6 @@ def infer_transforms(module: ir.Module):
specified. We error out if two distinct sets of transforms are competing to specified. We error out if two distinct sets of transforms are competing to
annotate the same memref. annotate the same memref.
""" """
def inference_step(op: ir.Operation): def inference_step(op: ir.Operation):
if not _should_have_transforms(op): if not _should_have_transforms(op):
return return

@ -55,6 +55,7 @@ else:
from jax.experimental.mosaic.gpu import launch_context from jax.experimental.mosaic.gpu import launch_context
from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import utils as utils
from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu import inference_utils
from jax.experimental.mosaic.gpu.utils import * # noqa: F403 from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm from jax._src.lib.mlir.dialects import llvm
@ -2405,25 +2406,21 @@ class Swizzle:
return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle) return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle)
def memref_with_transforms( def set_in_transforms(
mem_ref: ir.Value, op: ir.OpView, transforms: Sequence[Sequence[Tile | Transpose | Swizzle]],
transforms: Sequence[Tile | Transpose | Swizzle], ) -> None:
) -> ir.Value: """Annotates an op with in_transforms."""
"""Casts the memref to one that has a layout with the given transforms.""" if not transforms:
mem_ref_type = ir.MemRefType(mem_ref.type) return
transform_attr = [t.attr() for t in transforms] in_transforms = []
if not transform_attr: smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable
return mem_ref for _, result_transforms in jax.util.safe_zip(smem_refs, transforms):
in_transforms.append(
ir.ArrayAttr.get([t.attr() for t in result_transforms])
)
layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr) op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms)
memref_new_type = ir.MemRefType.get(
mem_ref_type.shape,
mem_ref_type.element_type,
layout,
mem_ref_type.memory_space,
)
return memref.cast(memref_new_type, mem_ref)
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
@ -2556,7 +2553,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
): ):
del ctx del ctx
smem_ref, tma_barrier = smem smem_ref, tma_barrier = smem
smem_ref = memref_with_transforms(smem_ref, test_case.transforms)
dialect_barrier = tma_barrier.as_dialect_barrier_memref() dialect_barrier = tma_barrier.as_dialect_barrier_memref()
elt_type = ir.MemRefType(in_gmem_ref.type).element_type elt_type = ir.MemRefType(in_gmem_ref.type).element_type
@ -2571,7 +2567,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices] slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices]
# GMEM -> SMEM # GMEM -> SMEM
mgpu_dialect.async_load( load_op = mgpu_dialect.AsyncLoadOp(
source=in_gmem_ref, source=in_gmem_ref,
destination=smem_ref, destination=smem_ref,
barrier=dialect_barrier, barrier=dialect_barrier,
@ -2579,6 +2575,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_lengths=test_case.slice_lengths, slice_lengths=test_case.slice_lengths,
collective=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]),
) )
set_in_transforms(load_op, [test_case.transforms])
parities = memref.load(tma_barrier.phases, []) parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities) parity, _ = tma_barrier.update_parities(parities)
@ -2623,58 +2620,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
(x[input_slice]).reshape(test_case.shape_sliced), (x[input_slice]).reshape(test_case.shape_sliced),
) )
@staticmethod def test_pointwise_kernel_with_tma(self):
def pointwise_kernel_with_tma_cases(dtype: jnp.dtype):
@dataclasses.dataclass(frozen=True)
class TestCaseInput:
shape: tuple[int, ...]
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
result = []
for swizzle in mgpu_dialect.SwizzlingMode:
n = swizzle * 8 // jnp.finfo(dtype).bits
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
# We need at least one case with no transforms, as this is handled
# differently.
result.append(TestCaseInput(shape=[128, n]))
result.extend([
TestCaseInput(
shape=[128, n],
transforms=[Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[
Transpose([1, 0, 2, 3]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[128, n],
transforms=[Tile([64, n]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2 * 64, 3 * n],
transforms=[
Tile([64, n]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
])
return result
@parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16))
def test_pointwise_kernel_with_tma(self, test_case):
def add( def add(
ctx: launch_context.LaunchContext, ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value, a_gmem_ref: ir.Value,
@ -2701,9 +2647,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# GMEM -> SMEM # GMEM -> SMEM
mgpu_dialect.async_load( mgpu_dialect.async_load(
source=a_gmem_ref, source=a_gmem_ref,
destination=memref_with_transforms( destination=a_smem_ref,
a_smem_ref, test_case.transforms
),
barrier=dialect_barrier, barrier=dialect_barrier,
indices=zero_slice_indices, indices=zero_slice_indices,
slice_lengths=shape, slice_lengths=shape,
@ -2711,9 +2655,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
) )
mgpu_dialect.async_load( mgpu_dialect.async_load(
source=b_gmem_ref, source=b_gmem_ref,
destination=memref_with_transforms( destination=b_smem_ref,
b_smem_ref, test_case.transforms
),
barrier=dialect_barrier, barrier=dialect_barrier,
indices=zero_slice_indices, indices=zero_slice_indices,
slice_lengths=shape, slice_lengths=shape,
@ -2740,9 +2682,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# SMEM -> GMEM # SMEM -> GMEM
mgpu_dialect.async_store( mgpu_dialect.async_store(
source=memref_with_transforms( source=result_smem_ref,
result_smem_ref, test_case.transforms
),
destination=result_gmem_ref, destination=result_gmem_ref,
indices=zero_slice_indices, indices=zero_slice_indices,
slice_lengths=shape, slice_lengths=shape,
@ -2752,114 +2692,76 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
dtype = jnp.bfloat16 dtype = jnp.bfloat16
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype) spec = jax.ShapeDtypeStruct((2, 3, 4, 64), dtype)
kernel = mgpu.as_gpu_kernel( kernel = mgpu.as_gpu_kernel(
add, add,
grid=(1, 1, 1), grid=(1, 1, 1),
block=(128, 1, 1), block=(128, 1, 1),
in_shape=(jax_shape, jax_shape), in_shape=(spec, spec),
out_shape=jax_shape, out_shape=spec,
smem_scratch_shape=[ smem_scratch_shape=[
jax_shape, spec,
jax_shape, spec,
jax_shape, spec,
core.TMABarrier(1), core.TMABarrier(1),
], ],
thread_semantics=mgpu.ThreadSemantics.Warpgroup, thread_semantics=mgpu.ThreadSemantics.Warpgroup,
) )
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) x = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) y = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y) self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
@staticmethod @parameterized.named_parameters(
def wgmma_kernel_with_tma_cases(abtype: jnp.dtype): (
@dataclasses.dataclass(frozen=True) f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}",
class TestCaseInput: swizzle,
shape_a: tuple[int, ...] = () transpose_lhs,
shape_b: tuple[int, ...] = () transpose_rhs,
shape_res: tuple[int, ...] = () lhs_in_registers,
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = () )
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = () for swizzle in mgpu_dialect.SwizzlingMode
transpose_a: bool = False for transpose_lhs in [False, True]
transpose_b: bool = False for transpose_rhs in [False, True]
load_a_in_registers: bool = False for lhs_in_registers in [False, True]
)
def test_wgmma_kernel_with_tma(
self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers
):
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
self.skipTest("No swizzle is not supported by wgmma")
result = [] if transpose_lhs or transpose_rhs:
for swizzle in [ self.skipTest("Transposes are not supported by transform inference yet.")
# TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes.
mgpu_dialect.SwizzlingMode.k32ByteSwizzle,
mgpu_dialect.SwizzlingMode.k64ByteSwizzle,
mgpu_dialect.SwizzlingMode.k128ByteSwizzle,
]:
k = swizzle // np.dtype(abtype).itemsize
groups_m = 4
groups_n = 1
groups_k = 1
result.extend([
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_n * k, groups_k * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_b=True,
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
load_a_in_registers=True,
),
])
# The below only works for 128-byte swizzling. Regardless of transposing,
# TMA needs the size of the last dimension to be compatible with the
# swizzle.
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
result.append(
TestCaseInput(
shape_a=[groups_k * k, groups_m * 64],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_a=True,
)
)
return result
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16)) swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize
def test_wgmma_kernel_with_tma(self, test_case): tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems
groups_m, groups_n, groups_k = 4, 1, 1
m, n, k = groups_m * tiling_m, groups_n * tiling_n, groups_k * tiling_k
lhs_shape = (k, m) if transpose_lhs else (m, k)
rhs_shape = (n, k) if transpose_rhs else (k, n)
out_shape = (m, n)
def matmul( def matmul(
ctx: launch_context.LaunchContext, ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value, lhs_gmem_ref: ir.Value,
b_gmem_ref: ir.Value, rhs_gmem_ref: ir.Value,
result_gmem_ref: ir.Value, result_gmem_ref: ir.Value,
smem: list[ir.Value], smem: list[ir.Value],
): ):
del ctx del ctx
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem
a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a)
b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b)
dialect_barrier = tma_barrier.as_dialect_barrier_memref() dialect_barrier = tma_barrier.as_dialect_barrier_memref()
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type
bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a) bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape)
bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b) bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape)
mgpu_dialect.arrive_expect_tx( mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier, barrier=dialect_barrier,
@ -2869,19 +2771,19 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
# GMEM -> SMEM # GMEM -> SMEM
mgpu_dialect.async_load( mgpu_dialect.async_load(
source=a_gmem_ref, source=lhs_gmem_ref,
destination=a_smem_ref, destination=lhs_smem_ref,
barrier=dialect_barrier, barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_a), indices=[zero_i32] * len(lhs_shape),
slice_lengths=test_case.shape_a, slice_lengths=lhs_shape,
collective=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]),
) )
mgpu_dialect.async_load( mgpu_dialect.async_load(
source=b_gmem_ref, source=rhs_gmem_ref,
destination=b_smem_ref, destination=rhs_smem_ref,
barrier=dialect_barrier, barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_b), indices=[zero_i32] * len(rhs_shape),
slice_lengths=test_case.shape_b, slice_lengths=rhs_shape,
collective=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]),
) )
@ -2889,29 +2791,34 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
parity, _ = tma_barrier.update_parities(parities) parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity) mgpu_dialect.wait(dialect_barrier, parity)
# SMEM -> Registers
a_operand = a_smem_ref
zero_index = arith.constant(ir.IndexType.get(), 0)
if test_case.load_a_in_registers:
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
zero_vector_indices = [zero_index] * len(test_case.shape_a)
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
# Computation # Computation
shape_result = ir.MemRefType(result_gmem_ref.type).shape shape_result = ir.MemRefType(result_gmem_ref.type).shape
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
acc_elt_type = ir.F32Type.get()
acc_type = ir.VectorType.get(shape_result, acc_elt_type)
zero_acc = arith.constant( zero_acc = arith.constant(
result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0) result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0)
)
accumulator = vector.splat(
ir.VectorType.get(shape_result, result_elt_type), zero_acc
) )
accumulator = vector.splat(acc_type, zero_acc)
if transpose_lhs:
lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0))
if transpose_rhs:
rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0))
zero_index = arith.constant(ir.IndexType.get(), 0)
if load_a_in_registers:
# SMEM -> Registers
lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type)
zero_vector_indices = [zero_index] * len(lhs_shape)
lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices)
else:
lhs_operand = lhs_smem_ref
result = mgpu_dialect.wgmma( result = mgpu_dialect.wgmma(
accumulator, accumulator,
a_operand, lhs_operand,
b_smem_ref, rhs_smem_ref,
transpose_a=test_case.transpose_a,
transpose_b=test_case.transpose_b,
) )
nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_commit_group_sync_aligned()
@ -2929,38 +2836,41 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
) )
nvvm.cp_async_bulk_wait_group(0) nvvm.cp_async_bulk_wait_group(0)
abtype = jnp.bfloat16 operand_type = jnp.bfloat16
acctype = jnp.float32 acctype = jnp.float32
a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype) lhs_jax_shape = jax.ShapeDtypeStruct(lhs_shape, operand_type)
b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype) rhs_jax_shape = jax.ShapeDtypeStruct(rhs_shape, operand_type)
result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype) result_jax_shape = jax.ShapeDtypeStruct(out_shape, acctype)
kernel = mgpu.as_gpu_kernel( kernel = mgpu.as_gpu_kernel(
matmul, matmul,
grid=(1, 1, 1), grid=(1, 1, 1),
block=(128, 1, 1), block=(128, 1, 1),
in_shape=(a_jax_shape, b_jax_shape), in_shape=(lhs_jax_shape, rhs_jax_shape),
out_shape=result_jax_shape, out_shape=result_jax_shape,
smem_scratch_shape=[ smem_scratch_shape=[
a_jax_shape, lhs_jax_shape,
b_jax_shape, rhs_jax_shape,
result_jax_shape, result_jax_shape,
core.TMABarrier(1), core.TMABarrier(1),
], ],
thread_semantics=mgpu.ThreadSemantics.Warpgroup, thread_semantics=mgpu.ThreadSemantics.Warpgroup,
) )
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype) prng_key = jax.random.key(1234)
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype) k0, k1 = jax.random.split(prng_key, 2)
x = jax.random.randint(k0, lhs_shape, 0, 2).astype(operand_type)
y = jax.random.randint(k1, rhs_shape, 0, 2).astype(operand_type)
transpose = lambda x, t: x.T if t else x transpose = lambda x, t: x.T if t else x
self.assertArraysAllClose( self.assertArraysAllClose(
jax.jit(kernel)(x, y), jax.jit(kernel)(x, y),
np.matmul( np.matmul(
transpose(x.reshape(test_case.shape_a), test_case.transpose_a), transpose(x, transpose_lhs),
transpose(y.reshape(test_case.shape_b), test_case.transpose_b), transpose(y, transpose_rhs)
), ),
atol=1e-5, atol=0,
rtol=1e-5, rtol=0,
) )