mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
247 lines
8.1 KiB
Python
247 lines
8.1 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Transform inference pass for the MLIR Mosaic GPU dialect.
|
|
|
|
The transform inference pass is meant to run on IR that has already been
|
|
annotated with layouts (see `layout_inference.py` for the relevant pass).
|
|
"""
|
|
|
|
from collections.abc import Callable
|
|
from functools import partial
|
|
import itertools
|
|
from typing import cast
|
|
|
|
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import arith
|
|
from jax._src.lib.mlir.dialects import vector
|
|
|
|
from . import fragmented_array as fa
|
|
from . import inference_utils
|
|
from . import layouts as layouts_lib
|
|
from . import utils
|
|
|
|
# mypy: ignore-errors
|
|
|
|
OptionalTransforms = tuple[list[ir.Attribute], list[ir.Attribute]] | None
|
|
TransformInferenceRule = Callable[[ir.OpView], OptionalTransforms]
|
|
_transform_inference_rules: dict[str, TransformInferenceRule] = {}
|
|
|
|
|
|
def _add_transform_inference_rule(
|
|
op: type[ir.OpView], rule: TransformInferenceRule
|
|
):
|
|
if op is not None:
|
|
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
|
|
return rule
|
|
|
|
|
|
def _set_transform_attributes(
|
|
op: ir.OpView,
|
|
in_transforms: list[ir.Attribute],
|
|
out_transforms: list[ir.Attribute],
|
|
):
|
|
op.attributes["in_transforms"] = ir.ArrayAttr.get(in_transforms)
|
|
op.attributes["out_transforms"] = ir.ArrayAttr.get(out_transforms)
|
|
|
|
|
|
def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr:
|
|
if len(ref_ty.shape) != 2:
|
|
raise ValueError(f"Expected a 2D memref, got {ref_ty}")
|
|
|
|
element_bytewidth = utils.bytewidth(ref_ty.element_type)
|
|
strides, _ = ref_ty.get_strides_and_offset()
|
|
|
|
if strides[0] < strides[1]:
|
|
raise NotImplementedError("Transpositions aren't handled yet.")
|
|
|
|
minor_dim = ref_ty.shape[1]
|
|
major_tiling = 8
|
|
|
|
# Try tiling with all swizzling modes starting from the largest one.
|
|
for swizzle in [
|
|
mgpu.SwizzlingMode.k128ByteSwizzle,
|
|
mgpu.SwizzlingMode.k64ByteSwizzle,
|
|
mgpu.SwizzlingMode.k32ByteSwizzle,
|
|
mgpu.SwizzlingMode.kNoSwizzle,
|
|
]:
|
|
swizzle_elems = swizzle // element_bytewidth
|
|
if minor_dim % swizzle_elems == 0:
|
|
minor_tiling = swizzle_elems
|
|
break
|
|
else:
|
|
# No valid tile transform can be inferred.
|
|
raise ValueError(
|
|
f"{ref_ty.shape} is not a valid WGMMA shape"
|
|
)
|
|
|
|
return ir.ArrayAttr.get([
|
|
mgpu.TileTransformAttr.get((major_tiling, minor_tiling)),
|
|
mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth),
|
|
])
|
|
|
|
|
|
@partial(_add_transform_inference_rule, mgpu.WGMMAOp)
|
|
def infer_wgmma_transforms(op: mgpu.WGMMAOp) -> OptionalTransforms:
|
|
b_transforms = infer_transforms_for_wgmma_ref(ir.MemRefType(op.b.type))
|
|
if ir.MemRefType.isinstance(op.a.type):
|
|
a_transforms = infer_transforms_for_wgmma_ref(
|
|
cast(ir.MemRefType, op.a.type)
|
|
)
|
|
return [a_transforms, b_transforms], []
|
|
return [b_transforms], []
|
|
|
|
|
|
@partial(_add_transform_inference_rule, mgpu.AsyncStoreOp)
|
|
def _infer_async_store_transforms(op: mgpu.AsyncStoreOp) -> OptionalTransforms:
|
|
in_transforms = inference_utils.value_transforms(op.source)
|
|
return None if in_transforms is None else ([in_transforms], [])
|
|
|
|
|
|
@partial(_add_transform_inference_rule, mgpu.AsyncLoadOp)
|
|
def _infer_async_load_transforms(op: mgpu.AsyncLoadOp) -> OptionalTransforms:
|
|
in_transforms = inference_utils.value_transforms(op.destination)
|
|
return None if in_transforms is None else ([in_transforms], [])
|
|
|
|
|
|
@partial(_add_transform_inference_rule, vector.LoadOp)
|
|
@partial(_add_transform_inference_rule, vector.StoreOp)
|
|
def _infer_vector_load_store_transforms(
|
|
op: vector.LoadOp | vector.StoreOp,
|
|
) -> OptionalTransforms:
|
|
for i in op.indices:
|
|
index_defining_op = i.owner.opview
|
|
if (
|
|
not isinstance(index_defining_op, arith.ConstantOp)
|
|
or index_defining_op.literal_value != 0
|
|
):
|
|
# TODO(bchetioui): handle slicing.
|
|
raise NotImplementedError(
|
|
f"Only constants with value 0 are supported as indices for {op}"
|
|
)
|
|
|
|
if isinstance(op, vector.LoadOp):
|
|
[layout_attr] = inference_utils.out_layouts(op)
|
|
else:
|
|
assert isinstance(op, vector.StoreOp)
|
|
[layout_attr] = inference_utils.in_layouts(op)
|
|
|
|
layout = layouts_lib.from_layout_attr(layout_attr)
|
|
transforms = inference_utils.value_transforms(op.base)
|
|
|
|
if layout == fa.WGMMA_LAYOUT:
|
|
layout_transforms = infer_transforms_for_wgmma_ref(
|
|
ir.MemRefType(op.base.type)
|
|
)
|
|
elif (isinstance(layout, fa.WGStridedFragLayout) or
|
|
isinstance(layout, fa.WGSplatFragLayout)):
|
|
layout_transforms = None
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Got layout {layout} which is not yet supported"
|
|
)
|
|
|
|
if transforms is not None and layout_transforms is not None:
|
|
if transforms != layout_transforms:
|
|
raise NotImplementedError(
|
|
f"Conflicting transforms for {op.base} in {op}: "
|
|
f"{transforms} != {layout_transforms}."
|
|
)
|
|
return [transforms], []
|
|
|
|
if transforms is not None:
|
|
return [transforms], []
|
|
|
|
if layout_transforms is not None:
|
|
return [layout_transforms], []
|
|
|
|
return None
|
|
|
|
|
|
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
|
|
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
|
|
|
|
@partial(_add_transform_inference_rule, SliceSMEMOp)
|
|
def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
|
|
transforms = None
|
|
uses = cast(ir.OpResult, op.result).uses
|
|
|
|
for op_operand_use in uses:
|
|
consumer = op_operand_use.owner
|
|
op_user = consumer.operands[op_operand_use.operand_number]
|
|
out_transforms = inference_utils.in_transforms_for_operand(
|
|
consumer, op_user
|
|
)
|
|
if transforms is not None and out_transforms is not None:
|
|
if transforms != out_transforms:
|
|
raise NotImplementedError(
|
|
f"Conflicting transforms for {op_user} in {op}: "
|
|
f"{transforms} != {out_transforms}."
|
|
)
|
|
elif out_transforms is not None:
|
|
transforms = out_transforms
|
|
|
|
return None if transforms is None else ([], [transforms])
|
|
|
|
|
|
def _should_have_transforms(op: ir.OpView) -> bool:
|
|
"""Returns 'True' if the operation should be assigned in/out transforms."""
|
|
return any(
|
|
map(
|
|
inference_utils.is_transformable_smem_memref,
|
|
itertools.chain(op.operands, op.results),
|
|
)
|
|
)
|
|
|
|
|
|
def infer_transforms(module: ir.Module):
|
|
"""Infers transforms for the given module.
|
|
|
|
Transforms are to memrefs what layouts are to vectors. More specifically,
|
|
transforms describe mappings between SMEM refs and GMEM refs, and are
|
|
determined based on how SMEM refs are used. For that reason, we always
|
|
annotate and apply memrefs on SMEM refs.
|
|
|
|
The pass is meant to be called on a module where layouts have been fully
|
|
specified. We error out if two distinct sets of transforms are competing to
|
|
annotate the same memref.
|
|
"""
|
|
|
|
def inference_step(op: ir.Operation):
|
|
if not _should_have_transforms(op):
|
|
return
|
|
elif inference_rule := _transform_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error
|
|
pass
|
|
else:
|
|
raise NotImplementedError(f"Can not infer transforms for {op}")
|
|
|
|
maybe_transforms = inference_rule(op)
|
|
if maybe_transforms is None:
|
|
return
|
|
|
|
_set_transform_attributes(op, *maybe_transforms)
|
|
|
|
# It's enough to do a single backwards propagation (starting from vector
|
|
# users), and then a single forward propagation (to feed into the async loads
|
|
# and stores).
|
|
for op in module.body:
|
|
inference_utils.traverse_op(
|
|
op, inference_step, inference_utils.TraversalOrder.BACKWARDS
|
|
)
|
|
for op in module.body:
|
|
inference_utils.traverse_op(
|
|
op, inference_step, inference_utils.TraversalOrder.FORWARD
|
|
)
|