rocm_jax/jax/experimental/mosaic/gpu/layout_inference.py
2025-02-26 04:29:36 -08:00

456 lines
15 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layout inference pass for the MLIR Mosaic GPU dialect."""
from collections.abc import Callable
import enum
from functools import partial
from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from . import fragmented_array as fa
from . import layouts as layouts_lib
# mypy: ignore-errors
OptionalLayouts = tuple[list[ir.Attribute], list[ir.Attribute]] | None
LayoutInferenceRule = Callable[[ir.OpView], OptionalLayouts]
_layout_inference_rules: dict[str, LayoutInferenceRule] = {}
def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule):
_layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
def _set_layout_attributes(
op: ir.OpView,
in_layouts: list[ir.Attribute],
out_layouts: list[ir.Attribute],
):
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)
def _choose_representative_layout(
layouts: set[ir.Attribute],
) -> ir.Attribute | None:
"""Chooses an appropriate layout from a given set of possible layouts.
Given the input set of possible layouts, this function extracts a single
representative layout. Currently, this function only works with strided,
splat, and WGMMA fragmented layouts.
Returns:
A single layout that can be used to annotate the operation, or None if the
input set is empty.
"""
if not layouts:
return None
strided_layouts: list[fa.WGStridedFragLayout] = [
layouts_lib.from_layout_attr(layout)
for layout in layouts
if layouts_lib.is_strided_fragmented_layout(layout)
]
splat_layouts: list[fa.WGSplatFragLayout] = list(
map(
layouts_lib.from_layout_attr,
filter(layouts_lib.is_splat_fragmented_layout, layouts),
)
)
wgmma_layouts: list[fa.WGMMAFragLayout] = list(
map(
layouts_lib.from_layout_attr,
filter(layouts_lib.is_wgmma_fragmented_layout, layouts),
)
)
if len(splat_layouts) + len(strided_layouts) + len(wgmma_layouts) != len(
layouts
):
raise ValueError(
f"Expected only strided, splat, and wgmma layouts, got {layouts}"
)
if len(splat_layouts) > 1:
raise NotImplementedError(
"Finding a representative layout for several distinct splat layouts "
"is not supported."
)
if len(strided_layouts) > 1:
raise NotImplementedError(
"Finding a representative layout for several distinct strided layouts "
"is not supported."
)
if (wgmma_layouts and strided_layouts):
raise NotImplementedError(
"Mixing strided and WGMMA layouts is not supported."
)
if wgmma_layouts:
return layouts_lib.to_layout_attr(wgmma_layouts[0])
if strided_layouts:
[strided_layout] = strided_layouts
return layouts_lib.to_layout_attr(strided_layout)
[splat_layout] = splat_layouts
return layouts_lib.to_layout_attr(splat_layout)
def _in_layout_for_operand(
op: ir.OpView,
operand: ir.Value,
) -> ir.Attribute | None:
"""Returns the layout of the operand in the given operation if it is set.
Raises:
ValueError: If `operand` is not an operand of `op`, or if `operand` is not a
Vector.
"""
if not ir.VectorType.isinstance(operand.type):
raise ValueError(f"{operand} is not a vector.")
operand_number = [
o for o in op.operands if ir.VectorType.isinstance(o.type)
].index(operand)
if not layouts_lib.has_in_layouts_set(op):
return None
return layouts_lib.in_layouts(op)[operand_number]
def _value_layout(value: ir.Value) -> ir.Attribute | None:
"""Returns the layout for a given value as defined by its owner.
Raises:
ValueError: If `result` is not a Vector.
"""
if not ir.VectorType.isinstance(value.type):
raise ValueError(f"{value} is not a vector.")
owner = value.owner
if isinstance(owner, ir.Operation):
if not layouts_lib.has_out_layouts_set(owner):
return None
value_result_number = [
r for r in owner.results if ir.VectorType.isinstance(r.type)
].index(value)
return layouts_lib.out_layouts(owner)[value_result_number]
# Block case, useful when attempting to derive layouts for ops
# depending on function parameters, or loop block arguments.
if isinstance(owner, ir.Block):
owner_op = owner.owner
block = cast(ir.Block, owner)
if not layouts_lib.has_in_layouts_set(owner_op):
return None
value_arg_number = [
r for r in block.arguments if ir.VectorType.isinstance(r.type)
].index(value)
return layouts_lib.in_layouts(owner_op)[value_arg_number]
raise NotImplementedError(
f"{owner} is not a function block nor an operation.")
def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts:
def is_array(v: ir.Value) -> bool:
return ir.VectorType.isinstance(v.type)
num_vector_operands = len([o for o in op.operands if is_array(o)])
num_vector_results = len([r for r in op.results if is_array(r)])
if layouts_lib.has_in_layouts_set(op):
op_in_layouts = layouts_lib.in_layouts(op)
if op_in_layouts:
layout = op_in_layouts[0]
return (num_vector_operands * [layout], num_vector_results * [layout])
if layouts_lib.has_out_layouts_set(op):
op_out_layouts = layouts_lib.out_layouts(op)
if op_out_layouts:
layout = op_out_layouts[0]
return (num_vector_operands * [layout], num_vector_results * [layout])
layouts = set()
# We can also try to infer layouts from the layout of producer and
# consumer operations.
#
# We first look at producers; this enables e.g. propagating splat layouts as
# far down as possible, until since we may be able to propagate splat layouts
# further down before requiring a relayout in that way.
all_inputs_have_layout = True
for operand in op.operands:
if not ir.VectorType.isinstance(operand.type):
continue
if (layout := _value_layout(operand)) is not None:
layouts.add(layout)
else:
all_inputs_have_layout = False
# We only look at consumers if we haven't found a possible layout yet. This is
# to avoid propagating more complicated layouts up, to e.g. preserve splat
# layouts as far down as possible.
if not layouts:
for op_result in op.results:
if not ir.VectorType.isinstance(op_result.type):
continue
for op_operand_use in cast(ir.OpResult, op_result).uses:
consumer = op_operand_use.owner
op_user = consumer.operands[op_operand_use.operand_number]
layout = _in_layout_for_operand(consumer, op_user)
if layout is not None:
layouts.add(layout)
# TODO(bchetioui): when propagating up, the representative layout should be
# chosen in the opposite way as when propagating down. E.g., when propagating
# down, we should pick a strided layout over a splat layout; when propagating
# up, we should pick a splat layout over a strided layout.
# This is left for a future change, and currently we only do "down
# propagation".
layout = _choose_representative_layout(layouts)
# It is unsafe to t conclude that this op produces a splat if not all inputs
# have been inferred: some of them might turn out not to be splats!
if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout:
return None
if layout is None:
return None
return (num_vector_operands * [layout], num_vector_results * [layout])
for op in [
arith.AddIOp, arith.AddFOp,
arith.AndIOp,
arith.CmpFOp,
arith.CmpIOp,
arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp,
arith.MaximumFOp,
arith.MaxUIOp, arith.MaxSIOp,
arith.MinimumFOp,
arith.MinUIOp, arith.MinSIOp,
arith.MulIOp, arith.MulFOp,
arith.OrIOp,
arith.FloorDivSIOp, arith.DivUIOp, arith.DivFOp,
arith.RemUIOp, arith.RemSIOp, arith.RemFOp,
arith.SubIOp, arith.SubFOp,
arith.TruncFOp, arith.TruncIOp,
arith.XOrIOp,
vector.LoadOp,
vector.StoreOp,
]:
_add_layout_inference_rule(op, _infer_pointwise_op_layouts)
@partial(_add_layout_inference_rule, arith.ConstantOp)
def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts:
if not ir.VectorType.isinstance(constant_op.result.type):
return None
shaped_ty = cast(ir.ShapedType, constant_op.result.type)
value = constant_op.value
layout = None
if (
ir.DenseElementsAttr.isinstance(value)
and ir.DenseElementsAttr(value).is_splat
):
layout = layouts_lib.to_splat_fragmented_layout_attr(
fa.WGSplatFragLayout(shape=shaped_ty.shape)
)
# If the constant is not a splat, there is no obvious good choice of layout.
# We need to look at the consumers of the constant to find a layout that works
# for them. If there are several users with N different layouts, we can
# arbitrarily choose any one of them for the constant, since we expect
# whichever choice we make to lead to N-1 relayouts, which all have the same
# cost.
#
# We assign a strided layout if the constant has no user, for completeness.
elif constant_op.result.uses:
for use in cast(ir.OpResult, constant_op.result).uses:
consumer = use.owner
operand = consumer.operands[use.operand_number]
layout = _in_layout_for_operand(consumer, operand)
if layout is not None:
break
# If the constant is not a splat, has no user, or a layout could not be
# determined from looking at the users, we assign a strided layout for
# completeness.
if layout is None:
layout = layouts_lib.to_strided_fragmented_layout_attr(
fa.WGStridedFragLayout.from_shaped_type(shaped_ty)
)
return [], [layout]
@partial(_add_layout_inference_rule, scf.YieldOp)
def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts:
layouts = []
for result in op.results_:
if not ir.VectorType.isinstance(result.type):
continue
if (layout := _value_layout(result)) is not None:
if layouts_lib.is_splat_fragmented_layout(layout):
return None
layouts.append(layout)
else:
# Not all layouts could be inferred for vector ops. Return for now.
return None
return (layouts, [])
@partial(_add_layout_inference_rule, scf.ForOp)
def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts:
yield_op = op.body.operations[len(op.body.operations) - 1]
assert isinstance(yield_op, scf.YieldOp)
if layouts_lib.has_in_layouts_set(yield_op):
yield_layouts = list(layouts_lib.in_layouts(yield_op))
if any(
layouts_lib.is_splat_fragmented_layout(layout)
for layout in yield_layouts
):
return None
return (yield_layouts, yield_layouts)
# TODO(bchetioui): we don't attempt to propagate from outside for the moment.
# For the existing kernels, propagating from the YieldOp should be enough.
return None
@partial(_add_layout_inference_rule, vector.SplatOp)
def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
layout = layouts_lib.to_splat_fragmented_layout_attr(
fa.WGSplatFragLayout(
shape=cast(ir.ShapedType, splat_op.result.type).shape
)
)
return [], [layout]
@partial(_add_layout_inference_rule, mgpu.WGMMAOp)
def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
layout = layouts_lib.to_layout_attr(fa.WGMMAFragLayout())
if ir.VectorType.isinstance(wgmma_op.a.type):
return [layout, layout], [layout]
return [layout], [layout]
class TraversalOrder(enum.Enum):
"""Traversal orders with respect to the data flow for IR."""
FORWARD = 1
BACKWARDS = 2
def traverse_op(
op: ir.OpView,
callback: Callable[[ir.OpView], None],
traversal_order: TraversalOrder = TraversalOrder.FORWARD,
):
"""Traverses the operation and applies the callback in the given order."""
for region in op.operation.regions:
for block in region:
if traversal_order == TraversalOrder.FORWARD:
ops_to_traverse = block
else:
ops_to_traverse = reversed(list(block))
for block_op in ops_to_traverse:
traverse_op(block_op, callback, traversal_order)
callback(op)
def infer_layout(module: ir.Module):
def inference_step(op: ir.Operation):
if not layouts_lib.should_have_layout(op):
return
elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error
pass
else:
raise NotImplementedError(f"Can not infer layout for {op}")
maybe_layouts = inference_rule(op)
if maybe_layouts is None:
return
_set_layout_attributes(op, *maybe_layouts)
# TODO(bchetioui): consider switching the order of the passes. This would
# allow propagating "simpler" layouts further down in the computation, which
# is more efficient when possible.
#
# We run two passes over the module, in order to make sure that layouts
# defined in the middle of the computation are propagated wherever they need
# to be propagated. We start with a backwards (root-to-parameters) pass to
# propagate the information as far up as possible, and then a forward pass
# (parameters-to-root).
#
# Backwards pass
for op in module.body:
traverse_op(op, inference_step, TraversalOrder.BACKWARDS)
# Forward pass
for op in module.body:
traverse_op(op, inference_step, TraversalOrder.FORWARD)
# At this point, layouts have been propagated as far as they could be
# propagated. However, it is possible for some operations to remain
# unannotated---for example, if there were no annotations on any operation in
# the module at the start of this function. We annotate all the remaining ops
# that should be annotated with a strided fragmented layout.
def to_default_layout(ty: ir.Type) -> ir.Attribute | None:
if not ir.VectorType.isinstance(ty):
return None
layout = fa.WGStridedFragLayout.from_shaped_type(ty)
return layouts_lib.to_strided_fragmented_layout_attr(layout)
def set_default_layout(op: ir.OpView):
if (layouts_lib.should_have_layout(op) and
not layouts_lib.has_any_layout_set(op)):
in_layouts = []
for operand in op.operands:
if (layout := to_default_layout(operand.type)) is not None:
in_layouts.append(layout)
out_layouts = []
for result in op.results:
if (layout := to_default_layout(result.type)) is not None:
out_layouts.append(layout)
_set_layout_attributes(op, in_layouts, out_layouts)
for op in module.body:
traverse_op(op, set_default_layout)