1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

[Mosaic GPU] Ensure that lowering InitializeBarrierOp preserves the result's type.

Otherwise, the lowered IR won't be type-correct.

PiperOrigin-RevId: 695339726
This commit is contained in:
Benjamin Chetioui 2024-11-11 08:01:12 -08:00 committed by jax authors
parent 1d24630b41
commit 8a7bf2e4b0
2 changed files with 15 additions and 3 deletions
jax/experimental/mosaic/gpu
tests/mosaic

@ -26,7 +26,7 @@ from jaxlib.mlir import ir
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import nvvm
from .utils import c, single_thread_predicate
from .utils import c, ptr_as_memref, single_thread_predicate
# mypy: ignore-errors
@ -89,7 +89,12 @@ def _initialize_barrier_op_lowering_rule(
predicate=predicate
)
return initialize_barrier_op.base_pointer,
barrier_base_ptr = llvm.getelementptr(
ir.Type.parse("!llvm.ptr"),
initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type)
return ptr_as_memref(
barrier_base_ptr, initialize_barrier_op.barriers_ref.type),
def lower_mgpu_dialect(module: ir.Module):

@ -25,6 +25,7 @@ from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import nvvm
from jax._src.lib.mlir.dialects import scf
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
@ -512,11 +513,17 @@ class DialectLoweringTest(DialectTest):
arrival_count = 1337
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
barriers_ref = mgpu.initialize_barrier(
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=arrival_count)
# Add a user for barriers_ref to make sure that the lowering keeps types
# consistent.
memref.copy(barriers_ref, barriers_ref)
self.assertTrue(self.module.operation.verify())
lower_mgpu_dialect(self.module)
self.assertTrue(self.module.operation.verify())
all_mbarrier_init_shared_ops = find_if(
self.module,