mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Ensure that lowering InitializeBarrierOp
preserves the result's type.
Otherwise, the lowered IR won't be type-correct. PiperOrigin-RevId: 695339726
This commit is contained in:
parent
1d24630b41
commit
8a7bf2e4b0
@ -26,7 +26,7 @@ from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from .utils import c, single_thread_predicate
|
||||
from .utils import c, ptr_as_memref, single_thread_predicate
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
@ -89,7 +89,12 @@ def _initialize_barrier_op_lowering_rule(
|
||||
predicate=predicate
|
||||
)
|
||||
|
||||
return initialize_barrier_op.base_pointer,
|
||||
barrier_base_ptr = llvm.getelementptr(
|
||||
ir.Type.parse("!llvm.ptr"),
|
||||
initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type)
|
||||
|
||||
return ptr_as_memref(
|
||||
barrier_base_ptr, initialize_barrier_op.barriers_ref.type),
|
||||
|
||||
|
||||
def lower_mgpu_dialect(module: ir.Module):
|
||||
|
@ -25,6 +25,7 @@ from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
|
||||
@ -512,11 +513,17 @@ class DialectLoweringTest(DialectTest):
|
||||
arrival_count = 1337
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
barriers_ref = mgpu.initialize_barrier(
|
||||
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=arrival_count)
|
||||
# Add a user for barriers_ref to make sure that the lowering keeps types
|
||||
# consistent.
|
||||
memref.copy(barriers_ref, barriers_ref)
|
||||
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
lower_mgpu_dialect(self.module)
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
|
||||
all_mbarrier_init_shared_ops = find_if(
|
||||
self.module,
|
||||
|
Loading…
x
Reference in New Issue
Block a user