[Mosaic GPU] Introduce a more flexible layout system

So far all of our layouts have been tailored to a limited set of use
cases we've tried so far, but they're still not general enough to
handle all of the register layouts needed for WGMMA or mixed precision
matmuls (incl. intermediate steps during conversions). Instead of adding
more special cases, I decided to adopt XLA tiled layouts and they do seem
to work quite well!

This change only lays the groundwork for the new layout system. Future
changes will build upon them to add new features and eventually replace
`WGMMA_LAYOUT` altogether.

PiperOrigin-RevId: 694105514
This commit is contained in:
Adam Paszke 2024-11-07 07:08:07 -08:00 committed by jax authors
parent f8dba3c8a4
commit de06584d98
3 changed files with 361 additions and 37 deletions

View File

@ -14,10 +14,12 @@
# ==============================================================================
"""Utilities for code generator."""
from __future__ import annotations
import dataclasses
import functools
import math
from typing import Callable
from typing import Sequence, TypeVar, Iterable
import jax
from jaxlib.mlir import ir
@ -35,10 +37,276 @@ from . import utils
# mypy: ignore-errors
T = TypeVar("T")
WARPGROUP_SIZE = utils.WARPGROUP_SIZE
WARP_SIZE = 32
WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE
c = utils.c
@dataclasses.dataclass(frozen=True)
class Tiling:
"""A tiling expression describing a permutation of elements of an nd-array.
To apply one level of tiling to an array, each of the trailing dimensions (up
to the rank of the tile) is unfolded into two dimensions: first equal to the
ratio of the dimension size and the tile size, and second equal to the tile
size. Then, all newly unfolded minor dimensions are transposed to appear at
the end.
This expression describes multi-level tiling, by applying each element of
`tiles` in sequence to the array.
See https://openxla.org/xla/tiled_layout for a more detailed explanation.
"""
tiles: tuple[tuple[int, ...], ...]
def __post_init__(self):
max_rank = math.inf
for tile in self.tiles:
if not tile:
raise ValueError("Tiles must not be empty")
if len(tile) > max_rank:
raise ValueError("Tile ranks must be non-increasing")
max_rank = len(tile)
if any(d <= 0 for d in tile):
raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}")
def __str__(self):
return f"Tiling({''.join(map(str, self.tiles))})"
def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
"""Computes the shape of an array after tiling."""
def fail():
raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}")
for tile in self.tiles:
if len(tile) > len(shape):
fail()
untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):]
if any(s % t != 0 for s, t in zip(tiled_dims, tile)):
fail()
shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile)
return shape
def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
"""Computes the shape of an array before tiling from its tiled shape."""
def fail():
raise ValueError("Shape does not look like it's been tiled?")
for tile in reversed(self.tiles):
if len(tile) > len(shape):
fail()
untiled_dims = shape[:-2 * len(tile)]
tiled_dims = shape[-2 * len(tile):-len(tile)]
tiling_dims = shape[-len(tile):]
if tiling_dims != tile:
fail()
shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile)))
return shape
def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]:
"""Computes the strides of an array after tiling."""
for tile in self.tiles:
untiled, tiled = strides[:-len(tile)], strides[-len(tile):]
strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled)
return strides
def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]:
"""Like built-in enumerate, but returns negative indices into the sequence."""
offset = len(elems)
for i, e in enumerate(elems):
yield i - offset, e
@dataclasses.dataclass(frozen=True)
class TiledLayout:
"""A FragmentedArray layout derived from a tiling expression.
A logical array is transformed according to the tiling expression, and then
split across warps (within a warpgroup), lanes, and vectorized according to
the dimension indices. All dimension indices must be negative and should refer
to the dimensions after tiling is applied.
Note that warp_dim and vector_dim could be sets as well, but we don't have a
usecase for that yet.
To better understand this layout, consider the example of WGMMA-related tiling
from https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d as
applied to a 128x128 array. The corresponding TiledLayout has a tiling of:
(64, 8)(16, 8)(8, 8)(1, 2)
and warp_dim=-8, lane_dims={-4, -3}, vector_dim=-1.
We begin by applying the tiling (note that it always applies to a suffix):
Tiled shape Remaining tiling actions
===========================================================================
128 128 (64, 8)(16, 8)(8, 8)(1, 2)
2 16 64 8 (16, 8)(8, 8)(1, 2)
2 16 4 1 16 8 (8, 8)(1, 2)
2 16 4 1 2 1 8 8 (1, 2)
2 16 4 1 2 1 8 4 1 2
The last expression is our final shape. At this stage, we're ready to
interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the
end is partitioned over 4 warps in a warpgroup (and so it must be of size 4).
lane_dims={-4, -3} indicate that those two dimensions are partitioned over
the lanes within a warp (their product must be equal to 32, i.e. warp size).
Finally, vector_dim=-1 indicates that each (logical) register is a vector
containing 2 elements (there are no shape restrictions here).
Given the above, the shape of the (logical) register array used to represent
the array in each thread is: (2, 16, 1, 1, 2, 1, 1, 1, 1, 1). We have set all
the dimensions above to 1, since each thread is a member of a single warp,
a single lane, and the elements along the vectorized dimension are represented
by a single (logical) register.
"""
tiling: Tiling
warp_dim: int
lane_dims: frozenset[int]
vector_dim: int
def __post_init__(self):
if not self.tiling.tiles:
raise ValueError("Tiling must have at least one tile")
min_shape = self.tiling.tiles[0]
min_tiled_shape = self.tiling.tile_shape(min_shape)
dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim}
if len(dims_set) != len(self.lane_dims) + 2:
raise ValueError
for d in dims_set:
if d >= 0:
raise ValueError("All dimensions must be negative")
if d < -(len(min_tiled_shape) - len(min_shape)):
raise ValueError("Dimension out of range")
if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP:
raise ValueError
if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE:
raise ValueError
@functools.cached_property
def tiled_tiling_shape(self) -> tuple[int, ...]:
"""The shape of the suffix of the array after tiling.
We only allow our repeated tiling actions to further subdivide the
dimensions created by previous tiling actions (except for the first one),
so the tiled shape always ends with this suffix, no matter what array shape
it's applied to.
"""
return self.tiling.tile_shape(self.tiling.tiles[0])
@property
def vector_length(self) -> int:
return self.tiled_tiling_shape[self.vector_dim]
def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
"""Returns the shape of the register array needed to represent an array of the given logical shape."""
tiled_shape = list(self.tiling.tile_shape(shape))
tiled_shape[self.warp_dim] = 1
for d in self.lane_dims:
tiled_shape[d] = 1
tiled_shape[self.vector_dim] = 1
return tuple(tiled_shape)
def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
"""Returns the logical shape of an array given its register array shape.
Inverse to `registers_shape`.
"""
tiled_tiling = self.tiled_tiling_shape
shape = list(shape)
shape[self.warp_dim] = WARPS_IN_WARPGROUP
for d in self.lane_dims:
shape[d] = tiled_tiling[d]
shape[self.vector_dim] = tiled_tiling[self.vector_dim]
return self.tiling.untile_shape(tuple(shape))
def lane_indices(self) -> tuple[ir.Value, ...]:
i32 = ir.IntegerType.get_signless(32)
tiled_shape = tuple(
d if i in self.lane_dims else 1
for i, d in enumerate_negative(self.tiled_tiling_shape)
)
assert math.prod(tiled_shape) == WARP_SIZE
lane_strides = utils.get_contiguous_strides(tiled_shape)
lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32))
return tuple(
arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32))
for stride, size in zip(lane_strides, tiled_shape)
)
def warp_indices(self) -> tuple[ir.Value, ...]:
i32 = ir.IntegerType.get_signless(32)
tiled_shape = tuple(
d if i == self.warp_dim else 1
for i, d in enumerate_negative(self.tiled_tiling_shape)
)
assert math.prod(tiled_shape) == WARPS_IN_WARPGROUP
warp_idx = arith.remui(
arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)),
c(WARPS_IN_WARPGROUP, i32),
)
indices = [arith.constant(i32, 0)] * len(tiled_shape)
indices[self.warp_dim] = warp_idx
return tuple(indices)
def _tiled_wgmma_layout(shape: tuple[int, ...]):
"""Returns the tiled layout relevant for WGMMA operations.
The tiled layout is equivalent to one described here in PTX documentation:
https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d
This tiled layout is equivalent to WGMMAFragLayout and will subsume it.
"""
if len(shape) != 2:
raise ValueError(f"Shape {shape} is not 2D")
if shape[0] % 64 != 0 or shape[1] % 8 != 0:
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
return TiledLayout(
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
warp_dim=-8,
lane_dims=frozenset((-4, -3)),
vector_dim=-1,
)
@dataclasses.dataclass(frozen=True)
class WGMMAFragLayout:
"""[m, n] matrix, where m % 64 == 0 == n % 8."""
def thread_idxs(self, shape):
index = ir.IndexType.get()
assert shape[0] % 64 == 0 and shape[1] % 8 == 0
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
warp_idx = arith.divui(tid_wg, c(32, index))
tid_warp = arith.remui(tid_wg, c(32, index))
col_base = arith.muli(arith.remui(tid_warp, c(4, index)), c(2, index))
row_base = arith.addi(
arith.divui(tid_warp, c(4, index)), arith.muli(warp_idx, c(16, index))
)
for row_group in range(0, shape[0], 64):
for col_group in range(0, shape[1], 8):
for row_subgroup in range(0, 16, 8):
row = arith.addi(row_base, c(row_group + row_subgroup, index))
yield row, arith.addi(col_base, c(col_group, index))
def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
assert len(shape) == 2
assert shape[0] % 64 == 0 and shape[1] % 8 == 0
return (shape[0] // 64, shape[1] // 8, 2, 1)
@dataclasses.dataclass(frozen=True)
class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0."""
def thread_idxs(self, shape):
raise NotImplementedError
@dataclasses.dataclass(frozen=True)
class WGSplatFragLayout:
"""A fragmented array where all the values are equal represented as a register per thread.
@ -75,36 +343,6 @@ class WGSplatFragLayout:
raise NotImplementedError
@dataclasses.dataclass(frozen=True)
class WGMMAFragLayout:
"""[m, n] matrix, where m % 64 == 0 == n % 8."""
def thread_idxs(self, shape):
index = ir.IndexType.get()
assert shape[0] % 64 == 0 and shape[1] % 8 == 0
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
warp_idx = arith.divui(tid_wg, c(32, index))
tid_warp = arith.remui(tid_wg, c(32, index))
col_base = arith.muli(arith.remui(tid_warp, c(4, index)), c(2, index))
row_base = arith.addi(
arith.divui(tid_warp, c(4, index)), arith.muli(warp_idx, c(16, index))
)
for row_group in range(0, shape[0], 64):
for col_group in range(0, shape[1], 8):
for row_subgroup in range(0, 16, 8):
row = arith.addi(row_base, c(row_group + row_subgroup, index))
yield row, arith.addi(col_base, c(col_group, index))
@dataclasses.dataclass(frozen=True)
class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0."""
def thread_idxs(self, shape):
raise NotImplementedError
@dataclasses.dataclass(frozen=True)
class WGStridedFragLayout:
"""Convert the array to 1D and then shard across threads."""
@ -162,7 +400,7 @@ class WGStridedFragLayout:
yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))
FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout
FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout | TiledLayout
WGMMA_LAYOUT = WGMMAFragLayout()
@ -230,6 +468,14 @@ class FragmentedArray:
if _registers.size != 1:
raise ValueError(f"Invalid register array shape: {_registers.shape}")
case TiledLayout():
try:
self.layout.shape_from_registers_shape(_registers.shape)
except ValueError:
raise ValueError(
"Register array shape does not match the tiled layout"
) from None
case _:
raise NotImplementedError
@ -304,15 +550,21 @@ class FragmentedArray:
return shape
case WGSplatFragLayout(shape=shape):
return shape
case TiledLayout():
return self.layout.shape_from_registers_shape(self.registers.shape)
case _:
raise NotImplementedError
@property
def mlir_dtype(self):
reg_ty = self.registers.flat[0].type
match self.layout:
case WGMMAFragLayout() | WGStridedFragLayout():
case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout():
return ir.VectorType(reg_ty).element_type
case WGMMARowFragLayout() | WGSplatFragLayout():
return reg_ty
case _:
raise NotImplementedError
def to_layout(self, new_layout: FragmentedLayout):
"""Converts the fragmented array to the given layout.
@ -321,6 +573,17 @@ class FragmentedArray:
"""
if self.layout == new_layout:
return self
shape = self.shape
if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 8 == 0:
tiled_layout = _tiled_wgmma_layout(shape)
if (self.layout == WGMMA_LAYOUT and new_layout == tiled_layout) or (
self.layout == tiled_layout and new_layout == WGMMA_LAYOUT
):
return FragmentedArray(
_registers=self.registers.reshape(new_layout.registers_shape(shape)),
_layout=new_layout,
_is_signed=self.is_signed,
)
if not isinstance(self.layout, WGSplatFragLayout):
raise NotImplementedError(
f"Cannot convert from {self.layout} to {new_layout}"
@ -745,10 +1008,9 @@ class FragmentedArray:
raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}")
new_registers = np.empty_like(self.registers)
match self.layout:
case WGMMAFragLayout():
new_reg_ty = ir.VectorType.get((2,), new_dtype)
case WGStridedFragLayout(vec_size=vec_size):
new_reg_ty = ir.VectorType.get((vec_size,), new_dtype)
case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout():
shape = ir.VectorType(self.registers.flat[0].type).shape
new_reg_ty = ir.VectorType.get(shape, new_dtype)
case WGMMARowFragLayout() | WGSplatFragLayout():
new_reg_ty = new_dtype
case _:
@ -916,6 +1178,8 @@ class FragmentedArray:
self._store_untiled_splat(ref)
case WGStridedFragLayout():
self._store_untiled_wg_strided(ref)
case TiledLayout():
self._store_untiled_tiled(ref)
case _:
raise NotImplementedError(self.layout)
@ -982,6 +1246,32 @@ class FragmentedArray:
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
memref.store(value, ref, [row, col])
def _store_untiled_tiled(self, ref: ir.Value):
"""Stores an array with a tiled layout. Not optimized at the moment."""
i32 = ir.IntegerType.get_signless(32)
layout = self.layout
assert isinstance(layout, TiledLayout)
ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset()
if ref_strides[layout.vector_dim] != 1:
raise NotImplementedError(
"Can't use vector stores with non-unit minormost stride"
)
strides = layout.tiling.tile_strides(ref_strides)
ptr = utils.memref_ptr(ref)
# Fold warp and lane offsets into the pointer once, since they are dynamic.
dyn_strides = [arith.constant(i32, s) for s in strides]
def dyn_dot(x, y):
return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y)))
warp_offset = dyn_dot(layout.warp_indices(), dyn_strides)
lane_offset = dyn_dot(layout.lane_indices(), dyn_strides)
dyn_offset = arith.addi(warp_offset, lane_offset)
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
# All warp tile offsets are static and can be fused into the store.
for tile_idx, reg in np.ndenumerate(self.registers):
lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True))
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
llvm.store(reg, reg_ptr)
def store_tiled(self, ref, swizzle: int | None):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError

View File

@ -40,6 +40,7 @@ import numpy as np
WARPGROUP_SIZE: int = 128
DYNAMIC = -9223372036854775808
DYNAMIC32 = -2147483648
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
@ -1036,3 +1037,11 @@ def is_signed(dtype: jax.typing.DTypeLike) -> bool | None:
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.issubdtype(dtype, jnp.signedinteger)
return None
def getelementptr(
ptr: ir.Value, indices: Sequence[ir.Value | int], dtype: ir.Type
) -> ir.Value:
static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices]
dyn_indices = [i for i in indices if not isinstance(i, int)]
return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype)

View File

@ -29,6 +29,7 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax.experimental.mosaic.gpu import fragmented_array as fa
import jax.numpy as jnp
import numpy as np
try:
@ -1583,5 +1584,29 @@ class TorchTest(TestCase):
del y # Make sure the destructor runs successfully.
class LayoutTest(TestCase):
@parameterized.product(
shape=((128, 128), (64, 8), (64, 256)),
dtype=(jnp.int32, jnp.int16, jnp.int8),
)
def test_wgmma_tiled_layout(self, shape, dtype):
def kernel(ctx, dst, _):
iota = iota_tensor(*shape, dtype)
tiled = iota.to_layout(fa._tiled_wgmma_layout(shape))
# Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1)
self.assertEqual(
tiled.registers.shape,
(shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1),
)
self.assertEqual(tiled.shape, shape)
self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype)
tiled.store_untiled(dst)
ty = jax.ShapeDtypeStruct(shape, dtype)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ())
expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
np.testing.assert_array_equal(f(), expected)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())