[Mosaic GPU] Define TMEMLayout without referring to the PTX guide

The PTX guide talks about a few layouts by assigning them different
letters, which do not have an obvious meaning. We redefine the layout
by parameterizing it with a 2D tile size which, as far as I can tell,
is sufficient to represent all layouts we care about.

PiperOrigin-RevId: 726833412
This commit is contained in:
Adam Paszke 2025-02-14 02:05:39 -08:00 committed by jax authors
parent a0812cd57e
commit 4a8023fe1e
4 changed files with 127 additions and 58 deletions

View File

@ -170,15 +170,12 @@ class ClusterBarrier:
class TMEM:
shape: tuple[int, int]
dtype: Any
layout: tcgen05.TMEMLayout
layout: tcgen05.TMEMLayout | None = None
collective: bool = False
def __post_init__(self):
if self.shape[0] != self.layout.num_rows:
raise ValueError(
f"Row must match layout={self.layout} ({self.layout.num_rows}), but"
f" got {self.shape[0]}"
)
if self.layout is not None:
self.layout.check_shape(self.shape)
def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
@ -226,15 +223,20 @@ def _construct_smem_reftree(
case Union(members):
member_thunks = [
_construct_smem_reftree(
cluster_shape, dynamic_smem, m,
delayed_warp_init, dynamic_smem_offset,
cluster_shape,
dynamic_smem,
m,
delayed_warp_init,
dynamic_smem_offset,
)
for m in members
]
# TODO(apaszke): This is quadratic, but it shouldn't matter for now...
dynamic_smem_offset += _smem_tree_size(ref_ty)
def ref(member_thunks=member_thunks):
return Union([t() for t in member_thunks])
case TMABarrier(num_barriers):
ref = utils.BarrierRef.initialize(
get_barrier_ptr(num_barriers), num_barriers, arrival_count=1
@ -257,13 +259,19 @@ def _construct_smem_reftree(
ir.MemRefType.get([], i32, memory_space=smem),
dynamic_smem, c(dynamic_smem_offset, index), [],
)
if layout is None:
layout = tcgen05._infer_tmem_layout(shape)
num_cols = layout.cols_in_shape(shape)
delayed_warp_init.append(
functools.partial(tcgen05.tmem_alloc, addr_ref, shape[1], collective=collective, exact=False)
functools.partial(
tcgen05.tmem_alloc,
addr_ref, num_cols, collective=collective, exact=False,
)
)
def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout):
addr = memref.load(addr_ref, [])
return tcgen05.TMEMRef(
addr, layout, shape[1], utils.dtype_to_ir_type(dtype)
addr, shape, utils.dtype_to_ir_type(dtype), layout
)
dynamic_smem_offset += 4 # i32 takes up 4 bytes
case _:

View File

@ -191,11 +191,12 @@ def build_kernel(
mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)),
jnp.float16)
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
assert block_tile_m == 128
smem = (
smem_buffers,
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
mgpu.Barrier(arrival_count=1),
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D, collective=collective),
mgpu.TMEM((128, tile_n), jnp.float32, collective=collective),
)
return mgpu.as_gpu_kernel(
kernel,
@ -212,27 +213,32 @@ def build_kernel(
def main(unused_argv):
m, k, n = 8192, 4096, 2048
m, k, n = 8192, 4096, 8192
ka, kb = jr.split(jr.key(0), 2)
a = jr.normal(key=ka, shape=(m, k), dtype=jnp.float16)
b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16)
tile_m = tile_n = (128,)
tile_m = (128,)
tile_n = (128, 256, 512)
max_concurrent_steps = (2, 4, 5, 6)
grid_tile_m = (1, 2, 4, 8, 16)
collective = (False, True)
configs = itertools.product(tile_m, tile_n, max_concurrent_steps, grid_tile_m, collective)
names = ("tile_m", "tile_n", "max_concurrent_steps", "grid_tile_m", "collective")
configs = itertools.product(collective, tile_m, tile_n, grid_tile_m, max_concurrent_steps)
names = ("collective", "tile_m", "tile_n", "grid_tile_m", "max_concurrent_steps")
best_runtime = float("inf")
best_kwargs = {}
for config in configs:
kwargs = dict(zip(names, config))
tile_m = kwargs["tile_m"]
tile_n = kwargs["tile_n"]
if kwargs["collective"]:
tile_m *= 2
if m < tile_m or n < kwargs["tile_n"]:
tile_n *= 2
if m < tile_m or n < tile_n:
continue
if kwargs["collective"] and tile_n >= 512:
continue # TODO(apaszke): Support 512
if (m // tile_m) % kwargs["grid_tile_m"]:
continue
try:

View File

@ -16,7 +16,6 @@
from __future__ import annotations
import dataclasses
import enum
import math
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
@ -35,6 +34,7 @@ from .launch_context import LaunchContext
# mypy: ignore-errors
TMEM_ROWS = 128
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
def create_smem_descriptor(
@ -129,8 +129,12 @@ def mma(
)
# The sizes of instruction we'll be using
m_instr_tiling = m_mem_tiling
k_instr_tiling = kn_mem_tiling
if (m_instr_tiling := d.layout.elements_in_tile[0]) != m_mem_tiling:
raise ValueError(
f"A's row tiling must be equal to {m_instr_tiling} (inferred from"
f" accumulator's TMEM layout), got: {m_mem_tiling}"
)
if n * num_cta <= 256:
n_instr_tiling = n
elif n * num_cta == 512:
@ -140,14 +144,6 @@ def mma(
else:
raise NotImplementedError("The only supported N larger than 256 is 512")
# TODO(apaszke): It's enough to make this a multiple of d.num_rows, but it
# would need more code below.
if m_instr_tiling != d.num_rows:
raise ValueError(
f"A's row tiling must be a multiple of {d.num_rows} (inferred from"
f" accumulator's TMEM layout), got: {m_instr_tiling}"
)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_m_byte_stride = a_strides[0] * utils.bytewidth(element_type)
b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset()
@ -336,29 +332,86 @@ def tmem_load(tmem_addr, shape, num):
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]
class TMEMLayout(enum.Enum):
"""Layout of the array in TMEM.
@dataclasses.dataclass(frozen=True)
class TMEMLayout:
"""Represents the way a shape is laid out in TMEM.
See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-organization
Only 2D shapes are supported. Row tiling must be between 32 and 128, and be
a power of 2. If the row tiling is smaller than 128 (the row count in TMEM),
the tiles are linearized in row-major order, but laid out in TMEM in a
column-major order.
Consider an array that is (128, 128) and we apply tiling of (64, 64):
+------------------+------------------+
| [0:64, 0:64] | [0:64, 64:128] |
+------------------+------------------+
| [64:128, 0:64] | [64:128, 64:128] |
+------------------+------------------+
In TMEM it will be laid out as follows:
+------------------+------------------+
| [0:64, 0:64] | [64:128, 0:64] |
+------------------+------------------+
| [0:64, 64:128] | [64:128, 64:128] |
+------------------+------------------+
"""
D = "D"
elements_in_tile: tuple[int, int]
@property
def num_rows(self) -> int:
match self:
case TMEMLayout.D:
return 128
def __post_init__(self):
row_tiling = self.elements_in_tile[0]
if not 32 <= row_tiling <= 128:
raise ValueError(
f"Row tiling must be between 32 and 128, got: {row_tiling}"
)
if row_tiling.bit_count() != 1:
raise ValueError(f"Row tiling must be a power of 2, got: {row_tiling}")
def check_shape(self, shape: tuple[int, ...]):
if len(shape) != 2:
raise ValueError(f"TMEM can only represent 2D shapes, got {shape}")
if any(s % t for s, t in zip(shape, self.elements_in_tile)):
raise ValueError(
f"{shape} is divisible into tiles of shape {self.elements_in_tile}"
)
def cols_in_shape(self, shape: tuple[int, int]):
cols_in_tile = self.elements_in_tile[1]
tiles_in_row = TMEM_ROWS // self.elements_in_tile[0]
num_tiles = math.prod(utils.tile_shape(shape, self.elements_in_tile)[:-2])
assert num_tiles % tiles_in_row == 0
return num_tiles // tiles_in_row * cols_in_tile
def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
if shape[0] > TMEM_ROWS:
raise ValueError(
"Can only infer TMEM layout for shapes with at most 128 rows, got:"
f" {shape[0]}"
)
if shape[0] < 32:
raise ValueError(
"Can only infer TMEM layout for shapes with at least 32 rows, got:"
f" {shape[0]}"
)
if shape[0].bit_count() != 1:
raise ValueError(
"Can only infer TMEM layout for shapes with row count that's a power of"
f" 2, got: {shape[0]}"
)
return TMEMLayout(elements_in_tile=(shape[0], 1))
@dataclasses.dataclass(frozen=True)
class TMEMRef:
address: ir.Value
layout: TMEMLayout
num_cols: int
shape: tuple[int, int]
dtype: ir.Type
layout: TMEMLayout
@classmethod
def from_alloc(cls, tmem_addr_ref: ir.Value, layout: TMEMLayout, num_cols: int, dtype: ir.Type):
def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layout: TMEMLayout | None = None):
i32 = ir.IntegerType.get_signless(32)
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
@ -372,34 +425,34 @@ class TMEMRef:
raise ValueError(f"tmem_addr_ref must contain a single element, got: {addr_ref_ty}")
i0 = arith.ConstantOp.create_index(0)
tmem_addr = memref.load(tmem_addr_ref, [i0] * addr_ref_ty.rank)
if shape[0] < 32:
raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}")
if layout is None:
layout = _infer_tmem_layout(shape)
else:
layout.check_shape(shape)
# TODO: Do we have to do this??
# warp_idx = utils.warp_idx(sync=False)
# tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32)))
return cls(tmem_addr, layout, num_cols, dtype)
@property
def num_rows(self):
return self.layout.num_rows
@property
def shape(self):
return (self.num_rows, self.num_cols)
return cls(tmem_addr, shape, dtype, layout)
def slice(self, *idxs):
base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
if self.layout != TMEMLayout.D:
raise NotImplementedError(self.layout)
if any(is_squeezed):
raise ValueError("TMEM can only be sliced, not indexed")
if base_idx[0] != 0 or slice_shape[0] != self.num_rows:
if self.layout.elements_in_tile[0] != TMEM_ROWS:
raise NotImplementedError(
f"Slicing only implemented for refs with tiling of {TMEM_ROWS} rows"
)
if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS:
raise NotImplementedError("TMEM cannot be sliced along rows")
col_idx = base_idx[1]
if not isinstance(col_idx, ir.Value):
col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx)
return TMEMRef(
address=arith.addi(self.address, col_idx),
shape=tuple(slice_shape),
layout=self.layout,
num_cols=slice_shape[1],
dtype=self.dtype,
)
@ -410,22 +463,24 @@ class TMEMRef:
raise ValueError("TMEM loads only support slicing")
if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape:
raise NotImplementedError("Slicing of TMEM not impelmented yet")
if self.layout != TMEMLayout.D:
raise NotImplementedError(self.layout)
if self.num_cols % 8:
if self.layout.elements_in_tile[0] != TMEM_ROWS:
raise NotImplementedError(
f"Loads only implemented for refs with tiling of {TMEM_ROWS} rows"
)
if self.shape[1] % 8:
raise NotImplementedError
if self.dtype != ir.F32Type.get():
raise NotImplementedError(self.dtype)
layout = _m128_256bit_32bit_layout(self.shape)
regs_shape = layout.registers_shape(self.shape)
num = self.num_cols // 8
num = self.shape[1] // 8
# TODO(apaszke): Make the tiling configurable through the args too.
if num <= 32:
num_tiling = num
elif num == 64:
num_tiling = 32
else:
raise NotImplementedError(f"num_cols={self.num_cols} is unsupported")
raise NotImplementedError(num)
registers = np.empty(regs_shape, dtype=object)
# We load 16 lanes at a time, but need 32 in total.
for row_group in range(2):

View File

@ -977,7 +977,7 @@ class TCGen05Test(TestCase):
jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype),
mgpu.TMABarrier(3),
mgpu.TMEM((128, n), out_jax_dtype, tcgen05.TMEMLayout.D),
mgpu.TMEM((128, n), out_jax_dtype),
]
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
@ -1080,7 +1080,7 @@ class TCGen05Test(TestCase):
jax.ShapeDtypeStruct(tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape((k, n_block_tile), (nk_tma_tile, nk_tma_tile)), in_jax_dtype),
mgpu.TMABarrier(3),
mgpu.TMEM((128, n), out_jax_dtype, tcgen05.TMEMLayout.D, collective=True),
mgpu.TMEM((128, n), out_jax_dtype, collective=True),
]
z = mgpu.as_gpu_kernel(
kernel, (2, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape, cluster=(2, 1, 1)