mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Handle TMEM allocation in the compiler
The code for allocation is uninteresting and it's the only set of primitives that is executed by a single warp (other TMA APIs have single-thread or warpgroup issue granularity). PiperOrigin-RevId: 725583720
This commit is contained in:
parent
5a2235163b
commit
0209eee185
@ -16,11 +16,15 @@
|
||||
from jax import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401
|
||||
|
||||
# The imports below shadow the module, so we need to rename it.
|
||||
from . import wgmma as _wgmma # noqa: F401
|
||||
|
||||
from .core import (
|
||||
Barrier as Barrier,
|
||||
ClusterBarrier as ClusterBarrier,
|
||||
TMABarrier as TMABarrier,
|
||||
ThreadSemantics as ThreadSemantics,
|
||||
TMEM as TMEM,
|
||||
Union as Union,
|
||||
as_gpu_kernel as as_gpu_kernel,
|
||||
)
|
||||
@ -85,8 +89,6 @@ from .utils import (
|
||||
warpgroup_idx as warpgroup_idx,
|
||||
when as when,
|
||||
)
|
||||
# The import below shadows the module, so we need to rename it.
|
||||
from . import wgmma as _wgmma # noqa: F401
|
||||
from .wgmma import (
|
||||
WGMMAAccumulator as WGMMAAccumulator,
|
||||
wgmma as wgmma,
|
||||
|
@ -18,12 +18,13 @@ import contextlib
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
import weakref
|
||||
|
||||
import jax
|
||||
@ -31,6 +32,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.lib import mosaic_gpu_dialect as dialect
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir import passmanager
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import builtin
|
||||
from jaxlib.mlir.dialects import func
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
@ -49,6 +51,7 @@ else:
|
||||
from . import profiler
|
||||
from . import utils
|
||||
from . import launch_context
|
||||
from . import tcgen05
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
@ -163,6 +166,19 @@ class ClusterBarrier:
|
||||
collective_dims: Sequence[gpu.Dimension]
|
||||
num_barriers: int = 1
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TMEM:
|
||||
shape: tuple[int, int]
|
||||
dtype: Any
|
||||
layout: tcgen05.TMEMLayout
|
||||
|
||||
def __post_init__(self):
|
||||
if self.shape[0] != self.layout.num_rows:
|
||||
raise ValueError(
|
||||
f"Row must match layout={self.layout} ({self.layout.num_rows}), but"
|
||||
f" got {self.shape[0]}"
|
||||
)
|
||||
|
||||
|
||||
def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
|
||||
return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize
|
||||
@ -179,10 +195,12 @@ def _construct_smem_reftree(
|
||||
cluster_shape: tuple[int, int, int],
|
||||
dynamic_smem: ir.Value,
|
||||
smem_buffers: ShapeTree,
|
||||
delayed_warp_init: list[Callable[[], None]], # Mutated by this function!
|
||||
dynamic_smem_offset: int = 0,
|
||||
) -> RefTree:
|
||||
) -> Callable[[], RefTree]:
|
||||
index = ir.IndexType.get()
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
|
||||
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
|
||||
@ -205,13 +223,17 @@ def _construct_smem_reftree(
|
||||
return barrier_base_ptr
|
||||
match ref_ty:
|
||||
case Union(members):
|
||||
member_trees = [
|
||||
_construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset)
|
||||
member_thunks = [
|
||||
_construct_smem_reftree(
|
||||
cluster_shape, dynamic_smem, m,
|
||||
delayed_warp_init, dynamic_smem_offset,
|
||||
)
|
||||
for m in members
|
||||
]
|
||||
# TODO(apaszke): This is quadratic, but it shouldn't matter for now...
|
||||
dynamic_smem_offset += _smem_tree_size(ref_ty)
|
||||
ref = Union(member_trees)
|
||||
def ref(member_thunks=member_thunks):
|
||||
return Union([t() for t in member_thunks])
|
||||
case TMABarrier(num_barriers):
|
||||
ref = utils.BarrierRef.initialize(
|
||||
get_barrier_ptr(num_barriers), num_barriers, arrival_count=1
|
||||
@ -229,6 +251,20 @@ def _construct_smem_reftree(
|
||||
collective_dims,
|
||||
cluster_shape,
|
||||
)
|
||||
case TMEM(shape, dtype, layout):
|
||||
addr_ref = memref.view(
|
||||
ir.MemRefType.get([], i32, memory_space=smem),
|
||||
dynamic_smem, c(dynamic_smem_offset, index), [],
|
||||
)
|
||||
delayed_warp_init.append(
|
||||
functools.partial(tcgen05.tmem_alloc, addr_ref, shape[1], exact=False)
|
||||
)
|
||||
def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout):
|
||||
addr = memref.load(addr_ref, [])
|
||||
return tcgen05.TMEMRef(
|
||||
addr, layout, shape[1], utils.dtype_to_ir_type(dtype)
|
||||
)
|
||||
dynamic_smem_offset += 4 # i32 takes up 4 bytes
|
||||
case _:
|
||||
mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype)
|
||||
tile_smem = memref.view(
|
||||
@ -238,7 +274,14 @@ def _construct_smem_reftree(
|
||||
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
|
||||
ref = tile_smem
|
||||
smem_refs.append(ref)
|
||||
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
|
||||
def ref_tree_thunk():
|
||||
refs = []
|
||||
for ref in smem_refs:
|
||||
if callable(ref):
|
||||
ref = ref()
|
||||
refs.append(ref)
|
||||
return jax.tree.unflatten(smem_buffer_tree, refs)
|
||||
return ref_tree_thunk
|
||||
|
||||
|
||||
def _smem_tree_size(smem_buffers: ShapeTree) -> int:
|
||||
@ -258,6 +301,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
|
||||
if size % utils.MBARRIER_BYTES:
|
||||
raise NotImplementedError("Misaligned barrier allocation")
|
||||
size += num_barriers * utils.MBARRIER_BYTES
|
||||
case TMEM(_):
|
||||
size += 4 # i32 takes up 4 bytes
|
||||
case _:
|
||||
size += _count_buffer_bytes(l)
|
||||
return size
|
||||
@ -336,15 +381,25 @@ def _launch(
|
||||
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
|
||||
ctx = launch_context.LaunchContext(launch_op, scratch_ptr, cluster, prof)
|
||||
with ctx.named_region("Init"):
|
||||
smem_ref_tree = _construct_smem_reftree(
|
||||
cluster, dynamic_smem, smem_buffers
|
||||
delayed_warp_init = []
|
||||
smem_ref_tree_thunk = _construct_smem_reftree(
|
||||
cluster, dynamic_smem, smem_buffers, delayed_warp_init
|
||||
)
|
||||
# TODO(apaszke): Skip the following if no barriers were initialized.
|
||||
# TODO(apaszke): Skip fences if no barriers or TMEM is initialized.
|
||||
# TODO(apaszke): Only initialize cluster barriers before the cluster wait.
|
||||
nvvm.fence_mbarrier_init()
|
||||
if math.prod(cluster) != 1:
|
||||
nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get())
|
||||
nvvm.cluster_wait(aligned=ir.UnitAttr.get())
|
||||
gpu.barrier()
|
||||
if delayed_warp_init:
|
||||
eq = arith.CmpIPredicate.eq
|
||||
is_init_warp = arith.cmpi(eq, utils.warp_idx(sync=False), c(0, i32))
|
||||
with utils.when(is_init_warp):
|
||||
for init in delayed_warp_init:
|
||||
init()
|
||||
tcgen05.tmem_relinquish_alloc_permit()
|
||||
gpu.barrier() # Make sure the init is visible to all threads.
|
||||
smem_ref_tree = smem_ref_tree_thunk()
|
||||
|
||||
yield ctx, smem_ref_tree
|
||||
if prof is not None:
|
||||
|
@ -21,7 +21,7 @@ from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax.experimental.mosaic import gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import c, ds, utils
|
||||
from jax.experimental.mosaic.gpu import c, ds
|
||||
from jax.experimental.mosaic.gpu import tcgen05
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jr
|
||||
@ -65,7 +65,7 @@ def build_kernel(
|
||||
tma_tile_kn = 64
|
||||
|
||||
def kernel(ctx, a, b, d, smem):
|
||||
a_smem, b_smem, d_smem, barriers, mma_done_barrier, tmem_addr = smem
|
||||
a_smem, b_smem, d_smem, barriers, mma_done_barrier, acc = smem
|
||||
(ab_full_barriers, ab_empty_barriers) = barriers
|
||||
|
||||
warp_idx = mgpu.warp_idx(sync=True)
|
||||
@ -109,18 +109,14 @@ def build_kernel(
|
||||
**common_args,
|
||||
)
|
||||
|
||||
with mgpu.when(is_warp(MMA_WARP)):
|
||||
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
|
||||
tcgen05.tmem_alloc(tmem_addr_addr, tile_n)
|
||||
tcgen05.tmem_relinquish_alloc_permit()
|
||||
tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
|
||||
with mgpu.when(arith.andi(is_warp(MMA_WARP), warp_leader)):
|
||||
with mgpu.when(warp_leader):
|
||||
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
|
||||
def _mma_body(ki, accumulate):
|
||||
slot = arith.remui(ki, c(max_concurrent_steps, index))
|
||||
ab_full_barriers[slot].wait()
|
||||
tcgen05.mma(
|
||||
tmem_ref,
|
||||
acc,
|
||||
mgpu.memref_slice(a_smem, slot),
|
||||
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)),
|
||||
a_swizzle=swizzle,
|
||||
@ -142,9 +138,9 @@ def build_kernel(
|
||||
gpu.barrier()
|
||||
mma_done_barrier.wait()
|
||||
|
||||
tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
|
||||
tmem_ref[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
|
||||
acc[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
|
||||
mgpu.commit_shared()
|
||||
# TODO(apaszke): Free up TMEM?
|
||||
ctx.async_copy(
|
||||
src_ref=d_smem,
|
||||
dst_ref=d,
|
||||
@ -161,7 +157,7 @@ def build_kernel(
|
||||
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
|
||||
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
|
||||
mgpu.Barrier(arrival_count=1),
|
||||
jax.ShapeDtypeStruct((1,), np.uint32), # TMEM address
|
||||
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D),
|
||||
)
|
||||
return mgpu.as_gpu_kernel(
|
||||
kernel,
|
||||
|
@ -297,6 +297,12 @@ class TMEMLayout(enum.Enum):
|
||||
"""
|
||||
D = "D"
|
||||
|
||||
@property
|
||||
def num_rows(self) -> int:
|
||||
match self:
|
||||
case TMEMLayout.D:
|
||||
return 128
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TMEMRef:
|
||||
@ -327,11 +333,7 @@ class TMEMRef:
|
||||
|
||||
@property
|
||||
def num_rows(self):
|
||||
match self.layout:
|
||||
case TMEMLayout.D:
|
||||
return 128
|
||||
case _:
|
||||
raise NotImplementedError(self.layout)
|
||||
return self.layout.num_rows
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
@ -968,19 +968,17 @@ class TCGen05Test(TestCase):
|
||||
in_jax_dtype,
|
||||
out_jax_dtype,
|
||||
):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
|
||||
raise self.skipTest("Only f16 input is supported for f16 output.")
|
||||
|
||||
in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
|
||||
out_mlir_dtype = utils.dtype_to_ir_type(out_jax_dtype)
|
||||
m_tile = 128
|
||||
nk_tile = swizzle // bytewidth(in_mlir_dtype)
|
||||
k = nk_tile * k_steps
|
||||
assert m % m_tile == 0 and n % nk_tile == 0
|
||||
|
||||
def kernel(ctx, lhs, rhs, out, scratch):
|
||||
lhs_smem, rhs_smem, barriers, tmem_addr_ref = scratch
|
||||
lhs_smem, rhs_smem, barriers, acc = scratch
|
||||
lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),)
|
||||
if lhs_transpose:
|
||||
assert nk_tile == m_tile # Make sure we didn't have to transpose tiling
|
||||
@ -1004,21 +1002,16 @@ class TCGen05Test(TestCase):
|
||||
)
|
||||
barriers[0].wait()
|
||||
barriers[1].wait()
|
||||
with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, mgpu.warp_idx(), c(0, i32))):
|
||||
tcgen05.tmem_alloc(tmem_addr_ref, n)
|
||||
tcgen05.tmem_relinquish_alloc_permit()
|
||||
acc = tcgen05.TMEMRef.from_alloc(tmem_addr_ref, tcgen05.TMEMLayout.D, n, out_mlir_dtype)
|
||||
with mgpu.single_thread():
|
||||
if lhs_transpose:
|
||||
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))
|
||||
if rhs_transpose:
|
||||
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
|
||||
tcgen05.mma(
|
||||
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False,
|
||||
)
|
||||
tcgen05.commit_arrive(barriers[2])
|
||||
with mgpu.single_thread():
|
||||
if lhs_transpose:
|
||||
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))
|
||||
if rhs_transpose:
|
||||
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
|
||||
tcgen05.mma(
|
||||
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False,
|
||||
)
|
||||
tcgen05.commit_arrive(barriers[2])
|
||||
barriers[2].wait()
|
||||
acc = tcgen05.TMEMRef.from_alloc(tmem_addr_ref, tcgen05.TMEMLayout.D, n, out_mlir_dtype)
|
||||
acc[:].store_untiled(out)
|
||||
|
||||
in_finfo = jnp.finfo(in_jax_dtype)
|
||||
@ -1036,7 +1029,7 @@ class TCGen05Test(TestCase):
|
||||
jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype),
|
||||
jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype),
|
||||
mgpu.TMABarrier(3),
|
||||
jax.ShapeDtypeStruct((), jnp.int32),
|
||||
mgpu.TMEM((128, n), out_jax_dtype, tcgen05.TMEMLayout.D),
|
||||
]
|
||||
z = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
|
||||
|
Loading…
x
Reference in New Issue
Block a user