mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas module context---which should be enforced by Mosaic GPU. This is in preparation of changes implementing transform inference. PiperOrigin-RevId: 732091266
This commit is contained in:
parent
55263ce485
commit
a9ab614123
@ -237,7 +237,7 @@ class ModuleContext:
|
||||
program_ids: Sequence[ir.Value] | None
|
||||
approx_math: bool
|
||||
single_wg_lane_predicate: ir.Value
|
||||
runtime_smem: ir.Value # ir.MemRefType
|
||||
smem_requested_bytes: int
|
||||
smem_used_bytes: int
|
||||
runtime_barriers: MutableMapping[
|
||||
mgpu.Barrier, MutableSequence[mgpu.BarrierRef]
|
||||
@ -279,25 +279,38 @@ class ModuleContext:
|
||||
and the second element is a sequence of memref views into the
|
||||
runtime scratch buffer.
|
||||
"""
|
||||
smem_scratch_bytes = math.prod(ir.MemRefType(self.runtime_smem.type).shape)
|
||||
|
||||
smem_base = None
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if self.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
smem_base = gpu_dialect.dynamic_shared_memory(
|
||||
ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem)
|
||||
)
|
||||
views = []
|
||||
off = initial_used_bytes = self.smem_used_bytes
|
||||
assert off % _SMEM_ALIGNMENT == 0
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
for s in structs:
|
||||
scratch_ty = ir.MemRefType.get(
|
||||
s.shape,
|
||||
mgpu_utils.dtype_to_ir_type(s.dtype),
|
||||
memory_space=smem,
|
||||
)
|
||||
views.append(
|
||||
memref_dialect.view(scratch_ty, self.runtime_smem, _as_index(off), [])
|
||||
)
|
||||
# The below code emission relies on the assumption that the first scratch
|
||||
# operand provided by Mosaic GPU always begins at the beginning of
|
||||
# dynamic SMEM. Mosaic GPU is expected to uphold that invariant.
|
||||
if self.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
view = memref_dialect.view(
|
||||
scratch_ty, smem_base, _as_index(off), []
|
||||
)
|
||||
else:
|
||||
view = mgpu.dialect.slice_smem(scratch_ty, mgpu_utils.c(off, i32))
|
||||
views.append(view)
|
||||
|
||||
off += _align_to(
|
||||
math.prod(s.shape) * jnp.dtype(s.dtype).itemsize, _SMEM_ALIGNMENT
|
||||
)
|
||||
assert off <= smem_scratch_bytes, "Ran out of scoped SMEM"
|
||||
assert off <= self.smem_requested_bytes, "Ran out of scoped SMEM"
|
||||
assert off % _SMEM_ALIGNMENT == 0
|
||||
|
||||
self.smem_used_bytes = off
|
||||
@ -596,7 +609,7 @@ def lower_jaxpr_to_module(
|
||||
[_program_id(axis, squashed_dims) for axis in range(len(grid))],
|
||||
approx_math,
|
||||
mgpu.single_thread_predicate(per_block=False),
|
||||
runtime_smem,
|
||||
smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape),
|
||||
smem_used_bytes=0,
|
||||
runtime_barriers=grouped_barriers,
|
||||
name_stack=source_info_util.NameStack(),
|
||||
|
@ -28,6 +28,7 @@ from jax._src.lib.mlir.dialects import builtin
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
@ -598,15 +599,25 @@ def _mgpu_wait_op_lowering_rule(
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(WaitOp)
|
||||
def _for_op_lowering_rule(
|
||||
_: LoweringContext, wait_op: scf.ForOp
|
||||
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
|
||||
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
|
||||
|
||||
|
||||
@_register_lowering(SliceSMEMOp)
|
||||
def _mgpu_slice_smem_op_lowering_rule(
|
||||
ctx: LoweringContext, op: SliceSMEMOp
|
||||
) -> Sequence[ir.Value]:
|
||||
del ctx
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
|
||||
barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier)
|
||||
barrier.wait_parity(wait_op.parity)
|
||||
smem_base = gpu.dynamic_shared_memory(
|
||||
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem)
|
||||
)
|
||||
|
||||
return []
|
||||
offset = arith.index_cast(ir.IndexType.get(), op.offset)
|
||||
|
||||
return [memref.view(op.result.type, smem_base, offset, [])]
|
||||
|
||||
|
||||
@_register_lowering(scf.ForOp)
|
||||
|
@ -372,6 +372,14 @@ def MosaicGPU_WGMMALayout :
|
||||
let cppNamespace = "::mosaic_gpu";
|
||||
}
|
||||
|
||||
|
||||
def MosaicGPU_SliceSMEMOp : Op<MosaicGPU_Dialect, "slice_smem", []> {
|
||||
let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address.";
|
||||
|
||||
let arguments = (ins I32:$offset);
|
||||
let results = (outs MemRefOf<[AnyType]>);
|
||||
}
|
||||
|
||||
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
|
||||
let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations.";
|
||||
let description = [{
|
||||
|
@ -802,6 +802,28 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
reg_vec_ty = ir.VectorType.get((2,), i32)
|
||||
self.assertSequenceEqual(result_types, [i32, reg_vec_ty, reg_vec_ty])
|
||||
|
||||
def test_lowering_slice_smem_op(self):
|
||||
shift = 1234
|
||||
offset = None
|
||||
|
||||
def body():
|
||||
nonlocal offset
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
offset = arith.constant(i32, shift)
|
||||
mgpu.dialect.slice_smem(i32, offset)
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func()(body)
|
||||
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
# Avoid making a change detector, only validate that lowering runs as
|
||||
# expected.
|
||||
self.assertEmpty(
|
||||
find_if(
|
||||
self.module, lambda op: isinstance(op, mgpu.dialect.SliceSMEMOp)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user