mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
938 lines
32 KiB
Python
938 lines
32 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Lowering rules and pass for the MLIR Mosaic GPU dialect."""
|
|
|
|
from collections.abc import Callable
|
|
import dataclasses
|
|
import functools
|
|
import operator
|
|
from typing import Any, Sequence, Type, cast
|
|
|
|
from jax._src.interpreters import mlir as mlir_interpreter
|
|
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import arith
|
|
from jax._src.lib.mlir.dialects import builtin
|
|
from jax._src.lib.mlir.dialects import func
|
|
from jax._src.lib.mlir.dialects import gpu
|
|
from jax._src.lib.mlir.dialects import llvm
|
|
from jax._src.lib.mlir.dialects import math as mlir_math
|
|
from jax._src.lib.mlir.dialects import memref
|
|
from jax._src.lib.mlir.dialects import nvvm
|
|
from jax._src.lib.mlir.dialects import scf
|
|
from jax._src.lib.mlir.dialects import vector
|
|
import numpy as np
|
|
|
|
from . import fragmented_array as fa
|
|
from . import inference_utils
|
|
from . import launch_context
|
|
from . import layouts
|
|
from . import utils
|
|
from . import wgmma
|
|
|
|
# mypy: ignore-errors
|
|
|
|
|
|
@dataclasses.dataclass()
|
|
class LoweringContext:
|
|
launch_context: launch_context.LaunchContext | None
|
|
single_thread_per_block_predicate: ir.Value | None
|
|
single_thread_per_warpgroup_predicate: ir.Value | None
|
|
lowered_operations: set[ir.Operation | ir.OpView] = dataclasses.field(
|
|
default_factory=set
|
|
)
|
|
|
|
def lower_op(self, op: ir.OpView):
|
|
if not _should_lower(op):
|
|
return
|
|
|
|
if (name := op.OPERATION_NAME) not in _lowerings:
|
|
raise NotImplementedError(f"Missing lowering rule for {op}")
|
|
|
|
lowering_rule = _lowerings[name]
|
|
|
|
# TODO(bchetioui): make sure all layouts are set here.
|
|
if inference_utils.should_have_layout(
|
|
op
|
|
) and not inference_utils.has_any_layout_set(op):
|
|
raise ValueError(f"{op} is missing a layout and can not be lowered.")
|
|
|
|
new_results = lowering_rule(self, op)
|
|
if new_results is not RECURSED:
|
|
for old, new in zip(op.results, new_results):
|
|
old.replace_all_uses_with(new)
|
|
self.lowered_operations.add(op)
|
|
|
|
|
|
class Recursed:
|
|
pass
|
|
RECURSED = Recursed()
|
|
|
|
MlirLoweringRuleResult = Sequence[ir.Value] | Recursed
|
|
MlirLoweringRule = Callable[
|
|
[LoweringContext, ir.Operation | ir.OpView], MlirLoweringRuleResult
|
|
]
|
|
|
|
|
|
_lowerings: dict[str, MlirLoweringRule] = {}
|
|
|
|
|
|
def _fragmented_array_to_ir(
|
|
fragmented_array: fa.FragmentedArray, ty: ir.Type
|
|
) -> ir.Value:
|
|
"""Converts a FragmentedArray to an IR value.
|
|
|
|
The fragmented array's signedness is omitted from the IR representation.
|
|
"""
|
|
conversion_cast = builtin.UnrealizedConversionCastOp(
|
|
[ty], fragmented_array.registers.flatten().tolist()
|
|
)
|
|
|
|
conversion_cast.attributes["registers_shape"] = ir.ArrayAttr.get([
|
|
ir.IntegerAttr.get(ir.IntegerType.get_signless(64), s)
|
|
for s in fragmented_array.registers.shape
|
|
])
|
|
|
|
conversion_cast.attributes["layout"] = layouts.to_layout_attr(
|
|
fragmented_array.layout
|
|
)
|
|
|
|
return conversion_cast.result
|
|
|
|
|
|
def _fragmented_array_from_ir(
|
|
fragmented_array_as_ir: ir.Value,
|
|
layout: ir.Attribute,
|
|
is_signed: bool | None = None,
|
|
) -> fa.FragmentedArray:
|
|
|
|
conversion_cast = cast(
|
|
builtin.UnrealizedConversionCastOp, fragmented_array_as_ir.owner.opview # pytype: disable=attribute-error
|
|
)
|
|
|
|
if not isinstance(conversion_cast, builtin.UnrealizedConversionCastOp):
|
|
raise ValueError(f"{conversion_cast} is not a conversion_cast")
|
|
|
|
converted_outputs = builtin.unrealized_conversion_cast(
|
|
[operand.type for operand in conversion_cast.operands],
|
|
conversion_cast.results,
|
|
)
|
|
if not isinstance(converted_outputs, list):
|
|
converted_outputs = [converted_outputs]
|
|
|
|
reverse_conversion_cast = converted_outputs[0].owner.opview
|
|
for attribute in conversion_cast.attributes:
|
|
attribute = cast(ir.NamedAttribute, attribute)
|
|
reverse_conversion_cast.attributes[attribute.name] = attribute.attr
|
|
|
|
registers = np.array(list(converted_outputs)).reshape(
|
|
[attr.value for attr in conversion_cast.attributes["registers_shape"]]
|
|
)
|
|
producer_layout = layouts.from_layout_attr(conversion_cast.attributes["layout"])
|
|
|
|
if ir.IntegerType.isinstance(conversion_cast.outputs[0].type.element_type):
|
|
is_signed = False if is_signed is None else is_signed
|
|
|
|
return fa.FragmentedArray(
|
|
_registers=registers, _layout=producer_layout, _is_signed=is_signed
|
|
).to_layout(layouts.from_layout_attr(layout))
|
|
|
|
|
|
def _register_lowering(
|
|
op: str | Type[ir.OpView] | None
|
|
) -> Callable[[MlirLoweringRule], MlirLoweringRule]:
|
|
def wrapper(f):
|
|
if op is not None:
|
|
op_name = op if isinstance(op, str) else op.OPERATION_NAME # pytype: disable=attribute-error
|
|
_lowerings[op_name] = f
|
|
return f
|
|
|
|
return wrapper
|
|
|
|
|
|
def _lowered_barrier_type() -> ir.Type:
|
|
return ir.IntegerType.get_signless(64)
|
|
|
|
|
|
@_register_lowering(mgpu.InitializeBarrierOp)
|
|
def _initialize_barrier_op_lowering_rule(
|
|
ctx: LoweringContext,
|
|
initialize_barrier_op: mgpu.InitializeBarrierOp,
|
|
) -> Sequence[ir.Value]:
|
|
|
|
shape = initialize_barrier_op.barriers_ref.type.shape
|
|
num_barriers = functools.reduce(operator.mul, shape, 1)
|
|
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
workgroup_nvptx_address_space = utils.gpu_address_space_to_nvptx(
|
|
gpu.AddressSpace.Workgroup)
|
|
ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
|
|
|
|
lowered_barrier_type = _lowered_barrier_type()
|
|
|
|
for i in range(num_barriers):
|
|
nvvm.mbarrier_init_shared(
|
|
llvm.getelementptr(ptr_ty, initialize_barrier_op.base_pointer, [], [i],
|
|
lowered_barrier_type),
|
|
utils.c(initialize_barrier_op.arrival_count.value, i32),
|
|
predicate=ctx.single_thread_per_block_predicate
|
|
)
|
|
|
|
gpu.barrier()
|
|
|
|
barrier_base_ptr = llvm.getelementptr(
|
|
ir.Type.parse("!llvm.ptr"),
|
|
initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type)
|
|
|
|
return utils.ptr_as_memref(
|
|
barrier_base_ptr, initialize_barrier_op.barriers_ref.type),
|
|
|
|
|
|
@_register_lowering(arith.ConstantOp)
|
|
def _arith_constant_op_lowering_rule(
|
|
_: LoweringContext, op: arith.ConstantOp
|
|
) -> Sequence[ir.Value]:
|
|
if not ir.DenseElementsAttr.isinstance(op.value):
|
|
raise NotImplementedError(f"Unsupported constant op: {op}")
|
|
|
|
value = ir.DenseElementsAttr(op.value)
|
|
if not value.is_splat:
|
|
raise NotImplementedError(f"Unsupported constant op: {op}")
|
|
|
|
ty = ir.VectorType(op.result.type)
|
|
is_signed = False if ir.IntegerType.isinstance(ty.element_type) else None
|
|
|
|
return [
|
|
_fragmented_array_to_ir(
|
|
fa.FragmentedArray.splat(
|
|
arith.constant(ty.element_type, value.get_splat_value()),
|
|
tuple(ty.shape),
|
|
layouts.from_layout_attr(op.attributes["out_layouts"][0]),
|
|
is_signed=is_signed,
|
|
),
|
|
op.result.type,
|
|
)
|
|
]
|
|
|
|
|
|
@_register_lowering(vector.LoadOp)
|
|
def _vector_load_op_lowering_rule(
|
|
_: LoweringContext, vector_load_op: vector.LoadOp
|
|
) -> Sequence[ir.Value]:
|
|
(out_layout_attr,) = cast(
|
|
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
|
|
)
|
|
|
|
for i in vector_load_op.indices:
|
|
index_defining_op = i.owner.opview
|
|
if (
|
|
not isinstance(index_defining_op, arith.ConstantOp)
|
|
or index_defining_op.literal_value != 0
|
|
):
|
|
# TODO(bchetioui,dasenov): support non-zero indices.
|
|
raise NotImplementedError(
|
|
"Only constants with value 0 are supported as indices "
|
|
f"for {vector_load_op}"
|
|
)
|
|
|
|
element_type = vector_load_op.result.type.element_type
|
|
is_signed = False if ir.IntegerType.isinstance(element_type) else None
|
|
|
|
if layouts.is_strided_fragmented_layout(out_layout_attr):
|
|
strided_layout = layouts.from_strided_fragmented_layout_attr(
|
|
out_layout_attr
|
|
)
|
|
fragmented_array = fa.FragmentedArray.load_strided(
|
|
vector_load_op.base,
|
|
is_signed=is_signed,
|
|
vec_size=strided_layout.vec_size,
|
|
)
|
|
elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT:
|
|
layout = ir.MemRefType(vector_load_op.base.type).layout
|
|
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
|
|
transformed_ref = transform_memref(vector_load_op.base, transforms)
|
|
fragmented_array = fa.FragmentedArray.load_tiled(
|
|
transformed_ref,
|
|
swizzle=swizzle,
|
|
is_signed=is_signed,
|
|
layout=fa.WGMMA_LAYOUT,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"{vector_load_op} has an unsupported layout: {out_layout_attr}"
|
|
)
|
|
return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)]
|
|
|
|
|
|
@_register_lowering(vector.StoreOp)
|
|
def _vector_store_op_lowering_rule(
|
|
_: LoweringContext, vector_store_op: vector.StoreOp
|
|
) -> Sequence[ir.Value]:
|
|
for i in vector_store_op.indices:
|
|
index_defining_op = i.owner.opview
|
|
if (
|
|
not isinstance(index_defining_op, arith.ConstantOp)
|
|
or index_defining_op.literal_value != 0
|
|
):
|
|
# TODO(bchetioui,dasenov): support non-zero indices.
|
|
raise NotImplementedError(
|
|
"Only constants with value 0 are supported as indices "
|
|
f"for {vector_store_op}"
|
|
)
|
|
|
|
[to_store_layout] = inference_utils.in_layouts(vector_store_op)
|
|
fragmented_array = _fragmented_array_from_ir(
|
|
vector_store_op.valueToStore, to_store_layout
|
|
)
|
|
|
|
# TODO(dasenov): This is not efficient for WGMMA layouts
|
|
fragmented_array.store_untiled(vector_store_op.base)
|
|
|
|
return []
|
|
|
|
@_register_lowering(vector.SplatOp)
|
|
def _vector_splat_op_lowering_rule(
|
|
_: LoweringContext, vector_splat_op: vector.SplatOp
|
|
) -> Sequence[ir.Value]:
|
|
|
|
out_vec_ty = ir.VectorType(vector_splat_op.aggregate.type)
|
|
is_signed = (
|
|
False if ir.IntegerType.isinstance(out_vec_ty.element_type) else None
|
|
)
|
|
fragmented_array = fa.FragmentedArray.splat(
|
|
vector_splat_op.input,
|
|
tuple(out_vec_ty.shape),
|
|
layouts.from_layout_attr(vector_splat_op.attributes["out_layouts"][0]),
|
|
is_signed=is_signed,
|
|
)
|
|
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
|
|
|
|
|
|
@_register_lowering(vector.ShapeCastOp)
|
|
def _vector_shape_cast_op_lowering_rule(
|
|
_: LoweringContext, op: vector.ShapeCastOp
|
|
) -> Sequence[ir.Value]:
|
|
[layout] = inference_utils.in_layouts(op)
|
|
out_vec_ty = ir.VectorType(op.result.type)
|
|
assert out_vec_ty.has_static_shape
|
|
is_signed = (
|
|
False if ir.IntegerType.isinstance(out_vec_ty.element_type) else None
|
|
)
|
|
a = _fragmented_array_from_ir(op.source, layout, is_signed)
|
|
return [_fragmented_array_to_ir(a.reshape(out_vec_ty.shape), out_vec_ty)]
|
|
|
|
|
|
@_register_lowering(vector.ReductionOp)
|
|
def _vector_reduction_op_lowering_rule(
|
|
ctx: LoweringContext, op: vector.ReductionOp
|
|
) -> Sequence[ir.Value]:
|
|
del ctx # Unused.
|
|
[layout] = inference_utils.in_layouts(op)
|
|
() = inference_utils.out_layouts(op)
|
|
element_type = ir.VectorType(op.vector.type).element_type
|
|
is_signed = False if ir.IntegerType.isinstance(element_type) else None
|
|
a = _fragmented_array_from_ir(op.vector, layout, is_signed)
|
|
match str(op.kind):
|
|
case "#vector.kind<add>":
|
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
|
scratch = _slice_smem(
|
|
ir.MemRefType.get([4], element_type, memory_space=smem),
|
|
arith.constant(None, op.attributes["offset"]),
|
|
)
|
|
result = a.reduce_sum(scratch)
|
|
case (
|
|
"#vector.kind<maxsi>" | "#vector.kind<maxui>" | "#vector.kind<maximumf>"
|
|
):
|
|
# TODO(slebedev): Implement this and remove the raise below.
|
|
raise NotImplementedError(f"Unsupported reduction kind: {op.kind}")
|
|
case _:
|
|
raise NotImplementedError(f"Unsupported reduction kind: {op.kind}")
|
|
return [_fragmented_array_to_ir(result, op.result.type)]
|
|
|
|
|
|
def memref_layout_to_swizzle_and_transforms(
|
|
layout: ir.Attribute,
|
|
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
|
|
"""Returns the swizzle and transforms that are encoded in the given layout.
|
|
|
|
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle and the
|
|
transforms are empty. Otherwise, the layout may have at most one swizzle
|
|
transform and any combination of tiling and transpose transforms.
|
|
"""
|
|
swizzle = None
|
|
gmem_transforms: list[launch_context.MemRefTransform] = []
|
|
|
|
if mgpu.LayoutAttr.isinstance(layout):
|
|
transforms_attr = mgpu.LayoutAttr(layout).transforms
|
|
for transform in transforms_attr:
|
|
if swizzle is not None:
|
|
raise ValueError(f"{layout} contains more transforms after the initial swizzle.")
|
|
if mgpu.SwizzleTransformAttr.isinstance(transform):
|
|
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
|
|
# ways. We might want to enforce some restrictions here.
|
|
swizzle = mgpu.SwizzleTransformAttr(transform).swizzle
|
|
elif mgpu.TileTransformAttr.isinstance(transform):
|
|
tiling = mgpu.TileTransformAttr(transform).tiling
|
|
tiling_transform = launch_context.TileTransform(tuple(tiling))
|
|
gmem_transforms.append(tiling_transform)
|
|
elif mgpu.TransposeTransformAttr.isinstance(transform):
|
|
permutation = mgpu.TransposeTransformAttr(transform).permutation
|
|
transpose_transform = launch_context.TransposeTransform(
|
|
tuple(permutation)
|
|
)
|
|
gmem_transforms.append(transpose_transform)
|
|
else:
|
|
raise ValueError(f"{layout} has an unsupported transform: {transform}")
|
|
|
|
return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms)
|
|
|
|
|
|
def transform_memref(
|
|
mem_ref: ir.Value, transforms: tuple[launch_context.MemRefTransform, ...]
|
|
) -> ir.Value:
|
|
"""Reinterprets the memref to one where the shape is transformed as given."""
|
|
if not transforms:
|
|
return mem_ref
|
|
|
|
mem_ref_type = ir.MemRefType(mem_ref.type)
|
|
if mem_ref_type.memory_space != ir.Attribute.parse(
|
|
"#gpu.address_space<workgroup>"
|
|
):
|
|
raise ValueError(f"Only workgroup memory is supported but got {mem_ref}.")
|
|
|
|
shape = mem_ref_type.shape
|
|
for t in transforms:
|
|
shape = t.transform_shape(shape)
|
|
|
|
memref_new_type = ir.MemRefType.get(
|
|
shape,
|
|
mem_ref_type.element_type,
|
|
memory_space=mem_ref_type.memory_space,
|
|
)
|
|
|
|
ms = utils.WORKGROUP_NVPTX_ADDRESS_SPACE
|
|
ptr = utils.memref_ptr(mem_ref, memory_space=ms)
|
|
return utils.ptr_as_memref(ptr, memref_new_type, ptr_memory_space=ms)
|
|
|
|
|
|
@_register_lowering(mgpu.AsyncLoadOp)
|
|
def _mgpu_async_load_op_lowering_rule(
|
|
ctx: LoweringContext, load_op: mgpu.AsyncLoadOp
|
|
) -> Sequence[ir.Value]:
|
|
assert ctx.launch_context is not None
|
|
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
|
|
|
|
dst_layout = ir.MemRefType(load_op.destination.type).layout
|
|
swizzle, transforms = memref_layout_to_swizzle_and_transforms(dst_layout)
|
|
|
|
gmem_slice = []
|
|
for idx_i32, size in zip(load_op.indices, load_op.slice_lengths):
|
|
idx = arith.index_cast(ir.IndexType.get(), idx_i32)
|
|
v = idx if size < 0 else utils.DynamicSlice(idx, size)
|
|
gmem_slice.append(v)
|
|
|
|
# TODO(dasenov): Add support for the remaining op properties.
|
|
ctx.launch_context.async_copy(
|
|
src_ref=load_op.source,
|
|
dst_ref=transform_memref(load_op.destination, transforms),
|
|
gmem_slice=tuple(gmem_slice),
|
|
barrier=barrier,
|
|
arrive=False,
|
|
uniform=True,
|
|
swizzle=swizzle,
|
|
gmem_transform=transforms,
|
|
predicate=ctx.single_thread_per_warpgroup_predicate,
|
|
)
|
|
return []
|
|
|
|
|
|
@_register_lowering(mgpu.AsyncStoreOp)
|
|
def _mgpu_async_store_op_lowering_rule(
|
|
ctx: LoweringContext, store_op: mgpu.AsyncStoreOp
|
|
) -> Sequence[ir.Value]:
|
|
assert ctx.launch_context is not None
|
|
|
|
src_layout = ir.MemRefType(store_op.source.type).layout
|
|
swizzle, transforms = memref_layout_to_swizzle_and_transforms(src_layout)
|
|
|
|
gmem_slice = []
|
|
for idx_i32, size in zip(store_op.indices, store_op.slice_lengths):
|
|
idx = arith.index_cast(ir.IndexType.get(), idx_i32)
|
|
v = idx if size < 0 else utils.DynamicSlice(idx, size)
|
|
gmem_slice.append(v)
|
|
|
|
# TODO(dasenov): Add support for the remaining op properties.
|
|
ctx.launch_context.async_copy(
|
|
src_ref=transform_memref(store_op.source, transforms),
|
|
dst_ref=store_op.destination,
|
|
gmem_slice=tuple(gmem_slice),
|
|
swizzle=swizzle,
|
|
gmem_transform=transforms,
|
|
uniform=True,
|
|
predicate=ctx.single_thread_per_warpgroup_predicate,
|
|
arrive=store_op.commit_group,
|
|
)
|
|
return []
|
|
|
|
|
|
def _conversion_op_lowering_rule(
|
|
_: LoweringContext,
|
|
op: ir.OpView,
|
|
source_is_signed: bool | None,
|
|
target_is_signed: bool | None,
|
|
) -> Sequence[ir.Value]:
|
|
[in_layout] = inference_utils.in_layouts(op)
|
|
[layout] = inference_utils.out_layouts(op)
|
|
if in_layout != layout:
|
|
raise ValueError("Layout mismatch")
|
|
|
|
target_ty = op.result.type.element_type # pytype: disable=attribute-error
|
|
operand = _fragmented_array_from_ir(op.operands[0], layout, source_is_signed)
|
|
converted = operand.astype(target_ty, is_signed=target_is_signed)
|
|
return [_fragmented_array_to_ir(converted, op.result.type)]
|
|
|
|
|
|
for op, source_is_signed, target_is_signed in [
|
|
(arith.ExtFOp, None, None),
|
|
(arith.ExtSIOp, True, True),
|
|
(arith.ExtUIOp, False, False),
|
|
(arith.FPToSIOp, None, True),
|
|
(arith.FPToUIOp, None, False),
|
|
(arith.SIToFPOp, True, None),
|
|
(arith.TruncFOp, None, None),
|
|
(arith.TruncIOp, False, False),
|
|
(arith.UIToFPOp, False, None),
|
|
]:
|
|
_lowerings[op.OPERATION_NAME] = functools.partial(
|
|
_conversion_op_lowering_rule,
|
|
source_is_signed=source_is_signed,
|
|
target_is_signed=target_is_signed,
|
|
)
|
|
|
|
|
|
def _unary_op_lowering_rule(
|
|
_: LoweringContext,
|
|
op: Any,
|
|
impl: Callable[[fa.FragmentedArray], fa.FragmentedArray],
|
|
is_signed: bool | None = None,
|
|
) -> Sequence[ir.Value]:
|
|
in_layouts = inference_utils.in_layouts(op)
|
|
[layout] = inference_utils.out_layouts(op)
|
|
if any(in_layout != layout for in_layout in in_layouts):
|
|
raise ValueError("Layout mismatch")
|
|
kwargs = {}
|
|
if hasattr(op, "fastmath"):
|
|
kwargs = dict(
|
|
approx=op.fastmath == ir.Attribute.parse("#arith.fastmath<afn>")
|
|
)
|
|
a = _fragmented_array_from_ir(op.operand, layout, is_signed)
|
|
return [_fragmented_array_to_ir(impl(a, **kwargs), op.result.type)]
|
|
|
|
|
|
for op, impl, is_signed in [
|
|
(mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None),
|
|
(mlir_math.ExpOp, fa.FragmentedArray.exp, None),
|
|
(mlir_math.Exp2Op, fa.FragmentedArray.exp2, None),
|
|
(mlir_math.LogOp, fa.FragmentedArray.log, None),
|
|
(mlir_math.TanhOp, fa.FragmentedArray.tanh, None),
|
|
]:
|
|
_lowerings[op.OPERATION_NAME] = functools.partial(
|
|
_unary_op_lowering_rule, impl=impl, is_signed=is_signed
|
|
)
|
|
|
|
|
|
def _binary_op_lowering_rule(
|
|
_: LoweringContext,
|
|
op: Any,
|
|
is_signed: bool | None,
|
|
impl: Callable[
|
|
[fa.FragmentedArray, fa.FragmentedArray], fa.FragmentedArray
|
|
],
|
|
) -> Sequence[ir.Value]:
|
|
in_layouts = inference_utils.in_layouts(op)
|
|
[layout] = inference_utils.out_layouts(op)
|
|
if any(in_layout != layout for in_layout in in_layouts):
|
|
raise ValueError("Layout mismatch")
|
|
lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed)
|
|
rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed)
|
|
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
|
|
|
|
|
|
for op, impl, is_signed in [
|
|
(arith.AddIOp, operator.add, False),
|
|
(arith.AddFOp, operator.add, None),
|
|
(arith.SubIOp, operator.sub, False),
|
|
(arith.SubFOp, operator.sub, None),
|
|
(arith.MulIOp, operator.mul, False),
|
|
(arith.MulFOp, operator.mul, None),
|
|
(arith.FloorDivSIOp, operator.floordiv, True),
|
|
(arith.DivUIOp, operator.floordiv, False),
|
|
(arith.DivFOp, operator.truediv, None),
|
|
(arith.RemSIOp, operator.mod, True),
|
|
(arith.RemUIOp, operator.mod, False),
|
|
(arith.RemFOp, operator.mod, None),
|
|
(arith.AndIOp, operator.and_, False),
|
|
(arith.OrIOp, operator.or_, False),
|
|
(arith.XOrIOp, operator.xor, False),
|
|
(arith.MaxSIOp, fa.FragmentedArray.max, True),
|
|
(arith.MaxUIOp, fa.FragmentedArray.max, False),
|
|
(arith.MaximumFOp, fa.FragmentedArray.max, None),
|
|
(arith.MinSIOp, fa.FragmentedArray.min, True),
|
|
(arith.MinUIOp, fa.FragmentedArray.min, False),
|
|
(arith.MinimumFOp, fa.FragmentedArray.min, None),
|
|
]:
|
|
_lowerings[op.OPERATION_NAME] = functools.partial(
|
|
_binary_op_lowering_rule, impl=impl, is_signed=is_signed
|
|
)
|
|
|
|
|
|
CMPI_IMPLS = {
|
|
arith.CmpIPredicate.eq: (operator.eq, False),
|
|
arith.CmpIPredicate.ne: (operator.ne, False),
|
|
arith.CmpIPredicate.slt: (operator.lt, True),
|
|
arith.CmpIPredicate.sle: (operator.le, True),
|
|
arith.CmpIPredicate.sgt: (operator.gt, True),
|
|
arith.CmpIPredicate.sge: (operator.ge, True),
|
|
arith.CmpIPredicate.ult: (operator.lt, False),
|
|
arith.CmpIPredicate.ule: (operator.le, False),
|
|
arith.CmpIPredicate.ugt: (operator.gt, False),
|
|
arith.CmpIPredicate.uge: (operator.ge, False),
|
|
}
|
|
|
|
|
|
@_register_lowering(arith.CmpIOp)
|
|
def _cmpi_op_lowering_rule(
|
|
_: LoweringContext, op: arith.CmpIOp
|
|
) -> Sequence[ir.Value]:
|
|
in_layouts = inference_utils.in_layouts(op)
|
|
[layout] = inference_utils.out_layouts(op)
|
|
if any(in_layout != layout for in_layout in in_layouts):
|
|
raise ValueError("Layout mismatch")
|
|
impl, is_signed = CMPI_IMPLS[op.predicate.value]
|
|
lhs = _fragmented_array_from_ir(op.lhs, layout, is_signed)
|
|
rhs = _fragmented_array_from_ir(op.rhs, layout, is_signed)
|
|
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
|
|
|
|
|
|
CMPF_IMPLS = {
|
|
arith.CmpFPredicate.OEQ: operator.eq,
|
|
arith.CmpFPredicate.UNE: operator.ne,
|
|
arith.CmpFPredicate.OLT: operator.lt,
|
|
arith.CmpFPredicate.OLE: operator.le,
|
|
arith.CmpFPredicate.OGT: operator.gt,
|
|
arith.CmpFPredicate.OGE: operator.ge,
|
|
}
|
|
|
|
|
|
@_register_lowering(arith.CmpFOp)
|
|
def _cmpf_op_lowering_rule(
|
|
_: LoweringContext, op: arith.CmpFOp
|
|
) -> Sequence[ir.Value]:
|
|
in_layouts = inference_utils.in_layouts(op)
|
|
[layout] = inference_utils.out_layouts(op)
|
|
if any(in_layout != layout for in_layout in in_layouts):
|
|
raise ValueError("Layout mismatch")
|
|
impl = CMPF_IMPLS[op.predicate.value]
|
|
lhs = _fragmented_array_from_ir(op.lhs, layout)
|
|
rhs = _fragmented_array_from_ir(op.rhs, layout)
|
|
return [_fragmented_array_to_ir(impl(lhs, rhs), op.result.type)]
|
|
|
|
|
|
@_register_lowering(arith.BitcastOp)
|
|
def _bitcast_op_lowering_rule(
|
|
_: LoweringContext, op: arith.BitcastOp
|
|
) -> Sequence[ir.Value]:
|
|
in_layouts = inference_utils.in_layouts(op)
|
|
[layout] = inference_utils.out_layouts(op)
|
|
if any(in_layout != layout for in_layout in in_layouts):
|
|
raise ValueError("Layout mismatch")
|
|
in_ = _fragmented_array_from_ir(op.in_, layout)
|
|
out_element_type = ir.VectorType(op.result.type).element_type
|
|
out = in_.bitcast(
|
|
out_element_type,
|
|
output_is_signed=False
|
|
if ir.IntegerType.isinstance(out_element_type)
|
|
else None,
|
|
)
|
|
return [_fragmented_array_to_ir(out, op.result.type)]
|
|
|
|
|
|
@_register_lowering(mgpu.WGMMAOp)
|
|
def _mgpu_wgmma_op_lowering_rule(
|
|
_: LoweringContext, wgmma_op: mgpu.WGMMAOp
|
|
) -> Sequence[ir.Value]:
|
|
fa_layouts = (
|
|
*inference_utils.in_layouts(wgmma_op),
|
|
*inference_utils.out_layouts(wgmma_op),
|
|
)
|
|
is_supported_layout = (
|
|
lambda l: layouts.from_tiled_layout_attr(l) == fa.WGMMA_LAYOUT
|
|
)
|
|
if not all(map(is_supported_layout, fa_layouts)):
|
|
raise ValueError("Layout mismatch")
|
|
wgmma_layout = fa_layouts[0]
|
|
|
|
# TODO(dasenov): Move the value -> accumulator conversion outisde of wgmma.
|
|
# The associated fence could be a little expensive and is not needed if the
|
|
# result a wgmma feeds into another wgmma (even in another loop step).
|
|
acc_in = _fragmented_array_from_ir(wgmma_op.accumulator, wgmma_layout)
|
|
regs = acc_in.to_layout(fa.WGMMA_LAYOUT)
|
|
acc = wgmma.WGMMAAccumulator.from_registers(regs)
|
|
|
|
b_layout = ir.MemRefType(wgmma_op.b.type).layout
|
|
b_swizzle, b_transforms = memref_layout_to_swizzle_and_transforms(b_layout)
|
|
b_operand = transform_memref(wgmma_op.b, b_transforms)
|
|
if wgmma_op.transpose_b:
|
|
b_operand = utils.memref_transpose(b_operand, (0, 1, 3, 2))
|
|
|
|
if ir.VectorType.isinstance(wgmma_op.a.type):
|
|
a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout)
|
|
else:
|
|
a_layout = ir.MemRefType(wgmma_op.a.type).layout
|
|
a_swizzle, a_transforms = memref_layout_to_swizzle_and_transforms(a_layout)
|
|
if a_swizzle != b_swizzle:
|
|
raise ValueError(
|
|
f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !="
|
|
f" {b_swizzle}"
|
|
)
|
|
a_operand = transform_memref(wgmma_op.a, a_transforms)
|
|
if wgmma_op.transpose_a:
|
|
a_operand = utils.memref_transpose(a_operand, (0, 1, 3, 2))
|
|
|
|
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
|
|
|
|
return [
|
|
_fragmented_array_to_ir(
|
|
new_acc.value.to_layout(fa.WGMMA_LAYOUT),
|
|
wgmma_op.accumulator.type,
|
|
)
|
|
]
|
|
|
|
|
|
@_register_lowering(mgpu.ArriveExpectTxOp)
|
|
def _mgpu_arrive_expect_tx_op_lowering_rule(
|
|
ctx: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp
|
|
) -> Sequence[ir.Value]:
|
|
|
|
barrier = utils.BarrierRef.from_dialect_barrier_memref(arrive_expect_tx_op.barrier)
|
|
barrier.arrive_expect_tx(
|
|
arrive_expect_tx_op.expect_tx.value,
|
|
ctx.single_thread_per_warpgroup_predicate,
|
|
)
|
|
|
|
return []
|
|
|
|
|
|
@_register_lowering(mgpu.WaitOp)
|
|
def _mgpu_wait_op_lowering_rule(
|
|
_: LoweringContext, wait_op: mgpu.WaitOp
|
|
) -> Sequence[ir.Value]:
|
|
|
|
barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier)
|
|
barrier.wait_parity(wait_op.parity)
|
|
|
|
return []
|
|
|
|
|
|
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
|
|
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
|
|
|
|
|
|
@_register_lowering(SliceSMEMOp)
|
|
def _mgpu_slice_smem_op_lowering_rule(
|
|
ctx: LoweringContext, op: SliceSMEMOp
|
|
) -> Sequence[ir.Value]:
|
|
del ctx
|
|
return [_slice_smem(op.result.type, op.offset)]
|
|
|
|
|
|
def _slice_smem(result: ir.Type, offset: ir.Value):
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
|
smem_base = gpu.dynamic_shared_memory(
|
|
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem)
|
|
)
|
|
offset = arith.index_cast(ir.IndexType.get(), offset)
|
|
return memref.view(result, smem_base, offset, [])
|
|
|
|
|
|
@_register_lowering(scf.ForOp)
|
|
def _for_op_lowering_rule(
|
|
ctx: LoweringContext, for_op: scf.ForOp
|
|
) -> MlirLoweringRuleResult:
|
|
if not inference_utils.should_have_layout(for_op):
|
|
return _traverse_op_lowering_rule(ctx, for_op)
|
|
in_layouts = inference_utils.in_layouts(for_op)
|
|
out_layouts = inference_utils.out_layouts(for_op)
|
|
yield_op = for_op.body.operations[len(for_op.body.operations) - 1]
|
|
yield_layouts = inference_utils.in_layouts(yield_op)
|
|
if in_layouts != out_layouts or in_layouts != yield_layouts:
|
|
raise ValueError("Layout mismatch")
|
|
fa_layouts = in_layouts
|
|
|
|
fa_layouts_it = iter(fa_layouts)
|
|
arg_template = [
|
|
(_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type)
|
|
if ir.VectorType.isinstance(arg.type)
|
|
else (arg, arg.type)
|
|
for arg in for_op.initArgs
|
|
]
|
|
def lower_carry(carry):
|
|
fa_layouts_it = iter(fa_layouts)
|
|
carry_with_fas = [
|
|
_fragmented_array_from_ir(arg, next(fa_layouts_it))
|
|
if ir.VectorType.isinstance(arg.type)
|
|
else arg
|
|
for arg in carry
|
|
]
|
|
lowered_carry = []
|
|
for c in carry_with_fas:
|
|
if isinstance(c, fa.FragmentedArray):
|
|
lowered_carry.extend(c.registers.flat)
|
|
else:
|
|
lowered_carry.append(c)
|
|
return lowered_carry
|
|
|
|
def recreate_carry(lowered_carry):
|
|
recreated_carry = []
|
|
arg_it = iter(lowered_carry)
|
|
for arg_value, arg_type in arg_template:
|
|
if isinstance(arg_value, fa.FragmentedArray):
|
|
carry_registers = np.asarray(
|
|
[next(arg_it) for _ in arg_value.registers.flat], dtype=object
|
|
)
|
|
carry_registers = carry_registers.reshape(arg_value.registers.shape)
|
|
carry = fa.FragmentedArray(
|
|
_registers=carry_registers,
|
|
_layout=arg_value.layout,
|
|
_is_signed=arg_value.is_signed,
|
|
)
|
|
recreated_carry.append(_fragmented_array_to_ir(carry, arg_type))
|
|
else:
|
|
recreated_carry.append(next(arg_it))
|
|
return recreated_carry
|
|
|
|
new_for_op = scf.ForOp(
|
|
for_op.lowerBound,
|
|
for_op.upperBound,
|
|
for_op.step,
|
|
lower_carry(for_op.initArgs),
|
|
)
|
|
with ir.InsertionPoint(new_for_op.body):
|
|
recreated_carry = recreate_carry(new_for_op.body.arguments[1:])
|
|
ops_to_lower = []
|
|
for op in for_op.body:
|
|
if op == yield_op:
|
|
continue
|
|
mgpu.private_operation_remove_from_parent(op)
|
|
mgpu.private_block_append_owned_operation(new_for_op.body, op)
|
|
ops_to_lower.append(op)
|
|
new_args = (new_for_op.induction_variable, *recreated_carry)
|
|
for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True):
|
|
old_carry.replace_all_uses_with(new_carry)
|
|
|
|
for op in ops_to_lower:
|
|
with ir.InsertionPoint(op):
|
|
ctx.lower_op(op)
|
|
|
|
with ir.InsertionPoint(new_for_op.body):
|
|
new_yield_operands = lower_carry(yield_op.operands)
|
|
yield_op.erase()
|
|
scf.yield_(new_yield_operands)
|
|
return recreate_carry(new_for_op.results)
|
|
|
|
|
|
@_register_lowering(func.FuncOp)
|
|
@_register_lowering(gpu.LaunchOp)
|
|
@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule.
|
|
@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule.
|
|
def _traverse_op_lowering_rule(
|
|
ctx: LoweringContext, op: ir.OpView
|
|
) -> MlirLoweringRuleResult:
|
|
if inference_utils.should_have_layout(op):
|
|
raise ValueError(
|
|
f"Rule cannot handle an op with vector operands or results: {op}"
|
|
)
|
|
for region in op.operation.regions:
|
|
for block in region:
|
|
for block_op in list(block):
|
|
with ir.InsertionPoint(block_op):
|
|
ctx.lower_op(block_op)
|
|
return RECURSED
|
|
|
|
|
|
def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
|
|
"""Returns a single thread predicate per block and one per warpgroup."""
|
|
block_predicate = warpgroup_predicate = None
|
|
for op in module.body.operations:
|
|
for region in op.operation.regions:
|
|
for block in region.blocks:
|
|
for sub_op in block.operations:
|
|
if sub_op.operation.name == "gpu.launch":
|
|
with ir.InsertionPoint.at_block_begin(
|
|
sub_op.operation.regions[0].blocks[0]
|
|
):
|
|
assert block_predicate is None
|
|
block_predicate = utils.single_thread_predicate(per_block=True)
|
|
warpgroup_predicate = utils.single_thread_predicate(
|
|
per_block=False
|
|
)
|
|
|
|
if block_predicate is None:
|
|
raise ValueError(
|
|
"No suitable function found to instantiate the single thread"
|
|
" predicates."
|
|
)
|
|
|
|
return block_predicate, warpgroup_predicate
|
|
|
|
|
|
def _should_lower(op: ir.OpView) -> bool:
|
|
"""Returns 'true' if the operation should be lowered."""
|
|
return (
|
|
op.OPERATION_NAME.startswith("mosaic_gpu.")
|
|
or inference_utils.should_have_layout(op)
|
|
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
|
|
)
|
|
|
|
|
|
def lower_mgpu_dialect(
|
|
module: ir.Module,
|
|
launch_context: launch_context.LaunchContext | None,
|
|
):
|
|
# TODO(apaszke,bchetioui): Make sure the layouts match.
|
|
# TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full
|
|
# module and to traverse all `gpu.LaunchOp`s if we have a `LaunchContext` that
|
|
# references a single `gpu.LaunchOp`.
|
|
#
|
|
# A `LaunchContext` should have all the information needed to lower a single
|
|
# kernel.
|
|
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
|
|
module.context.load_all_available_dialects()
|
|
|
|
# TODO(bchetioui): fix tests to not have a test-only path polluting the API.
|
|
if launch_context is None: # this case is used in some tests
|
|
block_predicate = warpgroup_predicate = None
|
|
else:
|
|
block_predicate, warpgroup_predicate = single_thread_predicates(module)
|
|
|
|
ctx = LoweringContext(launch_context, block_predicate, warpgroup_predicate)
|
|
with ir.InsertionPoint(module.body):
|
|
for op in list(module.body):
|
|
ctx.lower_op(op)
|
|
|
|
for lowered_op in ctx.lowered_operations:
|
|
lowered_op.erase()
|