[Mosaic GPU] Enable the new transform inference pass in the warpgroup lowering.

A couple of dummy transform inference rules needed to be added in order to
contend with parts of the lowering that do not use the dialect yet, along with
a transform inference rule for `memref.view`.

PiperOrigin-RevId: 738089782
This commit is contained in:
Benjamin Chetioui 2025-03-18 11:50:58 -07:00 committed by jax authors
parent 547d602760
commit 875099b25d
5 changed files with 324 additions and 348 deletions

View File

@ -41,20 +41,16 @@ from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm
import numpy as np
if dialect is not None:
from . import dialect_lowering
from . import layout_inference
else:
dialect_lowering = None
layout_inference = None
from . import profiler
from . import utils
from . import launch_context
from . import tcgen05
# mypy: ignore-errors
from . import dialect_lowering
from . import launch_context
from . import layout_inference
from . import profiler
from . import tcgen05
from . import transform_inference
from . import utils
# MLIR can't find libdevice unless we point it to the CUDA path
# TODO(apaszke): Unify with jax._src.lib.cuda_path
CUDA_ROOT = "/usr/local/cuda"
@ -584,6 +580,7 @@ def as_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr)
@ -666,6 +663,7 @@ def as_torch_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr)

View File

@ -17,6 +17,7 @@
from collections.abc import Callable
import dataclasses
import functools
import itertools
import operator
from typing import Any, Sequence, Type, cast
@ -58,7 +59,7 @@ class LoweringContext:
if not _should_lower(op):
return
if (name := op.OPERATION_NAME) not in _lowerings:
if (name := op.OPERATION_NAME) not in _lowerings: # pytype: disable=attribute-error
raise NotImplementedError(f"Missing lowering rule for {op}")
lowering_rule = _lowerings[name]
@ -227,6 +228,60 @@ def _arith_constant_op_lowering_rule(
]
def _check_transforms_and_swizzle_are_supported(
ref_ty: ir.MemRefType,
transforms: Sequence[launch_context.MemRefTransform],
swizzle: mgpu.SwizzlingMode,
minimum_swizzle: mgpu.SwizzlingMode = mgpu.SwizzlingMode.kNoSwizzle,
):
"""Checks that the list of provided transforms and swizzle are supported.
Currently, we allow the following:
- any swizzle that is larger than or equal to `minimum_swizzle`;
- optionally, a single tile transform (with rank equal to the rank of the
memref being annotated);
- optionally, a single transpose transform.
"""
if swizzle < minimum_swizzle:
raise NotImplementedError(
f"Unsupported swizzle {swizzle} smaller than {minimum_swizzle}."
)
partitioned_transforms = {
k: list(v)
for k, v in itertools.groupby(
transforms, lambda t: isinstance(t, launch_context.TileTransform)
)
}
tile_transforms = partitioned_transforms.get(True, [])
other_transforms = partitioned_transforms.get(False, [])
if len(tile_transforms) > 1:
raise NotImplementedError(
f"{tile_transforms} contains more than one tile transform."
)
if len(tile_transforms) == 1:
if len(tile_transforms[0].tiling) != len(ref_ty.shape):
raise NotImplementedError(
f"Only tile transforms with rank equal to the rank of the memref "
f"being annotated are supported but got {tile_transforms[0]} for "
f"{ref_ty}."
)
if len(other_transforms) > 1:
raise NotImplementedError(
f"{other_transforms} contains more than one transform."
)
if len(other_transforms) == 1:
if not isinstance(other_transforms[0], launch_context.TransposeTransform):
raise NotImplementedError(
f"{other_transforms[0]} is not a transpose transform."
)
@_register_lowering(vector.LoadOp)
def _vector_load_op_lowering_rule(
_: LoweringContext, vector_load_op: vector.LoadOp
@ -260,8 +315,11 @@ def _vector_load_op_lowering_rule(
vec_size=strided_layout.vec_size,
)
elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT:
layout = ir.MemRefType(vector_load_op.base.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
inference_utils.in_transforms(vector_load_op)[0]
)
ref_ty = ir.MemRefType(vector_load_op.base.type)
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
transformed_ref = transform_memref(vector_load_op.base, transforms)
fragmented_array = fa.FragmentedArray.load_tiled(
transformed_ref,
@ -297,8 +355,22 @@ def _vector_store_op_lowering_rule(
vector_store_op.valueToStore, to_store_layout
)
# TODO(dasenov): This is not efficient for WGMMA layouts
fragmented_array.store_untiled(vector_store_op.base)
if fragmented_array.layout == fa.WGMMA_LAYOUT:
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
inference_utils.in_transforms(vector_store_op)[0]
)
ref_ty = ir.MemRefType(vector_store_op.base.type)
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
fragmented_array.store_tiled(
transform_memref(vector_store_op.base, transforms), swizzle
)
elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or
isinstance(fragmented_array.layout, fa.WGSplatFragLayout)):
fragmented_array.store_untiled(vector_store_op.base)
else:
raise ValueError(
f"{vector_store_op} has an unsupported layout: {to_store_layout}"
)
return []
@ -362,39 +434,43 @@ def _vector_reduction_op_lowering_rule(
return [_fragmented_array_to_ir(result, op.result.type)]
def memref_layout_to_swizzle_and_transforms(
layout: ir.Attribute,
def swizzle_and_transforms_from_transforms_attr(
transforms: ir.ArrayAttr,
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
"""Returns the swizzle and transforms that are encoded in the given layout.
"""Returns the swizzle and MemrefTransforms for the given transforms.
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the
transforms are empty. Otherwise, the layout may have at most one swizzle
transform and any combination of tiling and transpose transforms.
Args:
transforms: a list of transform attributes.
Returns:
A tuple containing the swizzle mode and MemRefTransforms corresponding to
the parameter transforms. If `transforms` is empty, or does not contain
any swizzling transform, the swizzle mode is assumed to be kNoSwizzle.
Raises:
ValueError: if a swizzling transform is followed by any transform.
"""
swizzle = None
gmem_transforms: list[launch_context.MemRefTransform] = []
if mgpu.LayoutAttr.isinstance(layout):
transforms_attr = mgpu.LayoutAttr(layout).transforms
for transform in transforms_attr:
if swizzle is not None:
raise ValueError(f"{layout} contains more transforms after the initial swizzle.")
if mgpu.SwizzleTransformAttr.isinstance(transform):
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
# ways. We might want to enforce some restrictions here.
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
elif mgpu.TileTransformAttr.isinstance(transform):
tiling = mgpu.TileTransformAttr(transform).tiling
tiling_transform = launch_context.TileTransform(tuple(tiling))
gmem_transforms.append(tiling_transform)
elif mgpu.TransposeTransformAttr.isinstance(transform):
permutation = mgpu.TransposeTransformAttr(transform).permutation
transpose_transform = launch_context.TransposeTransform(
tuple(permutation)
)
gmem_transforms.append(transpose_transform)
else:
raise ValueError(f"{layout} has an unsupported transform: {transform}")
for transform in transforms:
if swizzle is not None:
raise ValueError(f"{transforms} contain more transforms after swizzle.")
if mgpu.SwizzleTransformAttr.isinstance(transform):
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
# ways. We might want to enforce some restrictions here.
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
elif mgpu.TileTransformAttr.isinstance(transform):
tiling = mgpu.TileTransformAttr(transform).tiling
tiling_transform = launch_context.TileTransform(tuple(tiling))
gmem_transforms.append(tiling_transform)
elif mgpu.TransposeTransformAttr.isinstance(transform):
permutation = mgpu.TransposeTransformAttr(transform).permutation
transpose_transform = launch_context.TransposeTransform(
tuple(permutation)
)
gmem_transforms.append(transpose_transform)
else:
raise ValueError("Unknown transform: {transform}")
return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms)
@ -434,8 +510,14 @@ def _mgpu_async_load_op_lowering_rule(
assert ctx.launch_context is not None
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
dst_layout = ir.MemRefType(load_op.destination.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout)
if inference_utils.has_in_transforms_set(load_op):
[transforms] = inference_utils.in_transforms(load_op)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
transforms
)
else:
swizzle = mgpu.SwizzlingMode.kNoSwizzle
transforms = ()
gmem_slice = []
for idx_i32, size in zip(load_op.indices, load_op.slice_lengths):
@ -464,8 +546,14 @@ def _mgpu_async_store_op_lowering_rule(
) -> Sequence[ir.Value]:
assert ctx.launch_context is not None
src_layout = ir.MemRefType(store_op.source.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout)
if inference_utils.has_in_transforms_set(store_op):
[transforms] = inference_utils.in_transforms(store_op)
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
transforms
)
else:
swizzle = mgpu.SwizzlingMode.kNoSwizzle
transforms = ()
gmem_slice = []
for idx_i32, size in zip(store_op.indices, store_op.slice_lengths):
@ -673,6 +761,9 @@ def _bitcast_op_lowering_rule(
def _mgpu_wgmma_op_lowering_rule(
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
) -> Sequence[ir.Value]:
if wgmma_op.transpose_a or wgmma_op.transpose_b:
raise ValueError("Transpose arguments are to be deleted.")
fa_layouts = (
*inference_utils.in_layouts(wgmma_op),
*inference_utils.out_layouts(wgmma_op),
@ -691,25 +782,38 @@ def _mgpu_wgmma_op_lowering_rule(
regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
acc = wgmma.WGMMAAccumulator.from_registers(regs)
b_layout = ir.MemRefType(wgmma_op.b.type).layout
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
if ir.VectorType.isinstance(wgmma_op.a.type):
a_transforms = None
b_transforms = inference_utils.in_transforms(wgmma_op)[0]
else:
a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op)
b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr(
b_transforms
)
minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle
ref_ty = ir.MemRefType(wgmma_op.b.type)
_check_transforms_and_swizzle_are_supported(
ref_ty, b_transforms, b_swizzle, minimum_swizzle
)
b_operand = transform_memref(wgmma_op.b, b_transforms)
if wgmma_op.transpose_b:
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
if ir.VectorType.isinstance(wgmma_op.a.type):
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
else:
a_layout = ir.MemRefType(wgmma_op.a.type).layout
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout)
a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr(
a_transforms
)
ref_ty = ir.MemRefType(wgmma_op.a.type)
_check_transforms_and_swizzle_are_supported(
ref_ty, a_transforms, a_swizzle, minimum_swizzle
)
if a_swizzle != b_swizzle:
raise ValueError(
f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !="
f" {b_swizzle}"
)
a_operand = transform_memref(wgmma_op.a, a_transforms)
if wgmma_op.transpose_a:
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
@ -902,7 +1006,7 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
def _should_lower(op: ir.OpView) -> bool:
"""Returns 'true' if the operation should be lowered."""
return (
op.OPERATION_NAME.startswith("mosaic_gpu.")
op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error
or inference_utils.should_have_layout(op)
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
)

View File

@ -383,89 +383,6 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
return [layout], [layout]
@dataclasses.dataclass()
class WGMMATransforms:
swizzle: mgpu.SwizzlingMode
a_tile: tuple[int, ...]
a_transpose: bool
b_tile: tuple[int, ...]
b_transpose: bool
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
# Try tiling with all swizzling modes starting from the largest one.
for swizzle in [
mgpu.SwizzlingMode.k128ByteSwizzle,
mgpu.SwizzlingMode.k64ByteSwizzle,
mgpu.SwizzlingMode.k32ByteSwizzle,
]:
s = swizzle * 8 // bitwidth
if k % s == 0:
return WGMMATransforms(
swizzle=swizzle,
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
a_transpose=wgmma_op.transpose_a,
b_tile=(s, s),
b_transpose=wgmma_op.transpose_b,
)
raise ValueError(
"Could not infer layouts for memref feeding into WGMMA. The "
"non-contracting dimension ({k}) must be a multiple of "
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
"`a` and `b`."
)
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
wgmma_use = None
uses = cast(ir.OpResult, view_op.result).uses
for use in uses:
user = use.owner
if isinstance(user, memref.CastOp):
# This memref is already cast, so we don't need to do anything.
return None
if isinstance(user, mgpu.WGMMAOp):
if wgmma_use is not None:
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
wgmma_use = use
break
if (
not isinstance(user, mgpu.AsyncLoadOp)
and not isinstance(user, mgpu.AsyncStoreOp)
and not isinstance(user, vector.LoadOp)
and not isinstance(user, vector.StoreOp)
):
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
if wgmma_use is None:
# This memref is not used by a WGMMA operation, so we don't need to do
# anything.
return None
transforms = infer_wgmma_transforms(wgmma_use.owner)
if wgmma_use.operand_number == 1:
tile = transforms.a_tile
transpose = transforms.a_transpose
else:
tile = transforms.b_tile
transpose = transforms.b_transpose
transpose_attr = (
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
)
layout = mgpu.LayoutAttr.get(
2,
[mgpu.TileTransformAttr.get(tile)]
+ transpose_attr
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
)
return layout
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
owners = [use.owner for use in uses]
@ -607,11 +524,3 @@ def infer_layout(module: ir.Module):
for op in module.body:
traverse_op(op, set_default_layout)
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
if op.name == "memref.view":
if layout := _layout_for_memref_view(op):
_insert_memref_layout_cast(layout, op)
for op in module.body:
traverse_op(op, infer_memref_layouts_and_insert_casts)

View File

@ -26,6 +26,9 @@ from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import builtin
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import vector
from . import fragmented_array as fa
@ -169,7 +172,6 @@ def _infer_vector_load_store_transforms(
return None
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
@ -196,6 +198,60 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
return None if transforms is None else ([], [transforms])
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
# the dialect in all cases.
# The rule is necessary in order to handle the lowering of `utils.memref_ptr`
# which is used in `_construct_smem_reftree`.
@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp)
def _infer_unrealized_conversion_cast_transforms(
_: builtin.UnrealizedConversionCastOp,
) -> OptionalTransforms:
return None
@partial(_add_transform_inference_rule, memref.ViewOp)
def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms:
if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp):
raise NotImplementedError(
"Memref view transforms are only inferred when the op is a direct user "
f"of a DynamicSharedMemoryOp but got {op}."
)
transforms = inference_utils.value_transforms(op.source)
if transforms is not None:
raise NotImplementedError(
"memref view with in_transforms aren't yet supported"
)
uses = cast(ir.OpResult, op.result).uses
for op_operand_use in uses:
consumer = op_operand_use.owner
op_user = consumer.operands[op_operand_use.operand_number]
out_transforms = inference_utils.in_transforms_for_operand(
consumer, op_user
)
if transforms is not None and out_transforms is not None:
if transforms != out_transforms:
raise ValueError(
f"Conflicting transforms for {op_user} in {op}: "
f"{transforms} != {out_transforms}."
)
elif out_transforms is not None:
transforms = out_transforms
# TODO(bchetioui): do we actually need to assign a transform to the input of
# the view op? Presumably, it'll only be used to access scratch memory.
return None if transforms is None else ([], [transforms])
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
# the dialect in all cases.
@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp)
def _infer_dynamic_smem_transforms(
_: gpu.DynamicSharedMemoryOp,
) -> OptionalTransforms:
return None
def _should_have_transforms(op: ir.OpView) -> bool:
"""Returns 'True' if the operation should be assigned in/out transforms."""
return any(
@ -218,7 +274,6 @@ def infer_transforms(module: ir.Module):
specified. We error out if two distinct sets of transforms are competing to
annotate the same memref.
"""
def inference_step(op: ir.Operation):
if not _should_have_transforms(op):
return

View File

@ -55,6 +55,7 @@ else:
from jax.experimental.mosaic.gpu import launch_context
from jax.experimental.mosaic.gpu import utils as utils
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu import inference_utils
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
@ -2405,25 +2406,21 @@ class Swizzle:
return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle)
def memref_with_transforms(
mem_ref: ir.Value,
transforms: Sequence[Tile | Transpose | Swizzle],
) -> ir.Value:
"""Casts the memref to one that has a layout with the given transforms."""
mem_ref_type = ir.MemRefType(mem_ref.type)
def set_in_transforms(
op: ir.OpView, transforms: Sequence[Sequence[Tile | Transpose | Swizzle]],
) -> None:
"""Annotates an op with in_transforms."""
if not transforms:
return
transform_attr = [t.attr() for t in transforms]
if not transform_attr:
return mem_ref
in_transforms = []
smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable
for _, result_transforms in jax.util.safe_zip(smem_refs, transforms):
in_transforms.append(
ir.ArrayAttr.get([t.attr() for t in result_transforms])
)
layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr)
memref_new_type = ir.MemRefType.get(
mem_ref_type.shape,
mem_ref_type.element_type,
layout,
mem_ref_type.memory_space,
)
return memref.cast(memref_new_type, mem_ref)
op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms)
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
@ -2556,7 +2553,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
):
del ctx
smem_ref, tma_barrier = smem
smem_ref = memref_with_transforms(smem_ref, test_case.transforms)
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
elt_type = ir.MemRefType(in_gmem_ref.type).element_type
@ -2571,7 +2567,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices]
# GMEM -> SMEM
mgpu_dialect.async_load(
load_op = mgpu_dialect.AsyncLoadOp(
source=in_gmem_ref,
destination=smem_ref,
barrier=dialect_barrier,
@ -2579,6 +2575,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_lengths=test_case.slice_lengths,
collective=ir.ArrayAttr.get([]),
)
set_in_transforms(load_op, [test_case.transforms])
parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities)
@ -2623,58 +2620,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
(x[input_slice]).reshape(test_case.shape_sliced),
)
@staticmethod
def pointwise_kernel_with_tma_cases(dtype: jnp.dtype):
@dataclasses.dataclass(frozen=True)
class TestCaseInput:
shape: tuple[int, ...]
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
result = []
for swizzle in mgpu_dialect.SwizzlingMode:
n = swizzle * 8 // jnp.finfo(dtype).bits
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
# We need at least one case with no transforms, as this is handled
# differently.
result.append(TestCaseInput(shape=[128, n]))
result.extend([
TestCaseInput(
shape=[128, n],
transforms=[Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[
Transpose([1, 0, 2, 3]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[128, n],
transforms=[Tile([64, n]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2 * 64, 3 * n],
transforms=[
Tile([64, n]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
])
return result
@parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16))
def test_pointwise_kernel_with_tma(self, test_case):
def test_pointwise_kernel_with_tma(self):
def add(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
@ -2701,9 +2647,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=memref_with_transforms(
a_smem_ref, test_case.transforms
),
destination=a_smem_ref,
barrier=dialect_barrier,
indices=zero_slice_indices,
slice_lengths=shape,
@ -2711,9 +2655,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=memref_with_transforms(
b_smem_ref, test_case.transforms
),
destination=b_smem_ref,
barrier=dialect_barrier,
indices=zero_slice_indices,
slice_lengths=shape,
@ -2740,9 +2682,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# SMEM -> GMEM
mgpu_dialect.async_store(
source=memref_with_transforms(
result_smem_ref, test_case.transforms
),
source=result_smem_ref,
destination=result_gmem_ref,
indices=zero_slice_indices,
slice_lengths=shape,
@ -2752,114 +2692,76 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
dtype = jnp.bfloat16
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype)
spec = jax.ShapeDtypeStruct((2, 3, 4, 64), dtype)
kernel = mgpu.as_gpu_kernel(
add,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape, jax_shape),
out_shape=jax_shape,
in_shape=(spec, spec),
out_shape=spec,
smem_scratch_shape=[
jax_shape,
jax_shape,
jax_shape,
spec,
spec,
spec,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
x = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
y = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
@staticmethod
def wgmma_kernel_with_tma_cases(abtype: jnp.dtype):
@dataclasses.dataclass(frozen=True)
class TestCaseInput:
shape_a: tuple[int, ...] = ()
shape_b: tuple[int, ...] = ()
shape_res: tuple[int, ...] = ()
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = ()
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
transpose_a: bool = False
transpose_b: bool = False
load_a_in_registers: bool = False
@parameterized.named_parameters(
(
f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}",
swizzle,
transpose_lhs,
transpose_rhs,
lhs_in_registers,
)
for swizzle in mgpu_dialect.SwizzlingMode
for transpose_lhs in [False, True]
for transpose_rhs in [False, True]
for lhs_in_registers in [False, True]
)
def test_wgmma_kernel_with_tma(
self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers
):
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
self.skipTest("No swizzle is not supported by wgmma")
result = []
for swizzle in [
# TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes.
mgpu_dialect.SwizzlingMode.k32ByteSwizzle,
mgpu_dialect.SwizzlingMode.k64ByteSwizzle,
mgpu_dialect.SwizzlingMode.k128ByteSwizzle,
]:
k = swizzle // np.dtype(abtype).itemsize
groups_m = 4
groups_n = 1
groups_k = 1
result.extend([
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_n * k, groups_k * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_b=True,
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
load_a_in_registers=True,
),
])
# The below only works for 128-byte swizzling. Regardless of transposing,
# TMA needs the size of the last dimension to be compatible with the
# swizzle.
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
result.append(
TestCaseInput(
shape_a=[groups_k * k, groups_m * 64],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_a=True,
)
)
return result
if transpose_lhs or transpose_rhs:
self.skipTest("Transposes are not supported by transform inference yet.")
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16))
def test_wgmma_kernel_with_tma(self, test_case):
swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize
tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems
groups_m, groups_n, groups_k = 4, 1, 1
m, n, k = groups_m * tiling_m, groups_n * tiling_n, groups_k * tiling_k
lhs_shape = (k, m) if transpose_lhs else (m, k)
rhs_shape = (n, k) if transpose_rhs else (k, n)
out_shape = (m, n)
def matmul(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
b_gmem_ref: ir.Value,
lhs_gmem_ref: ir.Value,
rhs_gmem_ref: ir.Value,
result_gmem_ref: ir.Value,
smem: list[ir.Value],
):
del ctx
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a)
b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b)
lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type
bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a)
bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b)
operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type
bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape)
bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape)
mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier,
@ -2869,19 +2771,19 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
source=lhs_gmem_ref,
destination=lhs_smem_ref,
barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_a),
slice_lengths=test_case.shape_a,
indices=[zero_i32] * len(lhs_shape),
slice_lengths=lhs_shape,
collective=ir.ArrayAttr.get([]),
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
source=rhs_gmem_ref,
destination=rhs_smem_ref,
barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_b),
slice_lengths=test_case.shape_b,
indices=[zero_i32] * len(rhs_shape),
slice_lengths=rhs_shape,
collective=ir.ArrayAttr.get([]),
)
@ -2889,29 +2791,34 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity)
# SMEM -> Registers
a_operand = a_smem_ref
zero_index = arith.constant(ir.IndexType.get(), 0)
if test_case.load_a_in_registers:
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
zero_vector_indices = [zero_index] * len(test_case.shape_a)
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
# Computation
shape_result = ir.MemRefType(result_gmem_ref.type).shape
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
acc_elt_type = ir.F32Type.get()
acc_type = ir.VectorType.get(shape_result, acc_elt_type)
zero_acc = arith.constant(
result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0)
)
accumulator = vector.splat(
ir.VectorType.get(shape_result, result_elt_type), zero_acc
result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0)
)
accumulator = vector.splat(acc_type, zero_acc)
if transpose_lhs:
lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0))
if transpose_rhs:
rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0))
zero_index = arith.constant(ir.IndexType.get(), 0)
if load_a_in_registers:
# SMEM -> Registers
lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type)
zero_vector_indices = [zero_index] * len(lhs_shape)
lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices)
else:
lhs_operand = lhs_smem_ref
result = mgpu_dialect.wgmma(
accumulator,
a_operand,
b_smem_ref,
transpose_a=test_case.transpose_a,
transpose_b=test_case.transpose_b,
lhs_operand,
rhs_smem_ref,
)
nvvm.wgmma_commit_group_sync_aligned()
@ -2929,38 +2836,41 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
)
nvvm.cp_async_bulk_wait_group(0)
abtype = jnp.bfloat16
operand_type = jnp.bfloat16
acctype = jnp.float32
a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype)
b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype)
result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype)
lhs_jax_shape = jax.ShapeDtypeStruct(lhs_shape, operand_type)
rhs_jax_shape = jax.ShapeDtypeStruct(rhs_shape, operand_type)
result_jax_shape = jax.ShapeDtypeStruct(out_shape, acctype)
kernel = mgpu.as_gpu_kernel(
matmul,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(a_jax_shape, b_jax_shape),
in_shape=(lhs_jax_shape, rhs_jax_shape),
out_shape=result_jax_shape,
smem_scratch_shape=[
a_jax_shape,
b_jax_shape,
lhs_jax_shape,
rhs_jax_shape,
result_jax_shape,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype)
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype)
prng_key = jax.random.key(1234)
k0, k1 = jax.random.split(prng_key, 2)
x = jax.random.randint(k0, lhs_shape, 0, 2).astype(operand_type)
y = jax.random.randint(k1, rhs_shape, 0, 2).astype(operand_type)
transpose = lambda x, t: x.T if t else x
self.assertArraysAllClose(
jax.jit(kernel)(x, y),
np.matmul(
transpose(x.reshape(test_case.shape_a), test_case.transpose_a),
transpose(y.reshape(test_case.shape_b), test_case.transpose_b),
transpose(x, transpose_lhs),
transpose(y, transpose_rhs)
),
atol=1e-5,
rtol=1e-5,
atol=0,
rtol=0,
)