[Mosaic GPU] Handle TMEM allocation in the compiler

The code for allocation is uninteresting and it's the only set of primitives
that is executed by a single warp (other TMA APIs have single-thread or
warpgroup issue granularity).

PiperOrigin-RevId: 725583720
This commit is contained in:
Adam Paszke 2025-02-11 05:00:50 -08:00 committed by jax authors
parent 5a2235163b
commit 0209eee185
5 changed files with 94 additions and 46 deletions

View File

@ -16,11 +16,15 @@
from jax import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401
# The imports below shadow the module, so we need to rename it.
from . import wgmma as _wgmma # noqa: F401
from .core import (
Barrier as Barrier,
ClusterBarrier as ClusterBarrier,
TMABarrier as TMABarrier,
ThreadSemantics as ThreadSemantics,
TMEM as TMEM,
Union as Union,
as_gpu_kernel as as_gpu_kernel,
)
@ -85,8 +89,6 @@ from .utils import (
warpgroup_idx as warpgroup_idx,
when as when,
)
# The import below shadows the module, so we need to rename it.
from . import wgmma as _wgmma # noqa: F401
from .wgmma import (
WGMMAAccumulator as WGMMAAccumulator,
wgmma as wgmma,

View File

@ -18,12 +18,13 @@ import contextlib
import ctypes
import dataclasses
import enum
import functools
import hashlib
import math
import os
import pathlib
import time
from typing import Any, Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar
import weakref
import jax
@ -31,6 +32,7 @@ from jax._src.interpreters import mlir
from jax._src.lib import mosaic_gpu_dialect as dialect
from jaxlib.mlir import ir
from jaxlib.mlir import passmanager
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import builtin
from jaxlib.mlir.dialects import func
from jaxlib.mlir.dialects import gpu
@ -49,6 +51,7 @@ else:
from . import profiler
from . import utils
from . import launch_context
from . import tcgen05
# mypy: ignore-errors
@ -163,6 +166,19 @@ class ClusterBarrier:
collective_dims: Sequence[gpu.Dimension]
num_barriers: int = 1
@dataclasses.dataclass(frozen=True)
class TMEM:
shape: tuple[int, int]
dtype: Any
layout: tcgen05.TMEMLayout
def __post_init__(self):
if self.shape[0] != self.layout.num_rows:
raise ValueError(
f"Row must match layout={self.layout} ({self.layout.num_rows}), but"
f" got {self.shape[0]}"
)
def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize
@ -179,10 +195,12 @@ def _construct_smem_reftree(
cluster_shape: tuple[int, int, int],
dynamic_smem: ir.Value,
smem_buffers: ShapeTree,
delayed_warp_init: list[Callable[[], None]], # Mutated by this function!
dynamic_smem_offset: int = 0,
) -> RefTree:
) -> Callable[[], RefTree]:
index = ir.IndexType.get()
i8 = ir.IntegerType.get_signless(8)
i32 = ir.IntegerType.get_signless(32)
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
@ -205,13 +223,17 @@ def _construct_smem_reftree(
return barrier_base_ptr
match ref_ty:
case Union(members):
member_trees = [
_construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset)
member_thunks = [
_construct_smem_reftree(
cluster_shape, dynamic_smem, m,
delayed_warp_init, dynamic_smem_offset,
)
for m in members
]
# TODO(apaszke): This is quadratic, but it shouldn't matter for now...
dynamic_smem_offset += _smem_tree_size(ref_ty)
ref = Union(member_trees)
def ref(member_thunks=member_thunks):
return Union([t() for t in member_thunks])
case TMABarrier(num_barriers):
ref = utils.BarrierRef.initialize(
get_barrier_ptr(num_barriers), num_barriers, arrival_count=1
@ -229,6 +251,20 @@ def _construct_smem_reftree(
collective_dims,
cluster_shape,
)
case TMEM(shape, dtype, layout):
addr_ref = memref.view(
ir.MemRefType.get([], i32, memory_space=smem),
dynamic_smem, c(dynamic_smem_offset, index), [],
)
delayed_warp_init.append(
functools.partial(tcgen05.tmem_alloc, addr_ref, shape[1], exact=False)
)
def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout):
addr = memref.load(addr_ref, [])
return tcgen05.TMEMRef(
addr, layout, shape[1], utils.dtype_to_ir_type(dtype)
)
dynamic_smem_offset += 4 # i32 takes up 4 bytes
case _:
mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype)
tile_smem = memref.view(
@ -238,7 +274,14 @@ def _construct_smem_reftree(
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
ref = tile_smem
smem_refs.append(ref)
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
def ref_tree_thunk():
refs = []
for ref in smem_refs:
if callable(ref):
ref = ref()
refs.append(ref)
return jax.tree.unflatten(smem_buffer_tree, refs)
return ref_tree_thunk
def _smem_tree_size(smem_buffers: ShapeTree) -> int:
@ -258,6 +301,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
if size % utils.MBARRIER_BYTES:
raise NotImplementedError("Misaligned barrier allocation")
size += num_barriers * utils.MBARRIER_BYTES
case TMEM(_):
size += 4 # i32 takes up 4 bytes
case _:
size += _count_buffer_bytes(l)
return size
@ -336,15 +381,25 @@ def _launch(
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
ctx = launch_context.LaunchContext(launch_op, scratch_ptr, cluster, prof)
with ctx.named_region("Init"):
smem_ref_tree = _construct_smem_reftree(
cluster, dynamic_smem, smem_buffers
delayed_warp_init = []
smem_ref_tree_thunk = _construct_smem_reftree(
cluster, dynamic_smem, smem_buffers, delayed_warp_init
)
# TODO(apaszke): Skip the following if no barriers were initialized.
# TODO(apaszke): Skip fences if no barriers or TMEM is initialized.
# TODO(apaszke): Only initialize cluster barriers before the cluster wait.
nvvm.fence_mbarrier_init()
if math.prod(cluster) != 1:
nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get())
nvvm.cluster_wait(aligned=ir.UnitAttr.get())
gpu.barrier()
if delayed_warp_init:
eq = arith.CmpIPredicate.eq
is_init_warp = arith.cmpi(eq, utils.warp_idx(sync=False), c(0, i32))
with utils.when(is_init_warp):
for init in delayed_warp_init:
init()
tcgen05.tmem_relinquish_alloc_permit()
gpu.barrier() # Make sure the init is visible to all threads.
smem_ref_tree = smem_ref_tree_thunk()
yield ctx, smem_ref_tree
if prof is not None:

View File

@ -21,7 +21,7 @@ from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import nvvm
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import c, ds, utils
from jax.experimental.mosaic.gpu import c, ds
from jax.experimental.mosaic.gpu import tcgen05
import jax.numpy as jnp
import jax.random as jr
@ -65,7 +65,7 @@ def build_kernel(
tma_tile_kn = 64
def kernel(ctx, a, b, d, smem):
a_smem, b_smem, d_smem, barriers, mma_done_barrier, tmem_addr = smem
a_smem, b_smem, d_smem, barriers, mma_done_barrier, acc = smem
(ab_full_barriers, ab_empty_barriers) = barriers
warp_idx = mgpu.warp_idx(sync=True)
@ -109,18 +109,14 @@ def build_kernel(
**common_args,
)
with mgpu.when(is_warp(MMA_WARP)):
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
tcgen05.tmem_alloc(tmem_addr_addr, tile_n)
tcgen05.tmem_relinquish_alloc_permit()
tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
with mgpu.when(arith.andi(is_warp(MMA_WARP), warp_leader)):
with mgpu.when(warp_leader):
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
def _mma_body(ki, accumulate):
slot = arith.remui(ki, c(max_concurrent_steps, index))
ab_full_barriers[slot].wait()
tcgen05.mma(
tmem_ref,
acc,
mgpu.memref_slice(a_smem, slot),
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)),
a_swizzle=swizzle,
@ -142,9 +138,9 @@ def build_kernel(
gpu.barrier()
mma_done_barrier.wait()
tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
tmem_ref[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
acc[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
mgpu.commit_shared()
# TODO(apaszke): Free up TMEM?
ctx.async_copy(
src_ref=d_smem,
dst_ref=d,
@ -161,7 +157,7 @@ def build_kernel(
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
mgpu.Barrier(arrival_count=1),
jax.ShapeDtypeStruct((1,), np.uint32), # TMEM address
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D),
)
return mgpu.as_gpu_kernel(
kernel,

View File

@ -297,6 +297,12 @@ class TMEMLayout(enum.Enum):
"""
D = "D"
@property
def num_rows(self) -> int:
match self:
case TMEMLayout.D:
return 128
@dataclasses.dataclass(frozen=True)
class TMEMRef:
@ -327,11 +333,7 @@ class TMEMRef:
@property
def num_rows(self):
match self.layout:
case TMEMLayout.D:
return 128
case _:
raise NotImplementedError(self.layout)
return self.layout.num_rows
@property
def shape(self):

View File

@ -968,19 +968,17 @@ class TCGen05Test(TestCase):
in_jax_dtype,
out_jax_dtype,
):
i32 = ir.IntegerType.get_signless(32)
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
raise self.skipTest("Only f16 input is supported for f16 output.")
in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
out_mlir_dtype = utils.dtype_to_ir_type(out_jax_dtype)
m_tile = 128
nk_tile = swizzle // bytewidth(in_mlir_dtype)
k = nk_tile * k_steps
assert m % m_tile == 0 and n % nk_tile == 0
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers, tmem_addr_ref = scratch
lhs_smem, rhs_smem, barriers, acc = scratch
lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),)
if lhs_transpose:
assert nk_tile == m_tile # Make sure we didn't have to transpose tiling
@ -1004,21 +1002,16 @@ class TCGen05Test(TestCase):
)
barriers[0].wait()
barriers[1].wait()
with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, mgpu.warp_idx(), c(0, i32))):
tcgen05.tmem_alloc(tmem_addr_ref, n)
tcgen05.tmem_relinquish_alloc_permit()
acc = tcgen05.TMEMRef.from_alloc(tmem_addr_ref, tcgen05.TMEMLayout.D, n, out_mlir_dtype)
with mgpu.single_thread():
if lhs_transpose:
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
tcgen05.mma(
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False,
)
tcgen05.commit_arrive(barriers[2])
with mgpu.single_thread():
if lhs_transpose:
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
tcgen05.mma(
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False,
)
tcgen05.commit_arrive(barriers[2])
barriers[2].wait()
acc = tcgen05.TMEMRef.from_alloc(tmem_addr_ref, tcgen05.TMEMLayout.D, n, out_mlir_dtype)
acc[:].store_untiled(out)
in_finfo = jnp.finfo(in_jax_dtype)
@ -1036,7 +1029,7 @@ class TCGen05Test(TestCase):
jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype),
mgpu.TMABarrier(3),
jax.ShapeDtypeStruct((), jnp.int32),
mgpu.TMEM((128, n), out_jax_dtype, tcgen05.TMEMLayout.D),
]
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape