rocm_jax/jax/experimental/mosaic/gpu/layout_inference.py
Sergei Lebedev a7e5eaee56 [pallas:mosaic_gpu] jnp.reduce_sum now works for >1D arrays
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00

618 lines
20 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layout inference pass for the MLIR Mosaic GPU dialect."""
from collections.abc import Callable, Sequence
import dataclasses
import enum
from functools import partial
import math
from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import math as mlir_math
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
import numpy as np
from . import fragmented_array as fa
from . import inference_utils
from . import layouts as layouts_lib
from . import utils
# mypy: ignore-errors
OptionalLayouts = tuple[list[ir.Attribute], list[ir.Attribute]] | None
LayoutInferenceRule = Callable[[ir.OpView], OptionalLayouts]
_layout_inference_rules: dict[str, LayoutInferenceRule] = {}
def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule):
_layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
def _set_layout_attributes(
op: ir.OpView,
in_layouts: list[ir.Attribute],
out_layouts: list[ir.Attribute],
):
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)
def _choose_representative_layout(
layouts: set[ir.Attribute],
) -> ir.Attribute | None:
"""Chooses an appropriate layout from a given set of possible layouts.
Given the input set of possible layouts, this function extracts a single
representative layout. Currently, this function only works with strided,
splat, and tiled layouts.
Returns:
A single layout that can be used to annotate the operation, or None if the
input set is empty.
"""
if not layouts:
return None
strided_layouts: list[fa.WGStridedFragLayout] = [
layouts_lib.from_layout_attr(layout)
for layout in layouts
if layouts_lib.is_strided_fragmented_layout(layout)
]
splat_layouts: list[fa.WGSplatFragLayout] = list(
map(
layouts_lib.from_layout_attr,
filter(layouts_lib.is_splat_fragmented_layout, layouts),
)
)
tiled_layouts: list[fa.TiledLayout] = list(
map(
layouts_lib.from_layout_attr,
filter(layouts_lib.is_tiled_layout, layouts),
)
)
if len(splat_layouts) + len(strided_layouts) + len(tiled_layouts) != len(
layouts
):
raise ValueError(
f"Expected only strided, splat, and tiled layouts, got {layouts}"
)
if len(splat_layouts) > 1:
raise NotImplementedError(
"Finding a representative layout for several distinct splat layouts "
"is not supported."
)
if len(strided_layouts) > 1:
raise NotImplementedError(
"Finding a representative layout for several distinct strided layouts "
"is not supported."
)
if len(tiled_layouts) > 1:
raise NotImplementedError(
"Finding a representative layout for several distinct tiled layouts "
"is not supported."
)
if tiled_layouts and strided_layouts:
raise NotImplementedError(
"Mixing strided and tiled layouts is not supported."
)
if tiled_layouts:
return layouts_lib.to_layout_attr(tiled_layouts[0])
if strided_layouts:
[strided_layout] = strided_layouts
return layouts_lib.to_layout_attr(strided_layout)
[splat_layout] = splat_layouts
return layouts_lib.to_layout_attr(splat_layout)
def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
def is_array(v: ir.Value) -> bool:
return ir.VectorType.isinstance(v.type)
num_vector_operands = len([o for o in op.operands if is_array(o)])
num_vector_results = len([r for r in op.results if is_array(r)])
if inference_utils.has_in_layouts_set(op):
op_in_layouts = inference_utils.in_layouts(op)
if op_in_layouts:
layout = op_in_layouts[0]
return (num_vector_operands * [layout], num_vector_results * [layout])
if inference_utils.has_out_layouts_set(op):
op_out_layouts = inference_utils.out_layouts(op)
if op_out_layouts:
layout = op_out_layouts[0]
return (num_vector_operands * [layout], num_vector_results * [layout])
layouts = set()
# We can also try to infer layouts from the layout of producer and
# consumer operations.
#
# We first look at producers; this enables e.g. propagating splat layouts as
# far down as possible, until since we may be able to propagate splat layouts
# further down before requiring a relayout in that way.
all_inputs_have_layout = True
for operand in op.operands:
if not ir.VectorType.isinstance(operand.type):
continue
if (layout := inference_utils.value_layout(operand)) is not None:
layouts.add(layout)
else:
all_inputs_have_layout = False
# We only look at consumers if we haven't found a possible layout yet. This is
# to avoid propagating more complicated layouts up, to e.g. preserve splat
# layouts as far down as possible.
if not layouts:
for op_result in op.results:
if not ir.VectorType.isinstance(op_result.type):
continue
for op_operand_use in cast(ir.OpResult, op_result).uses:
consumer = op_operand_use.owner
op_user = consumer.operands[op_operand_use.operand_number]
layout = inference_utils.in_layout_for_operand(consumer, op_user)
if layout is not None:
layouts.add(layout)
# TODO(bchetioui): when propagating up, the representative layout should be
# chosen in the opposite way as when propagating down. E.g., when propagating
# down, we should pick a strided layout over a splat layout; when propagating
# up, we should pick a splat layout over a strided layout.
# This is left for a future change, and currently we only do "down
# propagation".
layout = _choose_representative_layout(layouts)
# It is unsafe to t conclude that this op produces a splat if not all inputs
# have been inferred: some of them might turn out not to be splats!
if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout:
return None
if layout is None:
return None
return (num_vector_operands * [layout], num_vector_results * [layout])
for op in [
arith.AddIOp,
arith.AddFOp,
arith.AndIOp,
arith.BitcastOp,
arith.CmpFOp,
arith.CmpIOp,
arith.ExtFOp,
arith.ExtSIOp,
arith.ExtUIOp,
arith.FPToSIOp,
arith.FPToUIOp,
arith.MaximumFOp,
arith.MaxUIOp,
arith.MaxSIOp,
arith.MinimumFOp,
arith.MinUIOp,
arith.MinSIOp,
arith.MulIOp,
arith.MulFOp,
arith.OrIOp,
arith.FloorDivSIOp,
arith.DivUIOp,
arith.DivFOp,
arith.RemUIOp,
arith.RemSIOp,
arith.RemFOp,
arith.SIToFPOp,
arith.UIToFPOp,
arith.SubIOp,
arith.SubFOp,
arith.TruncFOp,
arith.TruncIOp,
arith.XOrIOp,
mlir_math.ExpOp,
mlir_math.Exp2Op,
mlir_math.LogOp,
mlir_math.RsqrtOp,
mlir_math.TanhOp,
vector.LoadOp,
vector.StoreOp,
]:
_add_layout_inference_rule(op, _infer_pointwise_op_layouts)
@partial(_add_layout_inference_rule, arith.ConstantOp)
def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts:
if not ir.VectorType.isinstance(constant_op.result.type):
return None
shaped_ty = cast(ir.ShapedType, constant_op.result.type)
value = constant_op.value
layout = None
if (
ir.DenseElementsAttr.isinstance(value)
and ir.DenseElementsAttr(value).is_splat
):
layout = layouts_lib.to_splat_fragmented_layout_attr(
fa.WGSplatFragLayout(shape=shaped_ty.shape)
)
# If the constant is not a splat, there is no obvious good choice of layout.
# We need to look at the consumers of the constant to find a layout that works
# for them. If there are several users with N different layouts, we can
# arbitrarily choose any one of them for the constant, since we expect
# whichever choice we make to lead to N-1 relayouts, which all have the same
# cost.
#
# We assign a strided layout if the constant has no user, for completeness.
elif constant_op.result.uses:
for use in cast(ir.OpResult, constant_op.result).uses:
consumer = use.owner
operand = consumer.operands[use.operand_number]
layout = inference_utils.in_layout_for_operand(consumer, operand)
if layout is not None:
break
# If the constant is not a splat, has no user, or a layout could not be
# determined from looking at the users, we assign a strided layout for
# completeness.
if layout is None:
layout = layouts_lib.to_strided_fragmented_layout_attr(
fa.WGStridedFragLayout.from_shaped_type(shaped_ty)
)
return [], [layout]
@partial(_add_layout_inference_rule, scf.YieldOp)
def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts:
layouts = []
for result in op.results_:
if not ir.VectorType.isinstance(result.type):
continue
if (layout := inference_utils.value_layout(result)) is not None:
if layouts_lib.is_splat_fragmented_layout(layout):
return None
layouts.append(layout)
else:
# Not all layouts could be inferred for vector ops. Return for now.
return None
return (layouts, [])
@partial(_add_layout_inference_rule, scf.ForOp)
def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts:
yield_op = op.body.operations[len(op.body.operations) - 1]
assert isinstance(yield_op, scf.YieldOp)
if inference_utils.has_in_layouts_set(yield_op):
yield_layouts = list(inference_utils.in_layouts(yield_op))
if any(
layouts_lib.is_splat_fragmented_layout(layout)
for layout in yield_layouts
):
return None
return (yield_layouts, yield_layouts)
# TODO(bchetioui): we don't attempt to propagate from outside for the moment.
# For the existing kernels, propagating from the YieldOp should be enough.
return None
@partial(_add_layout_inference_rule, vector.SplatOp)
def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
layout = layouts_lib.to_splat_fragmented_layout_attr(
fa.WGSplatFragLayout(
shape=cast(ir.ShapedType, splat_op.result.type).shape
)
)
return [], [layout]
def _update_layout_shape(
layout: ir.Attribute, shape: Sequence[int], origin: str
) -> ir.Attribute:
if layouts_lib.is_splat_fragmented_layout(
layout
) or layouts_lib.is_strided_fragmented_layout(layout):
return layouts_lib.to_layout_attr(
dataclasses.replace(layouts_lib.from_layout_attr(layout), shape=shape)
)
raise NotImplementedError(f"Unsupported {origin} layout: {layout}.")
@partial(_add_layout_inference_rule, vector.ShapeCastOp)
def _infer_shape_cast_op_layout(op: vector.ShapeCastOp) -> OptionalLayouts:
in_layout = inference_utils.value_layout(op.source)
if in_layout is None:
out_layout = inference_utils.value_layout(op.result)
if out_layout is None:
return None
in_layout = _update_layout_shape(
out_layout, ir.VectorType(op.source.type).shape, "source"
)
return [in_layout], [out_layout]
out_layout = _update_layout_shape(
in_layout, ir.VectorType(op.result.type).shape, "result"
)
return [in_layout], [out_layout]
@partial(_add_layout_inference_rule, vector.ReductionOp)
def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts:
if layout := inference_utils.value_layout(op.vector):
return [layout], []
return None
@partial(_add_layout_inference_rule, mgpu.WGMMAOp)
def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)
if ir.VectorType.isinstance(wgmma_op.a.type):
return [layout, layout], [layout]
return [layout], [layout]
@dataclasses.dataclass()
class WGMMATransforms:
swizzle: mgpu.SwizzlingMode
a_tile: tuple[int, ...]
a_transpose: bool
b_tile: tuple[int, ...]
b_transpose: bool
def infer_wgmma_transforms(wgmma_op: mgpu.WGMMAOp) -> WGMMATransforms:
a_shape = cast(ir.ShapedType, wgmma_op.a.type).shape
k = a_shape[0] if wgmma_op.transpose_a else a_shape[1]
bitwidth = cast(ir.ShapedType, wgmma_op.a.type).element_type.width
# Try tiling with all swizzling modes starting from the largest one.
for swizzle in [
mgpu.SwizzlingMode.k128ByteSwizzle,
mgpu.SwizzlingMode.k64ByteSwizzle,
mgpu.SwizzlingMode.k32ByteSwizzle,
]:
s = swizzle * 8 // bitwidth
if k % s == 0:
return WGMMATransforms(
swizzle=swizzle,
a_tile=(s, 64) if wgmma_op.transpose_a else (64, s),
a_transpose=wgmma_op.transpose_a,
b_tile=(s, s),
b_transpose=wgmma_op.transpose_b,
)
raise ValueError(
"Could not infer layouts for memref feeding into WGMMA. The "
"non-contracting dimension ({k}) must be a multiple of "
"s = swizzle * (8 / bitwidth) where swizzle is a valid swizzle "
f"(32, 64, or 128) and bitwidth ({bitwidth}) is the element size of "
"`a` and `b`."
)
def _layout_for_memref_view(view_op: memref.ViewOp) -> ir.Attribute | None:
wgmma_use = None
uses = cast(ir.OpResult, view_op.result).uses
for use in uses:
user = use.owner
if isinstance(user, memref.CastOp):
# This memref is already cast, so we don't need to do anything.
return None
if isinstance(user, mgpu.WGMMAOp):
if wgmma_use is not None:
raise NotImplementedError(f"Multiple WGMMA consumers of {view_op}.")
wgmma_use = use
break
if (
not isinstance(user, mgpu.AsyncLoadOp)
and not isinstance(user, mgpu.AsyncStoreOp)
and not isinstance(user, vector.LoadOp)
and not isinstance(user, vector.StoreOp)
):
raise NotImplementedError(f"Unsupported user {user} of {view_op}.")
if wgmma_use is None:
# This memref is not used by a WGMMA operation, so we don't need to do
# anything.
return None
transforms = infer_wgmma_transforms(wgmma_use.owner)
if wgmma_use.operand_number == 1:
tile = transforms.a_tile
transpose = transforms.a_transpose
else:
tile = transforms.b_tile
transpose = transforms.b_transpose
transpose_attr = (
[mgpu.TransposeTransformAttr.get([1, 0, 2, 3])] if transpose else []
)
layout = mgpu.LayoutAttr.get(
2,
[mgpu.TileTransformAttr.get(tile)]
+ transpose_attr
+ [mgpu.SwizzleTransformAttr.get(transforms.swizzle)],
)
return layout
def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir.OpView:
owners = [use.owner for use in uses]
for region in regions:
for block in region:
for op in block:
if op in owners:
return op
raise ValueError("None of uses are in the given block")
def _insert_memref_layout_cast(layout: ir.Attribute, view_op: memref.ViewOp):
mem_ref_type = ir.MemRefType(view_op.result.type)
memref_new_type = ir.MemRefType.get(
mem_ref_type.shape,
mem_ref_type.element_type,
layout,
mem_ref_type.memory_space,
)
uses = list(view_op.result.uses)
with ir.InsertionPoint(_earliest_use(view_op.parent.regions, uses)):
cast_op = memref.cast(memref_new_type, view_op.result)
for use in uses:
use.owner.operands[use.operand_number] = cast_op
class TraversalOrder(enum.Enum):
"""Traversal orders with respect to the data flow for IR."""
FORWARD = 1
BACKWARDS = 2
def traverse_op(
op: ir.OpView,
callback: Callable[[ir.OpView], None],
traversal_order: TraversalOrder = TraversalOrder.FORWARD,
):
"""Traverses the operation and applies the callback in the given order."""
for region in op.operation.regions:
for block in region:
if traversal_order == TraversalOrder.FORWARD:
ops_to_traverse = block
else:
ops_to_traverse = reversed(list(block))
for block_op in ops_to_traverse:
traverse_op(block_op, callback, traversal_order)
callback(op)
def infer_layout(module: ir.Module):
def inference_step(op: ir.Operation):
if not inference_utils.should_have_layout(op):
return
elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error
pass
else:
raise NotImplementedError(f"Can not infer layout for {op}")
maybe_layouts = inference_rule(op)
if maybe_layouts is None:
return
_set_layout_attributes(op, *maybe_layouts)
# TODO(bchetioui): consider switching the order of the passes. This would
# allow propagating "simpler" layouts further down in the computation, which
# is more efficient when possible.
#
# We run two passes over the module, in order to make sure that layouts
# defined in the middle of the computation are propagated wherever they need
# to be propagated. We start with a backwards (root-to-parameters) pass to
# propagate the information as far up as possible, and then a forward pass
# (parameters-to-root).
#
# Backwards pass
for op in module.body:
inference_utils.traverse_op(
op, inference_step, inference_utils.TraversalOrder.BACKWARDS
)
# Forward pass
for op in module.body:
inference_utils.traverse_op(
op, inference_step, inference_utils.TraversalOrder.FORWARD
)
# At this point, layouts have been propagated as far as they could be
# propagated. However, it is possible for some operations to remain
# unannotated---for example, if there were no annotations on any operation in
# the module at the start of this function. We annotate all the remaining ops
# that should be annotated with a strided fragmented layout, whose vector size
# is derived from the narrowest type and vector size used in the program. We
# make sure to derive a single vector size in order to avoid relayouts at
# lowering time.
default_vector_size = math.inf
def update_default_vector_size(op: ir.OpView):
nonlocal default_vector_size
for v in list(op.operands) + list(op.results):
if ir.VectorType.isinstance(v.type):
max_vec_size_for_v = (
np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE
)
desired_vec_size = 8 // utils.bytewidth(v.type.element_type)
default_vector_size = min(
default_vector_size, max_vec_size_for_v, desired_vec_size
)
for op in module.body:
traverse_op(op, update_default_vector_size)
if default_vector_size is None: # Nothing to annotate.
return
def to_default_layout(ty: ir.Type) -> ir.Attribute | None:
if not ir.VectorType.isinstance(ty):
return None
layout = fa.WGStridedFragLayout(
shape=cast(ir.ShapedType, ty).shape, vec_size=default_vector_size
)
return layouts_lib.to_strided_fragmented_layout_attr(layout)
def set_default_layout(op: ir.OpView):
if inference_utils.should_have_layout(
op
) and not inference_utils.has_any_layout_set(op):
in_layouts = []
for operand in op.operands:
if (layout := to_default_layout(operand.type)) is not None:
in_layouts.append(layout)
out_layouts = []
for result in op.results:
if (layout := to_default_layout(result.type)) is not None:
out_layouts.append(layout)
_set_layout_attributes(op, in_layouts, out_layouts)
for op in module.body:
traverse_op(op, set_default_layout)
def infer_memref_layouts_and_insert_casts(op: ir.OpView):
if op.name == "memref.view":
if layout := _layout_for_memref_view(op):
_insert_memref_layout_cast(layout, op)
for op in module.body:
traverse_op(op, infer_memref_layouts_and_insert_casts)