diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 3534508b6..dcdfe62bb 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -237,7 +237,7 @@ class ModuleContext: program_ids: Sequence[ir.Value] | None approx_math: bool single_wg_lane_predicate: ir.Value - runtime_smem: ir.Value # ir.MemRefType + smem_requested_bytes: int smem_used_bytes: int runtime_barriers: MutableMapping[ mgpu.Barrier, MutableSequence[mgpu.BarrierRef] @@ -279,25 +279,38 @@ class ModuleContext: and the second element is a sequence of memref views into the runtime scratch buffer. """ - smem_scratch_bytes = math.prod(ir.MemRefType(self.runtime_smem.type).shape) - + smem_base = None + smem = ir.Attribute.parse("#gpu.address_space") + i8 = ir.IntegerType.get_signless(8) + i32 = ir.IntegerType.get_signless(32) + if self.thread_semantics == mgpu.ThreadSemantics.Lane: + smem_base = gpu_dialect.dynamic_shared_memory( + ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem) + ) views = [] off = initial_used_bytes = self.smem_used_bytes assert off % _SMEM_ALIGNMENT == 0 - smem = ir.Attribute.parse("#gpu.address_space") for s in structs: scratch_ty = ir.MemRefType.get( s.shape, mgpu_utils.dtype_to_ir_type(s.dtype), memory_space=smem, ) - views.append( - memref_dialect.view(scratch_ty, self.runtime_smem, _as_index(off), []) - ) + # The below code emission relies on the assumption that the first scratch + # operand provided by Mosaic GPU always begins at the beginning of + # dynamic SMEM. Mosaic GPU is expected to uphold that invariant. + if self.thread_semantics == mgpu.ThreadSemantics.Lane: + view = memref_dialect.view( + scratch_ty, smem_base, _as_index(off), [] + ) + else: + view = mgpu.dialect.slice_smem(scratch_ty, mgpu_utils.c(off, i32)) + views.append(view) + off += _align_to( math.prod(s.shape) * jnp.dtype(s.dtype).itemsize, _SMEM_ALIGNMENT ) - assert off <= smem_scratch_bytes, "Ran out of scoped SMEM" + assert off <= self.smem_requested_bytes, "Ran out of scoped SMEM" assert off % _SMEM_ALIGNMENT == 0 self.smem_used_bytes = off @@ -596,7 +609,7 @@ def lower_jaxpr_to_module( [_program_id(axis, squashed_dims) for axis in range(len(grid))], approx_math, mgpu.single_thread_predicate(per_block=False), - runtime_smem, + smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, runtime_barriers=grouped_barriers, name_stack=source_info_util.NameStack(), diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 4602ba99e..024b0c67b 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -28,6 +28,7 @@ from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector @@ -598,15 +599,25 @@ def _mgpu_wait_op_lowering_rule( return [] -@_register_lowering(WaitOp) -def _for_op_lowering_rule( - _: LoweringContext, wait_op: scf.ForOp +# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. +SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) + + +@_register_lowering(SliceSMEMOp) +def _mgpu_slice_smem_op_lowering_rule( + ctx: LoweringContext, op: SliceSMEMOp ) -> Sequence[ir.Value]: + del ctx + i8 = ir.IntegerType.get_signless(8) + smem = ir.Attribute.parse("#gpu.address_space") - barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier) - barrier.wait_parity(wait_op.parity) + smem_base = gpu.dynamic_shared_memory( + ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem) + ) - return [] + offset = arith.index_cast(ir.IndexType.get(), op.offset) + + return [memref.view(op.result.type, smem_base, offset, [])] @_register_lowering(scf.ForOp) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index eb26c811b..48c6a0464 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -372,6 +372,14 @@ def MosaicGPU_WGMMALayout : let cppNamespace = "::mosaic_gpu"; } + +def MosaicGPU_SliceSMEMOp : Op { + let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address."; + + let arguments = (ins I32:$offset); + let results = (outs MemRefOf<[AnyType]>); +} + def MosaicGPU_WGMMAOp : Op { let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations."; let description = [{ diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 1bee3123d..94d5d6714 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -802,6 +802,28 @@ class DialectLoweringTest(MosaicGpuTest): reg_vec_ty = ir.VectorType.get((2,), i32) self.assertSequenceEqual(result_types, [i32, reg_vec_ty, reg_vec_ty]) + def test_lowering_slice_smem_op(self): + shift = 1234 + offset = None + + def body(): + nonlocal offset + i32 = ir.IntegerType.get_signless(32) + offset = arith.constant(i32, shift) + mgpu.dialect.slice_smem(i32, offset) + + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func()(body) + + mgpu.lower_mgpu_dialect(self.module, None) + # Avoid making a change detector, only validate that lowering runs as + # expected. + self.assertEmpty( + find_if( + self.module, lambda op: isinstance(op, mgpu.dialect.SliceSMEMOp) + ) + ) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader())