mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Enable the new transform inference pass in the warpgroup lowering.
A couple of dummy transform inference rules needed to be added in order to contend with parts of the lowering that do not use the dialect yet, along with a transform inference rule for `memref.view`. PiperOrigin-RevId: 738089782
This commit is contained in:
parent
547d602760
commit
875099b25d
@ -41,20 +41,16 @@ from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
import numpy as np
|
||||
|
||||
if dialect is not None:
|
||||
from . import dialect_lowering
|
||||
from . import layout_inference
|
||||
else:
|
||||
dialect_lowering = None
|
||||
layout_inference = None
|
||||
|
||||
from . import profiler
|
||||
from . import utils
|
||||
from . import launch_context
|
||||
from . import tcgen05
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
from . import dialect_lowering
|
||||
from . import launch_context
|
||||
from . import layout_inference
|
||||
from . import profiler
|
||||
from . import tcgen05
|
||||
from . import transform_inference
|
||||
from . import utils
|
||||
|
||||
# MLIR can't find libdevice unless we point it to the CUDA path
|
||||
# TODO(apaszke): Unify with jax._src.lib.cuda_path
|
||||
CUDA_ROOT = "/usr/local/cuda"
|
||||
@ -584,6 +580,7 @@ def as_gpu_kernel(
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
layout_inference.infer_layout(module) # pytype: disable=attribute-error
|
||||
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
_initialize_scratch(launch_ctx, scratch_arr)
|
||||
@ -666,6 +663,7 @@ def as_torch_gpu_kernel(
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
layout_inference.infer_layout(module) # pytype: disable=attribute-error
|
||||
transform_inference.infer_transforms(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
_initialize_scratch(launch_ctx, scratch_arr)
|
||||
|
@ -17,6 +17,7 @@
|
||||
from collections.abc import Callable
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Sequence, Type, cast
|
||||
|
||||
@ -58,7 +59,7 @@ class LoweringContext:
|
||||
if not _should_lower(op):
|
||||
return
|
||||
|
||||
if (name := op.OPERATION_NAME) not in _lowerings:
|
||||
if (name := op.OPERATION_NAME) not in _lowerings: # pytype: disable=attribute-error
|
||||
raise NotImplementedError(f"Missing lowering rule for {op}")
|
||||
|
||||
lowering_rule = _lowerings[name]
|
||||
@ -227,6 +228,60 @@ def _arith_constant_op_lowering_rule(
|
||||
]
|
||||
|
||||
|
||||
def _check_transforms_and_swizzle_are_supported(
|
||||
ref_ty: ir.MemRefType,
|
||||
transforms: Sequence[launch_context.MemRefTransform],
|
||||
swizzle: mgpu.SwizzlingMode,
|
||||
minimum_swizzle: mgpu.SwizzlingMode = mgpu.SwizzlingMode.kNoSwizzle,
|
||||
):
|
||||
"""Checks that the list of provided transforms and swizzle are supported.
|
||||
|
||||
Currently, we allow the following:
|
||||
- any swizzle that is larger than or equal to `minimum_swizzle`;
|
||||
- optionally, a single tile transform (with rank equal to the rank of the
|
||||
memref being annotated);
|
||||
- optionally, a single transpose transform.
|
||||
"""
|
||||
if swizzle < minimum_swizzle:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported swizzle {swizzle} smaller than {minimum_swizzle}."
|
||||
)
|
||||
|
||||
partitioned_transforms = {
|
||||
k: list(v)
|
||||
for k, v in itertools.groupby(
|
||||
transforms, lambda t: isinstance(t, launch_context.TileTransform)
|
||||
)
|
||||
}
|
||||
|
||||
tile_transforms = partitioned_transforms.get(True, [])
|
||||
other_transforms = partitioned_transforms.get(False, [])
|
||||
|
||||
if len(tile_transforms) > 1:
|
||||
raise NotImplementedError(
|
||||
f"{tile_transforms} contains more than one tile transform."
|
||||
)
|
||||
|
||||
if len(tile_transforms) == 1:
|
||||
if len(tile_transforms[0].tiling) != len(ref_ty.shape):
|
||||
raise NotImplementedError(
|
||||
f"Only tile transforms with rank equal to the rank of the memref "
|
||||
f"being annotated are supported but got {tile_transforms[0]} for "
|
||||
f"{ref_ty}."
|
||||
)
|
||||
|
||||
if len(other_transforms) > 1:
|
||||
raise NotImplementedError(
|
||||
f"{other_transforms} contains more than one transform."
|
||||
)
|
||||
|
||||
if len(other_transforms) == 1:
|
||||
if not isinstance(other_transforms[0], launch_context.TransposeTransform):
|
||||
raise NotImplementedError(
|
||||
f"{other_transforms[0]} is not a transpose transform."
|
||||
)
|
||||
|
||||
|
||||
@_register_lowering(vector.LoadOp)
|
||||
def _vector_load_op_lowering_rule(
|
||||
_: LoweringContext, vector_load_op: vector.LoadOp
|
||||
@ -260,8 +315,11 @@ def _vector_load_op_lowering_rule(
|
||||
vec_size=strided_layout.vec_size,
|
||||
)
|
||||
elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT:
|
||||
layout = ir.MemRefType(vector_load_op.base.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
|
||||
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
inference_utils.in_transforms(vector_load_op)[0]
|
||||
)
|
||||
ref_ty = ir.MemRefType(vector_load_op.base.type)
|
||||
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
|
||||
transformed_ref = transform_memref(vector_load_op.base, transforms)
|
||||
fragmented_array = fa.FragmentedArray.load_tiled(
|
||||
transformed_ref,
|
||||
@ -297,8 +355,22 @@ def _vector_store_op_lowering_rule(
|
||||
vector_store_op.valueToStore, to_store_layout
|
||||
)
|
||||
|
||||
# TODO(dasenov): This is not efficient for WGMMA layouts
|
||||
fragmented_array.store_untiled(vector_store_op.base)
|
||||
if fragmented_array.layout == fa.WGMMA_LAYOUT:
|
||||
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
inference_utils.in_transforms(vector_store_op)[0]
|
||||
)
|
||||
ref_ty = ir.MemRefType(vector_store_op.base.type)
|
||||
_check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle)
|
||||
fragmented_array.store_tiled(
|
||||
transform_memref(vector_store_op.base, transforms), swizzle
|
||||
)
|
||||
elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or
|
||||
isinstance(fragmented_array.layout, fa.WGSplatFragLayout)):
|
||||
fragmented_array.store_untiled(vector_store_op.base)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{vector_store_op} has an unsupported layout: {to_store_layout}"
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
@ -362,39 +434,43 @@ def _vector_reduction_op_lowering_rule(
|
||||
return [_fragmented_array_to_ir(result, op.result.type)]
|
||||
|
||||
|
||||
def memref_layout_to_swizzle_and_transforms(
|
||||
layout: ir.Attribute,
|
||||
def swizzle_and_transforms_from_transforms_attr(
|
||||
transforms: ir.ArrayAttr,
|
||||
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
|
||||
"""Returns the swizzle and transforms that are encoded in the given layout.
|
||||
"""Returns the swizzle and MemrefTransforms for the given transforms.
|
||||
|
||||
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the
|
||||
transforms are empty. Otherwise, the layout may have at most one swizzle
|
||||
transform and any combination of tiling and transpose transforms.
|
||||
Args:
|
||||
transforms: a list of transform attributes.
|
||||
|
||||
Returns:
|
||||
A tuple containing the swizzle mode and MemRefTransforms corresponding to
|
||||
the parameter transforms. If `transforms` is empty, or does not contain
|
||||
any swizzling transform, the swizzle mode is assumed to be kNoSwizzle.
|
||||
Raises:
|
||||
ValueError: if a swizzling transform is followed by any transform.
|
||||
"""
|
||||
swizzle = None
|
||||
gmem_transforms: list[launch_context.MemRefTransform] = []
|
||||
|
||||
if mgpu.LayoutAttr.isinstance(layout):
|
||||
transforms_attr = mgpu.LayoutAttr(layout).transforms
|
||||
for transform in transforms_attr:
|
||||
if swizzle is not None:
|
||||
raise ValueError(f"{layout} contains more transforms after the initial swizzle.")
|
||||
if mgpu.SwizzleTransformAttr.isinstance(transform):
|
||||
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
|
||||
# ways. We might want to enforce some restrictions here.
|
||||
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
|
||||
elif mgpu.TileTransformAttr.isinstance(transform):
|
||||
tiling = mgpu.TileTransformAttr(transform).tiling
|
||||
tiling_transform = launch_context.TileTransform(tuple(tiling))
|
||||
gmem_transforms.append(tiling_transform)
|
||||
elif mgpu.TransposeTransformAttr.isinstance(transform):
|
||||
permutation = mgpu.TransposeTransformAttr(transform).permutation
|
||||
transpose_transform = launch_context.TransposeTransform(
|
||||
tuple(permutation)
|
||||
)
|
||||
gmem_transforms.append(transpose_transform)
|
||||
else:
|
||||
raise ValueError(f"{layout} has an unsupported transform: {transform}")
|
||||
for transform in transforms:
|
||||
if swizzle is not None:
|
||||
raise ValueError(f"{transforms} contain more transforms after swizzle.")
|
||||
if mgpu.SwizzleTransformAttr.isinstance(transform):
|
||||
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
|
||||
# ways. We might want to enforce some restrictions here.
|
||||
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
|
||||
elif mgpu.TileTransformAttr.isinstance(transform):
|
||||
tiling = mgpu.TileTransformAttr(transform).tiling
|
||||
tiling_transform = launch_context.TileTransform(tuple(tiling))
|
||||
gmem_transforms.append(tiling_transform)
|
||||
elif mgpu.TransposeTransformAttr.isinstance(transform):
|
||||
permutation = mgpu.TransposeTransformAttr(transform).permutation
|
||||
transpose_transform = launch_context.TransposeTransform(
|
||||
tuple(permutation)
|
||||
)
|
||||
gmem_transforms.append(transpose_transform)
|
||||
else:
|
||||
raise ValueError("Unknown transform: {transform}")
|
||||
|
||||
return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms)
|
||||
|
||||
@ -434,8 +510,14 @@ def _mgpu_async_load_op_lowering_rule(
|
||||
assert ctx.launch_context is not None
|
||||
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
|
||||
|
||||
dst_layout = ir.MemRefType(load_op.destination.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout)
|
||||
if inference_utils.has_in_transforms_set(load_op):
|
||||
[transforms] = inference_utils.in_transforms(load_op)
|
||||
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
transforms
|
||||
)
|
||||
else:
|
||||
swizzle = mgpu.SwizzlingMode.kNoSwizzle
|
||||
transforms = ()
|
||||
|
||||
gmem_slice = []
|
||||
for idx_i32, size in zip(load_op.indices, load_op.slice_lengths):
|
||||
@ -464,8 +546,14 @@ def _mgpu_async_store_op_lowering_rule(
|
||||
) -> Sequence[ir.Value]:
|
||||
assert ctx.launch_context is not None
|
||||
|
||||
src_layout = ir.MemRefType(store_op.source.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout)
|
||||
if inference_utils.has_in_transforms_set(store_op):
|
||||
[transforms] = inference_utils.in_transforms(store_op)
|
||||
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
transforms
|
||||
)
|
||||
else:
|
||||
swizzle = mgpu.SwizzlingMode.kNoSwizzle
|
||||
transforms = ()
|
||||
|
||||
gmem_slice = []
|
||||
for idx_i32, size in zip(store_op.indices, store_op.slice_lengths):
|
||||
@ -673,6 +761,9 @@ def _bitcast_op_lowering_rule(
|
||||
def _mgpu_wgmma_op_lowering_rule(
|
||||
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
|
||||
) -> Sequence[ir.Value]:
|
||||
if wgmma_op.transpose_a or wgmma_op.transpose_b:
|
||||
raise ValueError("Transpose arguments are to be deleted.")
|
||||
|
||||
fa_layouts = (
|
||||
*inference_utils.in_layouts(wgmma_op),
|
||||
*inference_utils.out_layouts(wgmma_op),
|
||||
@ -691,25 +782,38 @@ def _mgpu_wgmma_op_lowering_rule(
|
||||
regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
|
||||
acc = wgmma.WGMMAAccumulator.from_registers(regs)
|
||||
|
||||
b_layout = ir.MemRefType(wgmma_op.b.type).layout
|
||||
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
|
||||
if ir.VectorType.isinstance(wgmma_op.a.type):
|
||||
a_transforms = None
|
||||
b_transforms = inference_utils.in_transforms(wgmma_op)[0]
|
||||
else:
|
||||
a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op)
|
||||
|
||||
b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
b_transforms
|
||||
)
|
||||
minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle
|
||||
ref_ty = ir.MemRefType(wgmma_op.b.type)
|
||||
_check_transforms_and_swizzle_are_supported(
|
||||
ref_ty, b_transforms, b_swizzle, minimum_swizzle
|
||||
)
|
||||
b_operand = transform_memref(wgmma_op.b, b_transforms)
|
||||
if wgmma_op.transpose_b:
|
||||
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
|
||||
|
||||
if ir.VectorType.isinstance(wgmma_op.a.type):
|
||||
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
|
||||
else:
|
||||
a_layout = ir.MemRefType(wgmma_op.a.type).layout
|
||||
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout)
|
||||
a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr(
|
||||
a_transforms
|
||||
)
|
||||
ref_ty = ir.MemRefType(wgmma_op.a.type)
|
||||
_check_transforms_and_swizzle_are_supported(
|
||||
ref_ty, a_transforms, a_swizzle, minimum_swizzle
|
||||
)
|
||||
if a_swizzle != b_swizzle:
|
||||
raise ValueError(
|
||||
f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !="
|
||||
f" {b_swizzle}"
|
||||
)
|
||||
a_operand = transform_memref(wgmma_op.a, a_transforms)
|
||||
if wgmma_op.transpose_a:
|
||||
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
|
||||
|
||||
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
|
||||
|
||||
@ -902,7 +1006,7 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
|
||||
def _should_lower(op: ir.OpView) -> bool:
|
||||
"""Returns 'true' if the operation should be lowered."""
|
||||
return (
|
||||
op.OPERATION_NAME.startswith("mosaic_gpu.")
|
||||
op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error
|
||||
or inference_utils.should_have_layout(op)
|
||||
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
|
||||
)
|
||||
|
@ -383,89 +383,6 @@ def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
|
||||
|
||||
return [layout], [layout]
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class WGMMATransforms:
|
||||
swizzle: mgpu.SwizzlingMode
|
||||
a_tile: tuple[int, ...]
|
||||
a_transpose: bool
|
||||
b_tile: tuple[int, ...]
|
||||
b_transpose: bool
|
||||
|
||||
|
||||
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
|
||||
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
|
||||
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
|
||||
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
|
||||
|
||||
# Try tiling with all swizzling modes starting from the largest one.
|
||||
for swizzle in [
|
||||
mgpu.SwizzlingMode.k128ByteSwizzle,
|
||||
mgpu.SwizzlingMode.k64ByteSwizzle,
|
||||
mgpu.SwizzlingMode.k32ByteSwizzle,
|
||||
]:
|
||||
s = swizzle * 8 // bitwidth
|
||||
if k % s == 0:
|
||||
return WGMMATransforms(
|
||||
swizzle=swizzle,
|
||||
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
|
||||
a_transpose=wgmma_op.transpose_a,
|
||||
b_tile=(s, s),
|
||||
b_transpose=wgmma_op.transpose_b,
|
||||
)
|
||||
raise ValueError(
|
||||
"Could not infer layouts for memref feeding into WGMMA. The "
|
||||
"non-contracting dimension ({k}) must be a multiple of "
|
||||
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
|
||||
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
|
||||
"`a` and `b`."
|
||||
)
|
||||
|
||||
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
|
||||
wgmma_use = None
|
||||
uses = cast(ir.OpResult, view_op.result).uses
|
||||
for use in uses:
|
||||
user = use.owner
|
||||
if isinstance(user, memref.CastOp):
|
||||
# This memref is already cast, so we don't need to do anything.
|
||||
return None
|
||||
if isinstance(user, mgpu.WGMMAOp):
|
||||
if wgmma_use is not None:
|
||||
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
|
||||
wgmma_use = use
|
||||
break
|
||||
if (
|
||||
not isinstance(user, mgpu.AsyncLoadOp)
|
||||
and not isinstance(user, mgpu.AsyncStoreOp)
|
||||
and not isinstance(user, vector.LoadOp)
|
||||
and not isinstance(user, vector.StoreOp)
|
||||
):
|
||||
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
|
||||
|
||||
if wgmma_use is None:
|
||||
# This memref is not used by a WGMMA operation, so we don't need to do
|
||||
# anything.
|
||||
return None
|
||||
|
||||
transforms = infer_wgmma_transforms(wgmma_use.owner)
|
||||
if wgmma_use.operand_number == 1:
|
||||
tile = transforms.a_tile
|
||||
transpose = transforms.a_transpose
|
||||
else:
|
||||
tile = transforms.b_tile
|
||||
transpose = transforms.b_transpose
|
||||
transpose_attr = (
|
||||
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
|
||||
)
|
||||
|
||||
layout = mgpu.LayoutAttr.get(
|
||||
2,
|
||||
[mgpu.TileTransformAttr.get(tile)]
|
||||
+ transpose_attr
|
||||
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
|
||||
)
|
||||
|
||||
return layout
|
||||
|
||||
|
||||
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
|
||||
owners = [use.owner for use in uses]
|
||||
@ -607,11 +524,3 @@ def infer_layout(module: ir.Module):
|
||||
|
||||
for op in module.body:
|
||||
traverse_op(op, set_default_layout)
|
||||
|
||||
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
|
||||
if op.name == "memref.view":
|
||||
if layout := _layout_for_memref_view(op):
|
||||
_insert_memref_layout_cast(layout, op)
|
||||
|
||||
for op in module.body:
|
||||
traverse_op(op, infer_memref_layouts_and_insert_casts)
|
||||
|
@ -26,6 +26,9 @@ from typing import cast
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import builtin
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
|
||||
from . import fragmented_array as fa
|
||||
@ -169,7 +172,6 @@ def _infer_vector_load_store_transforms(
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
|
||||
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
|
||||
|
||||
@ -196,6 +198,60 @@ def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
|
||||
return None if transforms is None else ([], [transforms])
|
||||
|
||||
|
||||
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
|
||||
# the dialect in all cases.
|
||||
# The rule is necessary in order to handle the lowering of `utils.memref_ptr`
|
||||
# which is used in `_construct_smem_reftree`.
|
||||
@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp)
|
||||
def _infer_unrealized_conversion_cast_transforms(
|
||||
_: builtin.UnrealizedConversionCastOp,
|
||||
) -> OptionalTransforms:
|
||||
return None
|
||||
|
||||
|
||||
@partial(_add_transform_inference_rule, memref.ViewOp)
|
||||
def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms:
|
||||
if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp):
|
||||
raise NotImplementedError(
|
||||
"Memref view transforms are only inferred when the op is a direct user "
|
||||
f"of a DynamicSharedMemoryOp but got {op}."
|
||||
)
|
||||
transforms = inference_utils.value_transforms(op.source)
|
||||
if transforms is not None:
|
||||
raise NotImplementedError(
|
||||
"memref view with in_transforms aren't yet supported"
|
||||
)
|
||||
uses = cast(ir.OpResult, op.result).uses
|
||||
|
||||
for op_operand_use in uses:
|
||||
consumer = op_operand_use.owner
|
||||
op_user = consumer.operands[op_operand_use.operand_number]
|
||||
out_transforms = inference_utils.in_transforms_for_operand(
|
||||
consumer, op_user
|
||||
)
|
||||
if transforms is not None and out_transforms is not None:
|
||||
if transforms != out_transforms:
|
||||
raise ValueError(
|
||||
f"Conflicting transforms for {op_user} in {op}: "
|
||||
f"{transforms} != {out_transforms}."
|
||||
)
|
||||
elif out_transforms is not None:
|
||||
transforms = out_transforms
|
||||
|
||||
# TODO(bchetioui): do we actually need to assign a transform to the input of
|
||||
# the view op? Presumably, it'll only be used to access scratch memory.
|
||||
return None if transforms is None else ([], [transforms])
|
||||
|
||||
|
||||
# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use
|
||||
# the dialect in all cases.
|
||||
@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp)
|
||||
def _infer_dynamic_smem_transforms(
|
||||
_: gpu.DynamicSharedMemoryOp,
|
||||
) -> OptionalTransforms:
|
||||
return None
|
||||
|
||||
|
||||
def _should_have_transforms(op: ir.OpView) -> bool:
|
||||
"""Returns 'True' if the operation should be assigned in/out transforms."""
|
||||
return any(
|
||||
@ -218,7 +274,6 @@ def infer_transforms(module: ir.Module):
|
||||
specified. We error out if two distinct sets of transforms are competing to
|
||||
annotate the same memref.
|
||||
"""
|
||||
|
||||
def inference_step(op: ir.Operation):
|
||||
if not _should_have_transforms(op):
|
||||
return
|
||||
|
@ -55,6 +55,7 @@ else:
|
||||
from jax.experimental.mosaic.gpu import launch_context
|
||||
from jax.experimental.mosaic.gpu import utils as utils
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.mosaic.gpu import inference_utils
|
||||
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
@ -2405,25 +2406,21 @@ class Swizzle:
|
||||
return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle)
|
||||
|
||||
|
||||
def memref_with_transforms(
|
||||
mem_ref: ir.Value,
|
||||
transforms: Sequence[Tile | Transpose | Swizzle],
|
||||
) -> ir.Value:
|
||||
"""Casts the memref to one that has a layout with the given transforms."""
|
||||
mem_ref_type = ir.MemRefType(mem_ref.type)
|
||||
def set_in_transforms(
|
||||
op: ir.OpView, transforms: Sequence[Sequence[Tile | Transpose | Swizzle]],
|
||||
) -> None:
|
||||
"""Annotates an op with in_transforms."""
|
||||
if not transforms:
|
||||
return
|
||||
|
||||
transform_attr = [t.attr() for t in transforms]
|
||||
if not transform_attr:
|
||||
return mem_ref
|
||||
in_transforms = []
|
||||
smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable
|
||||
for _, result_transforms in jax.util.safe_zip(smem_refs, transforms):
|
||||
in_transforms.append(
|
||||
ir.ArrayAttr.get([t.attr() for t in result_transforms])
|
||||
)
|
||||
|
||||
layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr)
|
||||
memref_new_type = ir.MemRefType.get(
|
||||
mem_ref_type.shape,
|
||||
mem_ref_type.element_type,
|
||||
layout,
|
||||
mem_ref_type.memory_space,
|
||||
)
|
||||
return memref.cast(memref_new_type, mem_ref)
|
||||
op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms)
|
||||
|
||||
|
||||
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
@ -2556,7 +2553,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
):
|
||||
del ctx
|
||||
smem_ref, tma_barrier = smem
|
||||
smem_ref = memref_with_transforms(smem_ref, test_case.transforms)
|
||||
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
|
||||
|
||||
elt_type = ir.MemRefType(in_gmem_ref.type).element_type
|
||||
@ -2571,7 +2567,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices]
|
||||
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
load_op = mgpu_dialect.AsyncLoadOp(
|
||||
source=in_gmem_ref,
|
||||
destination=smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
@ -2579,6 +2575,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
slice_lengths=test_case.slice_lengths,
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
set_in_transforms(load_op, [test_case.transforms])
|
||||
|
||||
parities = memref.load(tma_barrier.phases, [])
|
||||
parity, _ = tma_barrier.update_parities(parities)
|
||||
@ -2623,58 +2620,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
(x[input_slice]).reshape(test_case.shape_sliced),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pointwise_kernel_with_tma_cases(dtype: jnp.dtype):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TestCaseInput:
|
||||
shape: tuple[int, ...]
|
||||
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
|
||||
|
||||
result = []
|
||||
for swizzle in mgpu_dialect.SwizzlingMode:
|
||||
n = swizzle * 8 // jnp.finfo(dtype).bits
|
||||
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
|
||||
# We need at least one case with no transforms, as this is handled
|
||||
# differently.
|
||||
result.append(TestCaseInput(shape=[128, n]))
|
||||
result.extend([
|
||||
TestCaseInput(
|
||||
shape=[128, n],
|
||||
transforms=[Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[2, 3, 64, n],
|
||||
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[2, 3, 64, n],
|
||||
transforms=[
|
||||
Transpose([1, 0, 2, 3]),
|
||||
Transpose([1, 0, 2, 3]),
|
||||
Swizzle(swizzle),
|
||||
],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[2, 3, 64, n],
|
||||
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[128, n],
|
||||
transforms=[Tile([64, n]), Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape=[2 * 64, 3 * n],
|
||||
transforms=[
|
||||
Tile([64, n]),
|
||||
Transpose([1, 0, 2, 3]),
|
||||
Swizzle(swizzle),
|
||||
],
|
||||
),
|
||||
])
|
||||
return result
|
||||
|
||||
@parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16))
|
||||
def test_pointwise_kernel_with_tma(self, test_case):
|
||||
def test_pointwise_kernel_with_tma(self):
|
||||
def add(
|
||||
ctx: launch_context.LaunchContext,
|
||||
a_gmem_ref: ir.Value,
|
||||
@ -2701,9 +2647,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
source=a_gmem_ref,
|
||||
destination=memref_with_transforms(
|
||||
a_smem_ref, test_case.transforms
|
||||
),
|
||||
destination=a_smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
indices=zero_slice_indices,
|
||||
slice_lengths=shape,
|
||||
@ -2711,9 +2655,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
source=b_gmem_ref,
|
||||
destination=memref_with_transforms(
|
||||
b_smem_ref, test_case.transforms
|
||||
),
|
||||
destination=b_smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
indices=zero_slice_indices,
|
||||
slice_lengths=shape,
|
||||
@ -2740,9 +2682,7 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
|
||||
# SMEM -> GMEM
|
||||
mgpu_dialect.async_store(
|
||||
source=memref_with_transforms(
|
||||
result_smem_ref, test_case.transforms
|
||||
),
|
||||
source=result_smem_ref,
|
||||
destination=result_gmem_ref,
|
||||
indices=zero_slice_indices,
|
||||
slice_lengths=shape,
|
||||
@ -2752,114 +2692,76 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
|
||||
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype)
|
||||
spec = jax.ShapeDtypeStruct((2, 3, 4, 64), dtype)
|
||||
kernel = mgpu.as_gpu_kernel(
|
||||
add,
|
||||
grid=(1, 1, 1),
|
||||
block=(128, 1, 1),
|
||||
in_shape=(jax_shape, jax_shape),
|
||||
out_shape=jax_shape,
|
||||
in_shape=(spec, spec),
|
||||
out_shape=spec,
|
||||
smem_scratch_shape=[
|
||||
jax_shape,
|
||||
jax_shape,
|
||||
jax_shape,
|
||||
spec,
|
||||
spec,
|
||||
spec,
|
||||
core.TMABarrier(1),
|
||||
],
|
||||
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
||||
)
|
||||
|
||||
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
|
||||
y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
|
||||
x = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
|
||||
y = self.prng.uniform(-1, 1, spec.shape).astype(dtype)
|
||||
|
||||
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
|
||||
|
||||
|
||||
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
|
||||
@staticmethod
|
||||
def wgmma_kernel_with_tma_cases(abtype: jnp.dtype):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TestCaseInput:
|
||||
shape_a: tuple[int, ...] = ()
|
||||
shape_b: tuple[int, ...] = ()
|
||||
shape_res: tuple[int, ...] = ()
|
||||
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = ()
|
||||
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
|
||||
transpose_a: bool = False
|
||||
transpose_b: bool = False
|
||||
load_a_in_registers: bool = False
|
||||
@parameterized.named_parameters(
|
||||
(
|
||||
f"swizzle={int(swizzle)}_{transpose_lhs=}_{transpose_rhs=}_{lhs_in_registers=}",
|
||||
swizzle,
|
||||
transpose_lhs,
|
||||
transpose_rhs,
|
||||
lhs_in_registers,
|
||||
)
|
||||
for swizzle in mgpu_dialect.SwizzlingMode
|
||||
for transpose_lhs in [False, True]
|
||||
for transpose_rhs in [False, True]
|
||||
for lhs_in_registers in [False, True]
|
||||
)
|
||||
def test_wgmma_kernel_with_tma(
|
||||
self, swizzle, transpose_lhs, transpose_rhs, load_a_in_registers
|
||||
):
|
||||
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
|
||||
self.skipTest("No swizzle is not supported by wgmma")
|
||||
|
||||
result = []
|
||||
for swizzle in [
|
||||
# TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes.
|
||||
mgpu_dialect.SwizzlingMode.k32ByteSwizzle,
|
||||
mgpu_dialect.SwizzlingMode.k64ByteSwizzle,
|
||||
mgpu_dialect.SwizzlingMode.k128ByteSwizzle,
|
||||
]:
|
||||
k = swizzle // np.dtype(abtype).itemsize
|
||||
groups_m = 4
|
||||
groups_n = 1
|
||||
groups_k = 1
|
||||
result.extend([
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_n * k, groups_k * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transpose_b=True,
|
||||
),
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
|
||||
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
|
||||
),
|
||||
TestCaseInput(
|
||||
shape_a=[groups_m * 64, groups_k * k],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
|
||||
load_a_in_registers=True,
|
||||
),
|
||||
])
|
||||
# The below only works for 128-byte swizzling. Regardless of transposing,
|
||||
# TMA needs the size of the last dimension to be compatible with the
|
||||
# swizzle.
|
||||
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
|
||||
result.append(
|
||||
TestCaseInput(
|
||||
shape_a=[groups_k * k, groups_m * 64],
|
||||
shape_b=[groups_k * k, groups_n * k],
|
||||
shape_res=[groups_m * 64, groups_n * k],
|
||||
transpose_a=True,
|
||||
)
|
||||
)
|
||||
return result
|
||||
if transpose_lhs or transpose_rhs:
|
||||
self.skipTest("Transposes are not supported by transform inference yet.")
|
||||
|
||||
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16))
|
||||
def test_wgmma_kernel_with_tma(self, test_case):
|
||||
swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize
|
||||
tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems
|
||||
|
||||
groups_m, groups_n, groups_k = 4, 1, 1
|
||||
m, n, k = groups_m * tiling_m, groups_n * tiling_n, groups_k * tiling_k
|
||||
|
||||
lhs_shape = (k, m) if transpose_lhs else (m, k)
|
||||
rhs_shape = (n, k) if transpose_rhs else (k, n)
|
||||
out_shape = (m, n)
|
||||
|
||||
def matmul(
|
||||
ctx: launch_context.LaunchContext,
|
||||
a_gmem_ref: ir.Value,
|
||||
b_gmem_ref: ir.Value,
|
||||
lhs_gmem_ref: ir.Value,
|
||||
rhs_gmem_ref: ir.Value,
|
||||
result_gmem_ref: ir.Value,
|
||||
smem: list[ir.Value],
|
||||
):
|
||||
del ctx
|
||||
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
|
||||
a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a)
|
||||
b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b)
|
||||
lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem
|
||||
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
|
||||
|
||||
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type
|
||||
bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a)
|
||||
bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b)
|
||||
operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type
|
||||
bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape)
|
||||
bytes_b = utils.bytewidth(operand_elt_type) * math.prod(rhs_shape)
|
||||
|
||||
mgpu_dialect.arrive_expect_tx(
|
||||
barrier=dialect_barrier,
|
||||
@ -2869,19 +2771,19 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
source=a_gmem_ref,
|
||||
destination=a_smem_ref,
|
||||
source=lhs_gmem_ref,
|
||||
destination=lhs_smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
indices=[zero_i32] * len(test_case.shape_a),
|
||||
slice_lengths=test_case.shape_a,
|
||||
indices=[zero_i32] * len(lhs_shape),
|
||||
slice_lengths=lhs_shape,
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
source=b_gmem_ref,
|
||||
destination=b_smem_ref,
|
||||
source=rhs_gmem_ref,
|
||||
destination=rhs_smem_ref,
|
||||
barrier=dialect_barrier,
|
||||
indices=[zero_i32] * len(test_case.shape_b),
|
||||
slice_lengths=test_case.shape_b,
|
||||
indices=[zero_i32] * len(rhs_shape),
|
||||
slice_lengths=rhs_shape,
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
|
||||
@ -2889,29 +2791,34 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
parity, _ = tma_barrier.update_parities(parities)
|
||||
mgpu_dialect.wait(dialect_barrier, parity)
|
||||
|
||||
# SMEM -> Registers
|
||||
a_operand = a_smem_ref
|
||||
zero_index = arith.constant(ir.IndexType.get(), 0)
|
||||
if test_case.load_a_in_registers:
|
||||
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
|
||||
zero_vector_indices = [zero_index] * len(test_case.shape_a)
|
||||
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
|
||||
|
||||
# Computation
|
||||
shape_result = ir.MemRefType(result_gmem_ref.type).shape
|
||||
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
|
||||
acc_elt_type = ir.F32Type.get()
|
||||
acc_type = ir.VectorType.get(shape_result, acc_elt_type)
|
||||
zero_acc = arith.constant(
|
||||
result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0)
|
||||
)
|
||||
accumulator = vector.splat(
|
||||
ir.VectorType.get(shape_result, result_elt_type), zero_acc
|
||||
result_elt_type, ir.FloatAttr.get(acc_elt_type, 0.0)
|
||||
)
|
||||
accumulator = vector.splat(acc_type, zero_acc)
|
||||
|
||||
if transpose_lhs:
|
||||
lhs_smem_ref = utils.memref_transpose(lhs_smem_ref, (1, 0))
|
||||
if transpose_rhs:
|
||||
rhs_smem_ref = utils.memref_transpose(rhs_smem_ref, (1, 0))
|
||||
|
||||
zero_index = arith.constant(ir.IndexType.get(), 0)
|
||||
if load_a_in_registers:
|
||||
# SMEM -> Registers
|
||||
lhs_ty = ir.VectorType.get(lhs_shape, operand_elt_type)
|
||||
zero_vector_indices = [zero_index] * len(lhs_shape)
|
||||
lhs_operand = vector.load(lhs_ty, lhs_smem_ref, zero_vector_indices)
|
||||
else:
|
||||
lhs_operand = lhs_smem_ref
|
||||
|
||||
result = mgpu_dialect.wgmma(
|
||||
accumulator,
|
||||
a_operand,
|
||||
b_smem_ref,
|
||||
transpose_a=test_case.transpose_a,
|
||||
transpose_b=test_case.transpose_b,
|
||||
lhs_operand,
|
||||
rhs_smem_ref,
|
||||
)
|
||||
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
@ -2929,38 +2836,41 @@ class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
||||
)
|
||||
nvvm.cp_async_bulk_wait_group(0)
|
||||
|
||||
abtype = jnp.bfloat16
|
||||
operand_type = jnp.bfloat16
|
||||
acctype = jnp.float32
|
||||
a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype)
|
||||
b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype)
|
||||
result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype)
|
||||
lhs_jax_shape = jax.ShapeDtypeStruct(lhs_shape, operand_type)
|
||||
rhs_jax_shape = jax.ShapeDtypeStruct(rhs_shape, operand_type)
|
||||
result_jax_shape = jax.ShapeDtypeStruct(out_shape, acctype)
|
||||
kernel = mgpu.as_gpu_kernel(
|
||||
matmul,
|
||||
grid=(1, 1, 1),
|
||||
block=(128, 1, 1),
|
||||
in_shape=(a_jax_shape, b_jax_shape),
|
||||
in_shape=(lhs_jax_shape, rhs_jax_shape),
|
||||
out_shape=result_jax_shape,
|
||||
smem_scratch_shape=[
|
||||
a_jax_shape,
|
||||
b_jax_shape,
|
||||
lhs_jax_shape,
|
||||
rhs_jax_shape,
|
||||
result_jax_shape,
|
||||
core.TMABarrier(1),
|
||||
],
|
||||
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
||||
)
|
||||
|
||||
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype)
|
||||
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype)
|
||||
prng_key = jax.random.key(1234)
|
||||
k0, k1 = jax.random.split(prng_key, 2)
|
||||
|
||||
x = jax.random.randint(k0, lhs_shape, 0, 2).astype(operand_type)
|
||||
y = jax.random.randint(k1, rhs_shape, 0, 2).astype(operand_type)
|
||||
|
||||
transpose = lambda x, t: x.T if t else x
|
||||
self.assertArraysAllClose(
|
||||
jax.jit(kernel)(x, y),
|
||||
np.matmul(
|
||||
transpose(x.reshape(test_case.shape_a), test_case.transpose_a),
|
||||
transpose(y.reshape(test_case.shape_b), test_case.transpose_b),
|
||||
transpose(x, transpose_lhs),
|
||||
transpose(y, transpose_rhs)
|
||||
),
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user