mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU] Define TMEMLayout without referring to the PTX guide
The PTX guide talks about a few layouts by assigning them different letters, which do not have an obvious meaning. We redefine the layout by parameterizing it with a 2D tile size which, as far as I can tell, is sufficient to represent all layouts we care about. PiperOrigin-RevId: 726833412
This commit is contained in:
parent
a0812cd57e
commit
4a8023fe1e
@ -170,15 +170,12 @@ class ClusterBarrier:
|
||||
class TMEM:
|
||||
shape: tuple[int, int]
|
||||
dtype: Any
|
||||
layout: tcgen05.TMEMLayout
|
||||
layout: tcgen05.TMEMLayout | None = None
|
||||
collective: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.shape[0] != self.layout.num_rows:
|
||||
raise ValueError(
|
||||
f"Row must match layout={self.layout} ({self.layout.num_rows}), but"
|
||||
f" got {self.shape[0]}"
|
||||
)
|
||||
if self.layout is not None:
|
||||
self.layout.check_shape(self.shape)
|
||||
|
||||
|
||||
def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
|
||||
@ -226,15 +223,20 @@ def _construct_smem_reftree(
|
||||
case Union(members):
|
||||
member_thunks = [
|
||||
_construct_smem_reftree(
|
||||
cluster_shape, dynamic_smem, m,
|
||||
delayed_warp_init, dynamic_smem_offset,
|
||||
cluster_shape,
|
||||
dynamic_smem,
|
||||
m,
|
||||
delayed_warp_init,
|
||||
dynamic_smem_offset,
|
||||
)
|
||||
for m in members
|
||||
]
|
||||
# TODO(apaszke): This is quadratic, but it shouldn't matter for now...
|
||||
dynamic_smem_offset += _smem_tree_size(ref_ty)
|
||||
|
||||
def ref(member_thunks=member_thunks):
|
||||
return Union([t() for t in member_thunks])
|
||||
|
||||
case TMABarrier(num_barriers):
|
||||
ref = utils.BarrierRef.initialize(
|
||||
get_barrier_ptr(num_barriers), num_barriers, arrival_count=1
|
||||
@ -257,13 +259,19 @@ def _construct_smem_reftree(
|
||||
ir.MemRefType.get([], i32, memory_space=smem),
|
||||
dynamic_smem, c(dynamic_smem_offset, index), [],
|
||||
)
|
||||
if layout is None:
|
||||
layout = tcgen05._infer_tmem_layout(shape)
|
||||
num_cols = layout.cols_in_shape(shape)
|
||||
delayed_warp_init.append(
|
||||
functools.partial(tcgen05.tmem_alloc, addr_ref, shape[1], collective=collective, exact=False)
|
||||
functools.partial(
|
||||
tcgen05.tmem_alloc,
|
||||
addr_ref, num_cols, collective=collective, exact=False,
|
||||
)
|
||||
)
|
||||
def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout):
|
||||
addr = memref.load(addr_ref, [])
|
||||
return tcgen05.TMEMRef(
|
||||
addr, layout, shape[1], utils.dtype_to_ir_type(dtype)
|
||||
addr, shape, utils.dtype_to_ir_type(dtype), layout
|
||||
)
|
||||
dynamic_smem_offset += 4 # i32 takes up 4 bytes
|
||||
case _:
|
||||
|
@ -191,11 +191,12 @@ def build_kernel(
|
||||
mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)),
|
||||
jnp.float16)
|
||||
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
|
||||
assert block_tile_m == 128
|
||||
smem = (
|
||||
smem_buffers,
|
||||
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
|
||||
mgpu.Barrier(arrival_count=1),
|
||||
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D, collective=collective),
|
||||
mgpu.TMEM((128, tile_n), jnp.float32, collective=collective),
|
||||
)
|
||||
return mgpu.as_gpu_kernel(
|
||||
kernel,
|
||||
@ -212,27 +213,32 @@ def build_kernel(
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
m, k, n = 8192, 4096, 2048
|
||||
m, k, n = 8192, 4096, 8192
|
||||
|
||||
ka, kb = jr.split(jr.key(0), 2)
|
||||
a = jr.normal(key=ka, shape=(m, k), dtype=jnp.float16)
|
||||
b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16)
|
||||
|
||||
tile_m = tile_n = (128,)
|
||||
tile_m = (128,)
|
||||
tile_n = (128, 256, 512)
|
||||
max_concurrent_steps = (2, 4, 5, 6)
|
||||
grid_tile_m = (1, 2, 4, 8, 16)
|
||||
collective = (False, True)
|
||||
configs = itertools.product(tile_m, tile_n, max_concurrent_steps, grid_tile_m, collective)
|
||||
names = ("tile_m", "tile_n", "max_concurrent_steps", "grid_tile_m", "collective")
|
||||
configs = itertools.product(collective, tile_m, tile_n, grid_tile_m, max_concurrent_steps)
|
||||
names = ("collective", "tile_m", "tile_n", "grid_tile_m", "max_concurrent_steps")
|
||||
best_runtime = float("inf")
|
||||
best_kwargs = {}
|
||||
for config in configs:
|
||||
kwargs = dict(zip(names, config))
|
||||
tile_m = kwargs["tile_m"]
|
||||
tile_n = kwargs["tile_n"]
|
||||
if kwargs["collective"]:
|
||||
tile_m *= 2
|
||||
if m < tile_m or n < kwargs["tile_n"]:
|
||||
tile_n *= 2
|
||||
if m < tile_m or n < tile_n:
|
||||
continue
|
||||
if kwargs["collective"] and tile_n >= 512:
|
||||
continue # TODO(apaszke): Support 512
|
||||
if (m // tile_m) % kwargs["grid_tile_m"]:
|
||||
continue
|
||||
try:
|
||||
|
@ -16,7 +16,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import math
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
|
||||
@ -35,6 +34,7 @@ from .launch_context import LaunchContext
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
TMEM_ROWS = 128
|
||||
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
|
||||
|
||||
def create_smem_descriptor(
|
||||
@ -129,8 +129,12 @@ def mma(
|
||||
)
|
||||
|
||||
# The sizes of instruction we'll be using
|
||||
m_instr_tiling = m_mem_tiling
|
||||
k_instr_tiling = kn_mem_tiling
|
||||
if (m_instr_tiling := d.layout.elements_in_tile[0]) != m_mem_tiling:
|
||||
raise ValueError(
|
||||
f"A's row tiling must be equal to {m_instr_tiling} (inferred from"
|
||||
f" accumulator's TMEM layout), got: {m_mem_tiling}"
|
||||
)
|
||||
if n * num_cta <= 256:
|
||||
n_instr_tiling = n
|
||||
elif n * num_cta == 512:
|
||||
@ -140,14 +144,6 @@ def mma(
|
||||
else:
|
||||
raise NotImplementedError("The only supported N larger than 256 is 512")
|
||||
|
||||
# TODO(apaszke): It's enough to make this a multiple of d.num_rows, but it
|
||||
# would need more code below.
|
||||
if m_instr_tiling != d.num_rows:
|
||||
raise ValueError(
|
||||
f"A's row tiling must be a multiple of {d.num_rows} (inferred from"
|
||||
f" accumulator's TMEM layout), got: {m_instr_tiling}"
|
||||
)
|
||||
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_m_byte_stride = a_strides[0] * utils.bytewidth(element_type)
|
||||
b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset()
|
||||
@ -336,29 +332,86 @@ def tmem_load(tmem_addr, shape, num):
|
||||
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]
|
||||
|
||||
|
||||
class TMEMLayout(enum.Enum):
|
||||
"""Layout of the array in TMEM.
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TMEMLayout:
|
||||
"""Represents the way a shape is laid out in TMEM.
|
||||
|
||||
See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-organization
|
||||
Only 2D shapes are supported. Row tiling must be between 32 and 128, and be
|
||||
a power of 2. If the row tiling is smaller than 128 (the row count in TMEM),
|
||||
the tiles are linearized in row-major order, but laid out in TMEM in a
|
||||
column-major order.
|
||||
|
||||
Consider an array that is (128, 128) and we apply tiling of (64, 64):
|
||||
|
||||
+------------------+------------------+
|
||||
| [0:64, 0:64] | [0:64, 64:128] |
|
||||
+------------------+------------------+
|
||||
| [64:128, 0:64] | [64:128, 64:128] |
|
||||
+------------------+------------------+
|
||||
|
||||
In TMEM it will be laid out as follows:
|
||||
|
||||
+------------------+------------------+
|
||||
| [0:64, 0:64] | [64:128, 0:64] |
|
||||
+------------------+------------------+
|
||||
| [0:64, 64:128] | [64:128, 64:128] |
|
||||
+------------------+------------------+
|
||||
"""
|
||||
D = "D"
|
||||
elements_in_tile: tuple[int, int]
|
||||
|
||||
@property
|
||||
def num_rows(self) -> int:
|
||||
match self:
|
||||
case TMEMLayout.D:
|
||||
return 128
|
||||
def __post_init__(self):
|
||||
row_tiling = self.elements_in_tile[0]
|
||||
if not 32 <= row_tiling <= 128:
|
||||
raise ValueError(
|
||||
f"Row tiling must be between 32 and 128, got: {row_tiling}"
|
||||
)
|
||||
if row_tiling.bit_count() != 1:
|
||||
raise ValueError(f"Row tiling must be a power of 2, got: {row_tiling}")
|
||||
|
||||
def check_shape(self, shape: tuple[int, ...]):
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"TMEM can only represent 2D shapes, got {shape}")
|
||||
if any(s % t for s, t in zip(shape, self.elements_in_tile)):
|
||||
raise ValueError(
|
||||
f"{shape} is divisible into tiles of shape {self.elements_in_tile}"
|
||||
)
|
||||
|
||||
def cols_in_shape(self, shape: tuple[int, int]):
|
||||
cols_in_tile = self.elements_in_tile[1]
|
||||
tiles_in_row = TMEM_ROWS // self.elements_in_tile[0]
|
||||
num_tiles = math.prod(utils.tile_shape(shape, self.elements_in_tile)[:-2])
|
||||
assert num_tiles % tiles_in_row == 0
|
||||
return num_tiles // tiles_in_row * cols_in_tile
|
||||
|
||||
|
||||
def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
|
||||
if shape[0] > TMEM_ROWS:
|
||||
raise ValueError(
|
||||
"Can only infer TMEM layout for shapes with at most 128 rows, got:"
|
||||
f" {shape[0]}"
|
||||
)
|
||||
if shape[0] < 32:
|
||||
raise ValueError(
|
||||
"Can only infer TMEM layout for shapes with at least 32 rows, got:"
|
||||
f" {shape[0]}"
|
||||
)
|
||||
if shape[0].bit_count() != 1:
|
||||
raise ValueError(
|
||||
"Can only infer TMEM layout for shapes with row count that's a power of"
|
||||
f" 2, got: {shape[0]}"
|
||||
)
|
||||
return TMEMLayout(elements_in_tile=(shape[0], 1))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TMEMRef:
|
||||
address: ir.Value
|
||||
layout: TMEMLayout
|
||||
num_cols: int
|
||||
shape: tuple[int, int]
|
||||
dtype: ir.Type
|
||||
layout: TMEMLayout
|
||||
|
||||
@classmethod
|
||||
def from_alloc(cls, tmem_addr_ref: ir.Value, layout: TMEMLayout, num_cols: int, dtype: ir.Type):
|
||||
def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layout: TMEMLayout | None = None):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
|
||||
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
|
||||
@ -372,34 +425,34 @@ class TMEMRef:
|
||||
raise ValueError(f"tmem_addr_ref must contain a single element, got: {addr_ref_ty}")
|
||||
i0 = arith.ConstantOp.create_index(0)
|
||||
tmem_addr = memref.load(tmem_addr_ref, [i0] * addr_ref_ty.rank)
|
||||
if shape[0] < 32:
|
||||
raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}")
|
||||
if layout is None:
|
||||
layout = _infer_tmem_layout(shape)
|
||||
else:
|
||||
layout.check_shape(shape)
|
||||
# TODO: Do we have to do this??
|
||||
# warp_idx = utils.warp_idx(sync=False)
|
||||
# tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32)))
|
||||
return cls(tmem_addr, layout, num_cols, dtype)
|
||||
|
||||
@property
|
||||
def num_rows(self):
|
||||
return self.layout.num_rows
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return (self.num_rows, self.num_cols)
|
||||
return cls(tmem_addr, shape, dtype, layout)
|
||||
|
||||
def slice(self, *idxs):
|
||||
base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
|
||||
if self.layout != TMEMLayout.D:
|
||||
raise NotImplementedError(self.layout)
|
||||
if any(is_squeezed):
|
||||
raise ValueError("TMEM can only be sliced, not indexed")
|
||||
if base_idx[0] != 0 or slice_shape[0] != self.num_rows:
|
||||
if self.layout.elements_in_tile[0] != TMEM_ROWS:
|
||||
raise NotImplementedError(
|
||||
f"Slicing only implemented for refs with tiling of {TMEM_ROWS} rows"
|
||||
)
|
||||
if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS:
|
||||
raise NotImplementedError("TMEM cannot be sliced along rows")
|
||||
col_idx = base_idx[1]
|
||||
if not isinstance(col_idx, ir.Value):
|
||||
col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx)
|
||||
return TMEMRef(
|
||||
address=arith.addi(self.address, col_idx),
|
||||
shape=tuple(slice_shape),
|
||||
layout=self.layout,
|
||||
num_cols=slice_shape[1],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@ -410,22 +463,24 @@ class TMEMRef:
|
||||
raise ValueError("TMEM loads only support slicing")
|
||||
if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape:
|
||||
raise NotImplementedError("Slicing of TMEM not impelmented yet")
|
||||
if self.layout != TMEMLayout.D:
|
||||
raise NotImplementedError(self.layout)
|
||||
if self.num_cols % 8:
|
||||
if self.layout.elements_in_tile[0] != TMEM_ROWS:
|
||||
raise NotImplementedError(
|
||||
f"Loads only implemented for refs with tiling of {TMEM_ROWS} rows"
|
||||
)
|
||||
if self.shape[1] % 8:
|
||||
raise NotImplementedError
|
||||
if self.dtype != ir.F32Type.get():
|
||||
raise NotImplementedError(self.dtype)
|
||||
layout = _m128_256bit_32bit_layout(self.shape)
|
||||
regs_shape = layout.registers_shape(self.shape)
|
||||
num = self.num_cols // 8
|
||||
num = self.shape[1] // 8
|
||||
# TODO(apaszke): Make the tiling configurable through the args too.
|
||||
if num <= 32:
|
||||
num_tiling = num
|
||||
elif num == 64:
|
||||
num_tiling = 32
|
||||
else:
|
||||
raise NotImplementedError(f"num_cols={self.num_cols} is unsupported")
|
||||
raise NotImplementedError(num)
|
||||
registers = np.empty(regs_shape, dtype=object)
|
||||
# We load 16 lanes at a time, but need 32 in total.
|
||||
for row_group in range(2):
|
||||
|
@ -977,7 +977,7 @@ class TCGen05Test(TestCase):
|
||||
jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype),
|
||||
jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype),
|
||||
mgpu.TMABarrier(3),
|
||||
mgpu.TMEM((128, n), out_jax_dtype, tcgen05.TMEMLayout.D),
|
||||
mgpu.TMEM((128, n), out_jax_dtype),
|
||||
]
|
||||
z = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
|
||||
@ -1080,7 +1080,7 @@ class TCGen05Test(TestCase):
|
||||
jax.ShapeDtypeStruct(tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)), in_jax_dtype),
|
||||
jax.ShapeDtypeStruct(tile_shape((k, n_block_tile), (nk_tma_tile, nk_tma_tile)), in_jax_dtype),
|
||||
mgpu.TMABarrier(3),
|
||||
mgpu.TMEM((128, n), out_jax_dtype, tcgen05.TMEMLayout.D, collective=True),
|
||||
mgpu.TMEM((128, n), out_jax_dtype, collective=True),
|
||||
]
|
||||
z = mgpu.as_gpu_kernel(
|
||||
kernel, (2, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape, cluster=(2, 1, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user