[Mosaic GPU] Use a single instance of the single_thread_predicate in the MLIR dialect lowering.

PiperOrigin-RevId: 720155654
This commit is contained in:
Dimitar (Mitko) Asenov 2025-01-27 07:03:29 -08:00 committed by jax authors
parent 9b5cb45bc3
commit a0db6c5cf6

View File

@ -15,6 +15,7 @@
"""Lowering rules and pass for the MLIR Mosaic GPU dialect."""
from collections.abc import Callable
import dataclasses
import functools
import operator
from typing import Sequence, Type, cast
@ -38,8 +39,15 @@ from . import utils
# mypy: ignore-errors
@dataclasses.dataclass()
class LoweringContext:
launch_context: launch_context.LaunchContext | None
single_thread_per_block_predicate: ir.Value | None
single_thread_per_warpgroup_predicate: ir.Value | None
MlirLoweringRule = Callable[
[launch_context.LaunchContext, ir.Operation | ir.OpView], Sequence[ir.Value]
[LoweringContext, ir.Operation | ir.OpView], Sequence[ir.Value]
]
@ -130,7 +138,7 @@ def _lowered_barrier_type() -> ir.Type:
@_register_lowering(InitializeBarrierOp)
def _initialize_barrier_op_lowering_rule(
_: launch_context.LaunchContext,
ctx: LoweringContext,
initialize_barrier_op: InitializeBarrierOp,
) -> Sequence[ir.Value]:
@ -144,15 +152,16 @@ def _initialize_barrier_op_lowering_rule(
lowered_barrier_type = _lowered_barrier_type()
predicate = utils.single_thread_predicate(per_block=True)
for i in range(num_barriers):
nvvm.mbarrier_init_shared(
llvm.getelementptr(ptr_ty, initialize_barrier_op.base_pointer, [], [i],
lowered_barrier_type),
utils.c(initialize_barrier_op.arrival_count.value, i32),
predicate=predicate
predicate=ctx.single_thread_per_block_predicate
)
gpu.barrier()
barrier_base_ptr = llvm.getelementptr(
ir.Type.parse("!llvm.ptr"),
initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type)
@ -163,7 +172,7 @@ def _initialize_barrier_op_lowering_rule(
@_register_lowering(vector.LoadOp)
def _vector_load_op_lowering_rule(
_: launch_context.LaunchContext, vector_load_op: vector.LoadOp
_: LoweringContext, vector_load_op: vector.LoadOp
) -> Sequence[ir.Value]:
(out_layout_attr,) = cast(
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
@ -192,7 +201,7 @@ def _vector_load_op_lowering_rule(
@_register_lowering(vector.StoreOp)
def _vector_store_op_lowering_rule(
_: launch_context.LaunchContext, vector_store_op: vector.StoreOp
_: LoweringContext, vector_store_op: vector.StoreOp
) -> Sequence[ir.Value]:
for i in vector_store_op.indices:
index_defining_op = i.owner.opview
@ -216,38 +225,40 @@ def _vector_store_op_lowering_rule(
@_register_lowering(mgpu.AsyncLoadOp)
def _mgpu_async_load_op_lowering_rule(
launch_context: launch_context.LaunchContext, load_op: mgpu.AsyncLoadOp
ctx: LoweringContext, load_op: mgpu.AsyncLoadOp
) -> Sequence[ir.Value]:
with utils.single_thread():
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
# TODO(dasenov): Add support for the remaining op properties.
launch_context.async_copy(
src_ref=load_op.source,
dst_ref=load_op.destination,
barrier=barrier,
arrive=load_op.arrive,
uniform=False,
swizzle=load_op.swizzle.value,
)
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
# TODO(dasenov): Add support for the remaining op properties.
ctx.launch_context.async_copy(
src_ref=load_op.source,
dst_ref=load_op.destination,
barrier=barrier,
arrive=load_op.arrive,
uniform=True,
swizzle=load_op.swizzle.value,
predicate=ctx.single_thread_per_warpgroup_predicate,
)
return []
@_register_lowering(mgpu.AsyncStoreOp)
def _mgpu_async_store_op_lowering_rule(
launch_context: launch_context.LaunchContext, store_op: mgpu.AsyncStoreOp
ctx: LoweringContext, store_op: mgpu.AsyncStoreOp
) -> Sequence[ir.Value]:
# TODO(dasenov): Add support for the remaining op properties.
launch_context.async_copy(
ctx.launch_context.async_copy(
src_ref=store_op.source,
dst_ref=store_op.destination,
swizzle=store_op.swizzle.value,
uniform=True,
predicate=ctx.single_thread_per_warpgroup_predicate,
)
return []
@_register_lowering(arith.AddFOp)
def _arith_addf_op_lowering_rule(
_: launch_context.LaunchContext, add: arith.AddFOp
_: LoweringContext, add: arith.AddFOp
) -> Sequence[ir.Value]:
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
@ -260,13 +271,47 @@ def _arith_addf_op_lowering_rule(
]
def instantiate_single_thread_predicates(module: ir.Module) -> LoweringContext:
block_predicate = None
warpgroup_predicate = None
for op in module.body.operations:
for region in op.operation.regions:
for block in region.blocks:
for sub_op in block.operations:
if sub_op.operation.name == "gpu.launch":
with ir.InsertionPoint.at_block_begin(
sub_op.operation.regions[0].blocks[0]
):
assert block_predicate is None
block_predicate = utils.single_thread_predicate(per_block=True)
warpgroup_predicate = utils.single_thread_predicate(
per_block=False
)
if block_predicate is None:
raise ValueError(
"No suitable function found to instantiate the single thread"
" predicates."
)
return block_predicate, warpgroup_predicate
def lower_mgpu_dialect(
module: ir.Module, launch_context: launch_context.LaunchContext
module: ir.Module, launch_context: launch_context.LaunchContext | None
):
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
module.context.load_all_available_dialects()
lowered_operations: set[ir.Operation | ir.OpView] = set()
if launch_context is None: # this case is used in some tests
block_predicate = warpgroup_predicate = None
else:
block_predicate, warpgroup_predicate = instantiate_single_thread_predicates(
module
)
ctx = LoweringContext(launch_context, block_predicate, warpgroup_predicate)
def _lower_op(op: ir.OpView):
if op.name not in _lowerings:
@ -277,7 +322,7 @@ def lower_mgpu_dialect(
if layouts.should_have_layout(op) and not layouts.has_any_layout_set(op):
raise ValueError(f"{op} is missing a layout and can not be lowered.")
new_results = lowering_rule(launch_context, op)
new_results = lowering_rule(ctx, op)
for old, new in zip(op.results, new_results):
old.replace_all_uses_with(new)