mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic GPU] Use a single instance of the single_thread_predicate
in the MLIR dialect lowering.
PiperOrigin-RevId: 720155654
This commit is contained in:
parent
9b5cb45bc3
commit
a0db6c5cf6
@ -15,6 +15,7 @@
|
||||
"""Lowering rules and pass for the MLIR Mosaic GPU dialect."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import dataclasses
|
||||
import functools
|
||||
import operator
|
||||
from typing import Sequence, Type, cast
|
||||
@ -38,8 +39,15 @@ from . import utils
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class LoweringContext:
|
||||
launch_context: launch_context.LaunchContext | None
|
||||
single_thread_per_block_predicate: ir.Value | None
|
||||
single_thread_per_warpgroup_predicate: ir.Value | None
|
||||
|
||||
|
||||
MlirLoweringRule = Callable[
|
||||
[launch_context.LaunchContext, ir.Operation | ir.OpView], Sequence[ir.Value]
|
||||
[LoweringContext, ir.Operation | ir.OpView], Sequence[ir.Value]
|
||||
]
|
||||
|
||||
|
||||
@ -130,7 +138,7 @@ def _lowered_barrier_type() -> ir.Type:
|
||||
|
||||
@_register_lowering(InitializeBarrierOp)
|
||||
def _initialize_barrier_op_lowering_rule(
|
||||
_: launch_context.LaunchContext,
|
||||
ctx: LoweringContext,
|
||||
initialize_barrier_op: InitializeBarrierOp,
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
@ -144,15 +152,16 @@ def _initialize_barrier_op_lowering_rule(
|
||||
|
||||
lowered_barrier_type = _lowered_barrier_type()
|
||||
|
||||
predicate = utils.single_thread_predicate(per_block=True)
|
||||
for i in range(num_barriers):
|
||||
nvvm.mbarrier_init_shared(
|
||||
llvm.getelementptr(ptr_ty, initialize_barrier_op.base_pointer, [], [i],
|
||||
lowered_barrier_type),
|
||||
utils.c(initialize_barrier_op.arrival_count.value, i32),
|
||||
predicate=predicate
|
||||
predicate=ctx.single_thread_per_block_predicate
|
||||
)
|
||||
|
||||
gpu.barrier()
|
||||
|
||||
barrier_base_ptr = llvm.getelementptr(
|
||||
ir.Type.parse("!llvm.ptr"),
|
||||
initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type)
|
||||
@ -163,7 +172,7 @@ def _initialize_barrier_op_lowering_rule(
|
||||
|
||||
@_register_lowering(vector.LoadOp)
|
||||
def _vector_load_op_lowering_rule(
|
||||
_: launch_context.LaunchContext, vector_load_op: vector.LoadOp
|
||||
_: LoweringContext, vector_load_op: vector.LoadOp
|
||||
) -> Sequence[ir.Value]:
|
||||
(out_layout_attr,) = cast(
|
||||
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
|
||||
@ -192,7 +201,7 @@ def _vector_load_op_lowering_rule(
|
||||
|
||||
@_register_lowering(vector.StoreOp)
|
||||
def _vector_store_op_lowering_rule(
|
||||
_: launch_context.LaunchContext, vector_store_op: vector.StoreOp
|
||||
_: LoweringContext, vector_store_op: vector.StoreOp
|
||||
) -> Sequence[ir.Value]:
|
||||
for i in vector_store_op.indices:
|
||||
index_defining_op = i.owner.opview
|
||||
@ -216,38 +225,40 @@ def _vector_store_op_lowering_rule(
|
||||
|
||||
@_register_lowering(mgpu.AsyncLoadOp)
|
||||
def _mgpu_async_load_op_lowering_rule(
|
||||
launch_context: launch_context.LaunchContext, load_op: mgpu.AsyncLoadOp
|
||||
ctx: LoweringContext, load_op: mgpu.AsyncLoadOp
|
||||
) -> Sequence[ir.Value]:
|
||||
with utils.single_thread():
|
||||
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
launch_context.async_copy(
|
||||
src_ref=load_op.source,
|
||||
dst_ref=load_op.destination,
|
||||
barrier=barrier,
|
||||
arrive=load_op.arrive,
|
||||
uniform=False,
|
||||
swizzle=load_op.swizzle.value,
|
||||
)
|
||||
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
ctx.launch_context.async_copy(
|
||||
src_ref=load_op.source,
|
||||
dst_ref=load_op.destination,
|
||||
barrier=barrier,
|
||||
arrive=load_op.arrive,
|
||||
uniform=True,
|
||||
swizzle=load_op.swizzle.value,
|
||||
predicate=ctx.single_thread_per_warpgroup_predicate,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(mgpu.AsyncStoreOp)
|
||||
def _mgpu_async_store_op_lowering_rule(
|
||||
launch_context: launch_context.LaunchContext, store_op: mgpu.AsyncStoreOp
|
||||
ctx: LoweringContext, store_op: mgpu.AsyncStoreOp
|
||||
) -> Sequence[ir.Value]:
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
launch_context.async_copy(
|
||||
ctx.launch_context.async_copy(
|
||||
src_ref=store_op.source,
|
||||
dst_ref=store_op.destination,
|
||||
swizzle=store_op.swizzle.value,
|
||||
uniform=True,
|
||||
predicate=ctx.single_thread_per_warpgroup_predicate,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(arith.AddFOp)
|
||||
def _arith_addf_op_lowering_rule(
|
||||
_: launch_context.LaunchContext, add: arith.AddFOp
|
||||
_: LoweringContext, add: arith.AddFOp
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
|
||||
@ -260,13 +271,47 @@ def _arith_addf_op_lowering_rule(
|
||||
]
|
||||
|
||||
|
||||
def instantiate_single_thread_predicates(module: ir.Module) -> LoweringContext:
|
||||
block_predicate = None
|
||||
warpgroup_predicate = None
|
||||
for op in module.body.operations:
|
||||
for region in op.operation.regions:
|
||||
for block in region.blocks:
|
||||
for sub_op in block.operations:
|
||||
if sub_op.operation.name == "gpu.launch":
|
||||
with ir.InsertionPoint.at_block_begin(
|
||||
sub_op.operation.regions[0].blocks[0]
|
||||
):
|
||||
assert block_predicate is None
|
||||
block_predicate = utils.single_thread_predicate(per_block=True)
|
||||
warpgroup_predicate = utils.single_thread_predicate(
|
||||
per_block=False
|
||||
)
|
||||
|
||||
if block_predicate is None:
|
||||
raise ValueError(
|
||||
"No suitable function found to instantiate the single thread"
|
||||
" predicates."
|
||||
)
|
||||
|
||||
return block_predicate, warpgroup_predicate
|
||||
|
||||
|
||||
def lower_mgpu_dialect(
|
||||
module: ir.Module, launch_context: launch_context.LaunchContext
|
||||
module: ir.Module, launch_context: launch_context.LaunchContext | None
|
||||
):
|
||||
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
|
||||
module.context.load_all_available_dialects()
|
||||
|
||||
lowered_operations: set[ir.Operation | ir.OpView] = set()
|
||||
if launch_context is None: # this case is used in some tests
|
||||
block_predicate = warpgroup_predicate = None
|
||||
else:
|
||||
block_predicate, warpgroup_predicate = instantiate_single_thread_predicates(
|
||||
module
|
||||
)
|
||||
|
||||
ctx = LoweringContext(launch_context, block_predicate, warpgroup_predicate)
|
||||
|
||||
def _lower_op(op: ir.OpView):
|
||||
if op.name not in _lowerings:
|
||||
@ -277,7 +322,7 @@ def lower_mgpu_dialect(
|
||||
if layouts.should_have_layout(op) and not layouts.has_any_layout_set(op):
|
||||
raise ValueError(f"{op} is missing a layout and can not be lowered.")
|
||||
|
||||
new_results = lowering_rule(launch_context, op)
|
||||
new_results = lowering_rule(ctx, op)
|
||||
|
||||
for old, new in zip(op.results, new_results):
|
||||
old.replace_all_uses_with(new)
|
||||
|
Loading…
x
Reference in New Issue
Block a user