mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

This allows us to significantly simplify the generated PTX/SASS, which is currently cluttered with LLVM trying to align slices to start at bit 0 and failing to CSE the right shifts. PiperOrigin-RevId: 737967890
2423 lines
96 KiB
Python
2423 lines
96 KiB
Python
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities for code generator."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import functools
|
|
import math
|
|
from collections.abc import Callable
|
|
from typing import Iterable, Protocol, Sequence, TypeVar
|
|
|
|
import itertools
|
|
import jax
|
|
from jaxlib.mlir import ir
|
|
from jaxlib.mlir.dialects import arith
|
|
from jaxlib.mlir.dialects import gpu
|
|
from jaxlib.mlir.dialects import llvm
|
|
from jaxlib.mlir.dialects import math as mlir_math
|
|
from jaxlib.mlir.dialects import memref
|
|
from jaxlib.mlir.dialects import nvvm
|
|
from jaxlib.mlir.dialects import vector
|
|
import numpy as np
|
|
|
|
import jax.experimental.mosaic.gpu as mgpu
|
|
from . import utils
|
|
|
|
# mypy: ignore-errors
|
|
|
|
T = TypeVar("T")
|
|
WARPGROUP_SIZE = utils.WARPGROUP_SIZE
|
|
WARP_SIZE = 32
|
|
WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE
|
|
SMEM_BANKS = 32
|
|
SMEM_BANK_BYTES = 4
|
|
c = utils.c
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Tiling:
|
|
"""A tiling expression describing a permutation of elements of an nd-array.
|
|
|
|
To apply one level of tiling to an array, each of the trailing dimensions (up
|
|
to the rank of the tile) is unfolded into two dimensions: first equal to the
|
|
ratio of the dimension size and the tile size, and second equal to the tile
|
|
size. Then, all newly unfolded minor dimensions are transposed to appear at
|
|
the end.
|
|
|
|
This expression describes multi-level tiling, by applying each element of
|
|
`tiles` in sequence to the array.
|
|
|
|
See https://openxla.org/xla/tiled_layout for a more detailed explanation.
|
|
"""
|
|
tiles: tuple[tuple[int, ...], ...]
|
|
|
|
def __post_init__(self):
|
|
if not self.tiles:
|
|
return
|
|
tiled_rank = len(self.tiles[0])
|
|
for tile in self.tiles:
|
|
if len(tile) > tiled_rank:
|
|
raise ValueError("Only the first tile can refer to value dimensions")
|
|
if not tile:
|
|
raise ValueError("Tiles must not be empty")
|
|
if any(d <= 0 for d in tile):
|
|
raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}")
|
|
tiled_rank += len(tile)
|
|
|
|
def __str__(self):
|
|
return f"Tiling({''.join(map(str, self.tiles))})"
|
|
|
|
def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
"""Computes the shape of an array after tiling."""
|
|
def fail():
|
|
raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}")
|
|
for tile in self.tiles:
|
|
if len(tile) > len(shape):
|
|
fail()
|
|
untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):]
|
|
if any(s % t != 0 for s, t in zip(tiled_dims, tile)):
|
|
fail()
|
|
shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile)
|
|
return shape
|
|
|
|
def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
"""Computes the shape of an array before tiling from its tiled shape."""
|
|
def fail():
|
|
raise ValueError(
|
|
f"shape {shape} is not a valid result of applying tiling {self}."
|
|
)
|
|
for tile in reversed(self.tiles):
|
|
if len(tile) > len(shape):
|
|
fail()
|
|
untiled_dims = shape[:-2 * len(tile)]
|
|
tiled_dims = shape[-2 * len(tile):-len(tile)]
|
|
tiling_dims = shape[-len(tile):]
|
|
if tiling_dims != tile:
|
|
fail()
|
|
shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile)))
|
|
return shape
|
|
|
|
def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]:
|
|
"""Computes the strides of an array after tiling."""
|
|
for tile in self.tiles:
|
|
untiled, tiled = strides[:-len(tile)], strides[-len(tile):]
|
|
strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled)
|
|
return strides
|
|
|
|
def tile_nested_shape_strides(
|
|
self,
|
|
shape: tuple[tuple[int, ...], ...],
|
|
strides: tuple[tuple[int, ...], ...],
|
|
) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]:
|
|
"""A fused version of `tile_shape` and `tile_strides` for nested shapes.
|
|
|
|
By nested shape we mean that each logical dimension (i.e. each element of
|
|
shape/strides) is actually composed out of multiple physical dimensions.
|
|
For example, a row-major array of logical shape (128, 128) that is tiled
|
|
into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each
|
|
dim is split into two sub-dims) and nested strides of
|
|
((2 * 64 * 64, 64), (64 * 64, 1)).
|
|
"""
|
|
if len(shape) != len(strides):
|
|
raise ValueError(
|
|
f"Shape {shape} and strides {strides} must have the same length"
|
|
)
|
|
def fail_if(cond, shape=shape): # Capture shape now.
|
|
if cond:
|
|
raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}")
|
|
for tile in self.tiles:
|
|
fail_if(len(tile) > len(shape))
|
|
untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):]
|
|
untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):]
|
|
major_dim_shapes, major_dim_strides = [], []
|
|
minor_dim_shapes, minor_dim_strides = [], []
|
|
for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides):
|
|
major_dim_shape_rev, major_dim_stride_rev = [], []
|
|
minor_dim_shape_rev, minor_dim_stride_rev = [], []
|
|
for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True):
|
|
if d < t: # We will need to tile more dims
|
|
fail_if(t % d != 0)
|
|
t //= d
|
|
minor_dim_shape_rev.append(d)
|
|
minor_dim_stride_rev.append(s)
|
|
elif t != 1: # Last dim to tile!
|
|
fail_if(d % t != 0)
|
|
minor_dim_shape_rev.append(t)
|
|
minor_dim_stride_rev.append(s)
|
|
if d != t: # No need to insert singleton dims.
|
|
major_dim_shape_rev.append(d // t)
|
|
major_dim_stride_rev.append(s * t)
|
|
t = 1
|
|
else: # Done tiling!
|
|
major_dim_shape_rev.append(d)
|
|
major_dim_stride_rev.append(s)
|
|
fail_if(t != 1)
|
|
major_dim_shapes.append(major_dim_shape_rev[::-1])
|
|
minor_dim_shapes.append(minor_dim_shape_rev[::-1])
|
|
major_dim_strides.append(major_dim_stride_rev[::-1])
|
|
minor_dim_strides.append(minor_dim_stride_rev[::-1])
|
|
shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes)
|
|
strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides)
|
|
return (
|
|
tuple(tuple(d) if d else (1,) for d in shape),
|
|
tuple(tuple(d) if d else (1,) for d in strides),
|
|
)
|
|
|
|
def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
|
|
for tile in self.tiles:
|
|
untiled, tiled = indices[:-len(tile)], indices[-len(tile):]
|
|
indices = (
|
|
*untiled,
|
|
*(i // t for i, t in zip(tiled, tile)),
|
|
*(i % t for i, t in zip(tiled, tile)),
|
|
)
|
|
return indices
|
|
|
|
def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
|
|
for tile in reversed(self.tiles):
|
|
untiled = indices[:-2 * len(tile)]
|
|
outer = indices[-2 * len(tile):-len(tile)]
|
|
inner = indices[-len(tile):]
|
|
indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile)))
|
|
return indices
|
|
|
|
def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]:
|
|
"""Like built-in enumerate, but returns negative indices into the sequence."""
|
|
offset = len(elems)
|
|
for i, e in enumerate(elems):
|
|
yield i - offset, e
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TiledLayout:
|
|
"""A FragmentedArray layout derived from a tiling expression.
|
|
|
|
A logical array is transformed according to the tiling expression, and then
|
|
split across warps (within a warpgroup), lanes, and vectorized according to
|
|
the dimension indices. All dimension indices must be negative and should refer
|
|
to the dimensions after tiling is applied.
|
|
|
|
Note that warp_dim and vector_dim could be sets as well, but we don't have a
|
|
usecase for that yet.
|
|
|
|
To better understand this layout, consider the example of WGMMA-related tiling
|
|
from https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d as
|
|
applied to a 128x128 array. The corresponding TiledLayout has a tiling of:
|
|
|
|
(64, 8)(16, 8)(8, 8)(1, 2)
|
|
|
|
and warp_dim=-8, lane_dims=(-4, -3), vector_dim=-1.
|
|
|
|
We begin by applying the tiling (note that it always applies to a suffix):
|
|
|
|
Tiled shape Remaining tiling actions
|
|
===========================================================================
|
|
128 128 (64, 8)(16, 8)(8, 8)(1, 2)
|
|
2 16 64 8 (16, 8)(8, 8)(1, 2)
|
|
2 16 4 1 16 8 (8, 8)(1, 2)
|
|
2 16 4 1 2 1 8 8 (1, 2)
|
|
2 16 4 1 2 1 8 4 1 2
|
|
|
|
The last expression is our final shape. At this stage, we're ready to
|
|
interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the
|
|
end is partitioned over 4 warps in a warpgroup (and so it must be of size 4).
|
|
lane_dims=(-4, -3) indicate that those two dimensions are partitioned over
|
|
the lanes within a warp (their product must be equal to 32, i.e. warp size).
|
|
Finally, vector_dim=-1 indicates that each (logical) register is a vector
|
|
containing 2 elements (there are no shape restrictions here).
|
|
|
|
Given the above, the shape of the (logical) register array used to represent
|
|
the array in each thread is: (2, 16, 1, 1, 2, 1, 1, 1, 1, 1). We have set all
|
|
the dimensions above to 1, since each thread is a member of a single warp,
|
|
a single lane, and the elements along the vectorized dimension are represented
|
|
by a single (logical) register.
|
|
"""
|
|
tiling: Tiling
|
|
warp_dim: int
|
|
lane_dims: tuple[int, ...] # major-to-minor
|
|
vector_dim: int
|
|
|
|
def __post_init__(self):
|
|
if not self.tiling.tiles:
|
|
raise ValueError("Tiling must have at least one tile")
|
|
min_shape = self.tiling.tiles[0]
|
|
min_tiled_shape = self.tiling.tile_shape(min_shape)
|
|
dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim}
|
|
if len(dims_set) != len(self.lane_dims) + 2:
|
|
raise ValueError
|
|
for d in dims_set:
|
|
if d >= 0:
|
|
raise ValueError("All dimensions must be negative")
|
|
if d < -(len(min_tiled_shape) - len(min_shape)):
|
|
raise ValueError("Dimension out of range")
|
|
if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP:
|
|
raise ValueError
|
|
if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE:
|
|
raise ValueError
|
|
|
|
def thread_idxs(self, shape: tuple[int, ...]) -> Iterable[tuple[ir.Value, ...]]:
|
|
# We first find the linear index and then divide by the shape to
|
|
# get the index.
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
index = ir.IndexType.get()
|
|
contig_strides = utils.get_contiguous_strides(shape)
|
|
tile_strides = self.tiling.tile_strides(contig_strides)
|
|
dyn_tile_strides = [c(s, i32) for s in tile_strides[-self.tiled_tiling_rank:]]
|
|
warp_offset = utils.dyn_dot(self.warp_indices(), dyn_tile_strides)
|
|
lane_offset = utils.dyn_dot(self.lane_indices(), dyn_tile_strides)
|
|
dyn_offset = arith.addi(warp_offset, lane_offset)
|
|
register_shape = self.registers_shape(shape)
|
|
for tile_idx in np.ndindex(register_shape):
|
|
tile_lin_idx = sum(i * s for i, s in zip(tile_idx, tile_strides))
|
|
dyn_lin_idx = arith.addi(dyn_offset, c(tile_lin_idx, i32))
|
|
idx = []
|
|
for stride in contig_strides:
|
|
idx.append(arith.index_castui(index, arith.divui(dyn_lin_idx, c(stride, i32))))
|
|
dyn_lin_idx = arith.remui(dyn_lin_idx, c(stride, i32))
|
|
yield tuple(idx)
|
|
|
|
@property
|
|
def base_tile_shape(self) -> int:
|
|
"""The shape of the first tile in the tiling expression.
|
|
|
|
This tile acts as the divisibility constraint for a suffix of arrays to
|
|
which this layout applies.
|
|
"""
|
|
return self.tiling.tiles[0]
|
|
|
|
@functools.cached_property
|
|
def tiled_tiling_shape(self) -> tuple[int, ...]:
|
|
"""The shape of the suffix of the array after tiling.
|
|
|
|
We only allow our repeated tiling actions to further subdivide the
|
|
dimensions created by previous tiling actions (except for the first one),
|
|
so the tiled shape always ends with this suffix, no matter what array shape
|
|
it's applied to.
|
|
"""
|
|
base_tile_shape = self.base_tile_shape
|
|
return self.tiling.tile_shape(base_tile_shape)[len(base_tile_shape):]
|
|
|
|
@functools.cached_property
|
|
def tiled_tiling_rank(self) -> int:
|
|
return len(self.tiled_tiling_shape)
|
|
|
|
@property
|
|
def vector_length(self) -> int:
|
|
return self.tiled_tiling_shape[self.vector_dim]
|
|
|
|
def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
"""Returns the shape of the register array needed to represent an array of the given logical shape."""
|
|
tiled_shape = list(self.tiling.tile_shape(shape))
|
|
tiled_shape[self.warp_dim] = 1
|
|
for d in self.lane_dims:
|
|
tiled_shape[d] = 1
|
|
tiled_shape[self.vector_dim] = 1
|
|
return tuple(tiled_shape)
|
|
|
|
def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
"""Returns the logical shape of an array given its register array shape.
|
|
|
|
Inverse to `registers_shape`.
|
|
"""
|
|
tiled_tiling = self.tiled_tiling_shape
|
|
shape = list(shape)
|
|
shape[self.warp_dim] = WARPS_IN_WARPGROUP
|
|
for d in self.lane_dims:
|
|
shape[d] = tiled_tiling[d]
|
|
shape[self.vector_dim] = tiled_tiling[self.vector_dim]
|
|
return self.tiling.untile_shape(tuple(shape))
|
|
|
|
def lane_indices(self) -> tuple[ir.Value, ...]:
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
tiled_shape = self.tiled_tiling_shape
|
|
lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims)
|
|
assert math.prod(lanes_shape) == WARP_SIZE
|
|
lane_strides = utils.get_contiguous_strides(lanes_shape)
|
|
lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32))
|
|
lane_indices = tuple(
|
|
arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32))
|
|
for stride, size in zip(lane_strides, lanes_shape)
|
|
)
|
|
full_indices = [arith.constant(i32, 0)] * len(tiled_shape)
|
|
for d, i in zip(self.lane_dims, lane_indices):
|
|
full_indices[d] = i
|
|
return tuple(full_indices)
|
|
|
|
def warp_indices(self) -> tuple[ir.Value, ...]:
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
tiled_shape_rank = len(self.tiled_tiling_shape)
|
|
warp_idx = arith.remui(
|
|
arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)),
|
|
c(WARPS_IN_WARPGROUP, i32),
|
|
)
|
|
indices = [arith.constant(i32, 0)] * tiled_shape_rank
|
|
indices[self.warp_dim] = warp_idx
|
|
return tuple(indices)
|
|
|
|
|
|
def _tiled_wgmma_layout(shape: tuple[int, ...]):
|
|
"""Returns the tiled layout relevant for WGMMA operations.
|
|
|
|
The tiled layout is equivalent to one described here in PTX documentation:
|
|
https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d
|
|
"""
|
|
if len(shape) != 2:
|
|
raise ValueError(f"Shape {shape} is not 2D")
|
|
if shape[0] % 64 != 0 or shape[1] % 8 != 0:
|
|
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
|
|
return WGMMA_LAYOUT
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class WGMMARowFragLayout:
|
|
"""[m] matrix, where m % 64 == 0."""
|
|
|
|
def thread_idxs(self, shape):
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class WGSplatFragLayout:
|
|
"""A fragmented array where all the values are equal represented as a register per thread.
|
|
|
|
FragmentedArrays in this layout can be are always the result of a
|
|
splat, each thread in the warpgroup has a single copy of the value,
|
|
while the FragmentedArray pretends it has whatever shape the user
|
|
wants. This means we can trivially broadcast, reshape and do
|
|
elementwise operations with all other layouts.
|
|
|
|
Examples:
|
|
|
|
To load a value in
|
|
```
|
|
FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2))
|
|
```
|
|
|
|
A shape is always provided for sanity check reasons.
|
|
|
|
"""
|
|
|
|
shape: tuple[int, ...] = ()
|
|
|
|
def can_broadcast_to(self, shape) -> bool:
|
|
"""Check that the shape can be broadcast.
|
|
|
|
Only dimensions of size 1 can be broadcast. All other dimensions
|
|
must be the same as the argument shape.
|
|
"""
|
|
return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1]))
|
|
|
|
def thread_idxs(self, shape):
|
|
assert shape == self.shape
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class WGStridedFragLayout:
|
|
"""Convert the array to 1D and then shard across threads."""
|
|
|
|
shape: tuple[int, ...]
|
|
vec_size: int
|
|
|
|
def __post_init__(self):
|
|
if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0:
|
|
raise ValueError((self, WARPGROUP_SIZE))
|
|
|
|
@classmethod
|
|
def from_shaped_type(cls, shaped_ty: ir.Type):
|
|
if not ir.ShapedType.isinstance(shaped_ty):
|
|
raise TypeError(shaped_ty)
|
|
|
|
shaped_ty = ir.ShapedType(shaped_ty)
|
|
bw = mgpu.bytewidth(shaped_ty.element_type)
|
|
assert 8 % bw == 0 and 8 // bw != 0, bw
|
|
if math.prod(shaped_ty.shape) % WARPGROUP_SIZE != 0:
|
|
raise ValueError(
|
|
f"{shaped_ty} must have a number of elements that is a multiple of"
|
|
f" {WARPGROUP_SIZE} (got {math.prod(shaped_ty.shape)})"
|
|
)
|
|
max_vec_size = np.prod(shaped_ty.shape) // WARPGROUP_SIZE
|
|
return cls(
|
|
shape=tuple(shaped_ty.shape), vec_size=min(8 // bw, max_vec_size)
|
|
)
|
|
|
|
def thread_idxs(self, shape):
|
|
assert shape == self.shape
|
|
index = ir.IndexType.get()
|
|
for v in self.linear_thread_idxs():
|
|
res = []
|
|
for dim in reversed(self.shape):
|
|
dim = c(dim, index)
|
|
res.append(arith.remui(v, dim))
|
|
v = arith.divui(v, dim)
|
|
res.reverse()
|
|
yield res
|
|
|
|
def linear_thread_idxs(self):
|
|
"""The indexes to be used for vector load/store WGStridedFragLayout.
|
|
|
|
Yields:
|
|
The indices of the vector that correspond to the current thread.
|
|
"""
|
|
index = ir.IndexType.get()
|
|
cardinality = np.prod(self.shape)
|
|
assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0
|
|
reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size)
|
|
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE, index))
|
|
off = arith.muli(tidx, c(self.vec_size, tidx.type))
|
|
for i in range(reg_num):
|
|
yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))
|
|
|
|
|
|
FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | TiledLayout
|
|
|
|
|
|
WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
|
|
|
|
# The tiled layout is equivalent to one described here in PTX documentation:
|
|
# https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d
|
|
# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles.
|
|
# Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit
|
|
# of data that is split across a warp. Since 8*8 = 64, but a warp has only 32
|
|
# threads, we vectorize pairs of elements along columns.
|
|
# The assignment of elements to warp lanes is as follows:
|
|
#
|
|
# 0 0 1 1 2 2 3 3
|
|
# 4 4 5 5 6 6 7 7
|
|
# 8 8 9 9 10 10 11 11
|
|
# 12 12 13 13 14 14 15 15
|
|
# ...
|
|
WGMMA_LAYOUT = TiledLayout(
|
|
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
|
|
warp_dim=-8,
|
|
lane_dims=(-4, -3),
|
|
vector_dim=-1,
|
|
)
|
|
# This tiled layout is similar to the WGMMA layout, only the unit at which we
|
|
# assign submatrices to warps grows from 8x8 to 8x16. The elements within each
|
|
# submatrix are assigned to threads in the following way:
|
|
#
|
|
# 0 0 0 0 2 2 2 2 1 1 1 1 3 3 3 3
|
|
# 4 4 4 4 6 6 6 6 5 5 5 5 7 7 7 7
|
|
# ...
|
|
#
|
|
# Our vector length is twice the size of that of WGMMA_LAYOUT, which lets us use
|
|
# 32-bit SMEM loads/stores when dealing with 8-bit values. The conversion
|
|
# to the WGMMA layout only requires communication between with index differing
|
|
# in their 2 bit (i.e. 0 and 1, 2 and 4), so the conversion to WGMMA_LAYOUT
|
|
# only requires a single warp shuffle (plus permutes local to each thread).
|
|
WGMMA_LAYOUT_UPCAST_2X = TiledLayout(
|
|
Tiling(((64, 16), (16, 16), (8, 16), (8,), (4,))),
|
|
warp_dim=-8,
|
|
lane_dims=(-4, -2, -3),
|
|
vector_dim=-1,
|
|
)
|
|
# This layout should be used when upcasting 4-bit elements to 16-bit, for the
|
|
# purpose of passing them into WGMMA later. The core matrices stored by a warp
|
|
# are 8x32, because each of the 4 threads in a row holds 8 elements in a single
|
|
# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each
|
|
# group of 4 threads in order (as opposed to the swapping between 1 and 2,
|
|
# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does).
|
|
WGMMA_LAYOUT_UPCAST_4X = TiledLayout(
|
|
Tiling(((64, 32), (16, 32), (8, 32), (8,))),
|
|
warp_dim=-7,
|
|
lane_dims=(-3, -2),
|
|
vector_dim=-1,
|
|
)
|
|
# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8
|
|
# submatrix in the following way (we only show the first 4 rows for brevity):
|
|
#
|
|
# 0 0 1 1 2 2 3 3
|
|
# 4 4 5 5 6 6 7 7
|
|
# 8 8 9 9 10 10 11 11
|
|
# 12 12 13 13 14 14 15 15
|
|
# ...
|
|
#
|
|
# This tiled layout stores the same 8x8 submatrix in the following way:
|
|
#
|
|
# 0 4 1 5 2 6 3 7
|
|
# 0 4 1 5 2 6 3 7
|
|
# 8 12 9 13 10 14 11 15
|
|
# 8 12 9 13 10 14 11 15
|
|
# ...
|
|
#
|
|
# You can see that we have taken 2x2 submatrices from the above layout and
|
|
# transposed them. The assigment of lanes to elements is such that in both
|
|
# layouts the same two lanes map to a single 2x2 submatrix, making the transpose
|
|
# very cheap (one shuffle and permute suffices to change between those layouts).
|
|
WGMMA_TRANSPOSED_LAYOUT = TiledLayout(
|
|
Tiling(((64, 8), (16, 8), (8, 8), (2, 2), (2, 1))),
|
|
warp_dim=-10,
|
|
lane_dims=(-6, -3, -5),
|
|
vector_dim=-2,
|
|
)
|
|
|
|
@jax.tree_util.register_pytree_node_class
|
|
@dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True)
|
|
class FragmentedArray:
|
|
# An array of ir.Value, see checks in init for shapes.
|
|
registers: np.ndarray = dataclasses.field(repr=False)
|
|
layout: FragmentedLayout
|
|
is_signed: bool | None
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
_registers: np.ndarray,
|
|
_layout: FragmentedLayout,
|
|
_is_signed: bool | None,
|
|
):
|
|
"""Initializes a fragmented array.
|
|
|
|
This is a low-level API. Prefer using classmethods to construct fragmented
|
|
arrays instead.
|
|
"""
|
|
# We need to use ``object.__setattr__`` here because of ``frozen=True``.
|
|
object.__setattr__(self, "registers", _registers)
|
|
object.__setattr__(self, "layout", _layout)
|
|
object.__setattr__(self, "is_signed", _is_signed)
|
|
|
|
if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype):
|
|
raise TypeError(
|
|
"is_signed must be non-None if and only if the MLIR type is an"
|
|
f" integer type, got {_is_signed=} for {self.mlir_dtype}"
|
|
)
|
|
|
|
match self.layout:
|
|
# Registers are [m_tiles, 2 rows] in WGMMA_ROW layout
|
|
# Each element is a dtype scalar
|
|
case WGMMARowFragLayout():
|
|
if _registers.ndim != 2 or _registers.shape[-1] != 2:
|
|
raise ValueError(f"Invalid register array shape: {_registers.shape}")
|
|
|
|
# Registers are flat
|
|
case WGStridedFragLayout(shape):
|
|
[reg_size] = ir.VectorType(_registers.flat[0].type).shape
|
|
if (
|
|
math.prod(shape)
|
|
!= math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size
|
|
):
|
|
raise ValueError(
|
|
"Invalid register array shape: math.prod({_registers.shape}) *"
|
|
" {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})"
|
|
)
|
|
|
|
# Just a single register
|
|
case WGSplatFragLayout():
|
|
if _registers.size != 1:
|
|
raise ValueError(f"Invalid register array shape: {_registers.shape}")
|
|
|
|
case TiledLayout():
|
|
try:
|
|
self.layout.shape_from_registers_shape(_registers.shape)
|
|
except ValueError:
|
|
raise ValueError(
|
|
"Register array shape does not match the tiled layout"
|
|
) from None
|
|
|
|
case _:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def load_strided(
|
|
cls,
|
|
ref: ir.Value,
|
|
*,
|
|
is_signed: bool | None = None,
|
|
vec_size: int | None = None,
|
|
):
|
|
if not ir.MemRefType.isinstance(ref.type):
|
|
raise TypeError(ref.type)
|
|
|
|
ref_ty = ir.MemRefType(ref.type)
|
|
shape = tuple(ref_ty.shape)
|
|
if vec_size is None:
|
|
layout = WGStridedFragLayout.from_shaped_type(ref_ty)
|
|
else:
|
|
layout = WGStridedFragLayout(shape=shape, vec_size=vec_size)
|
|
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
|
|
try:
|
|
# Flattening the reference potentially produces simpler PTX but
|
|
# if the ref is not already 1D and has strided dimensions
|
|
# flattening won't work.
|
|
ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
|
|
vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_idxs()]
|
|
except NotImplementedError:
|
|
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
|
|
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)
|
|
|
|
@classmethod
|
|
def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None):
|
|
layout = layout or WGSplatFragLayout(shape)
|
|
match layout:
|
|
case WGMMARowFragLayout():
|
|
if len(shape) != 1:
|
|
raise ValueError("WGMMARowFragLayout requires a 1D shape")
|
|
if shape[0] % 64:
|
|
raise ValueError(
|
|
"WGMMARowFragLayout requires shape[0] to be a multiple of 64"
|
|
)
|
|
reg_shape = (shape[0] // 64, 2)
|
|
case WGStridedFragLayout(vec_size=vec_size):
|
|
assert shape == layout.shape
|
|
elems = np.prod(shape)
|
|
reg_shape = (elems // (WARPGROUP_SIZE * vec_size),)
|
|
value = vector.splat(ir.VectorType.get((vec_size,), value.type), value)
|
|
case WGSplatFragLayout():
|
|
assert shape == layout.shape
|
|
reg_shape = ()
|
|
case TiledLayout():
|
|
value = vector.splat(ir.VectorType.get((layout.vector_length,), value.type), value)
|
|
reg_shape = layout.registers_shape(shape)
|
|
case _:
|
|
raise NotImplementedError(layout)
|
|
|
|
return cls(
|
|
_registers=np.full(reg_shape, value, dtype=object),
|
|
_layout=layout,
|
|
_is_signed=is_signed,
|
|
)
|
|
|
|
@property
|
|
def shape(self):
|
|
match self.layout:
|
|
case WGMMARowFragLayout():
|
|
row_tiles = self.registers.shape[0]
|
|
return (row_tiles * 64,)
|
|
case WGStridedFragLayout(shape):
|
|
return shape
|
|
case WGSplatFragLayout(shape=shape):
|
|
return shape
|
|
case TiledLayout():
|
|
return self.layout.shape_from_registers_shape(self.registers.shape)
|
|
case _:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def mlir_dtype(self):
|
|
reg_ty = self.registers.flat[0].type
|
|
match self.layout:
|
|
case WGStridedFragLayout() | TiledLayout():
|
|
return ir.VectorType(reg_ty).element_type
|
|
case WGMMARowFragLayout() | WGSplatFragLayout():
|
|
return reg_ty
|
|
case _:
|
|
raise NotImplementedError
|
|
|
|
def to_layout(self, new_layout: FragmentedLayout):
|
|
"""Converts the fragmented array to the given layout.
|
|
|
|
At the moment, only conversions from ``WGSplatFragLayout`` are supported.
|
|
"""
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
c = lambda x: arith.constant(i32, x)
|
|
if self.layout == new_layout:
|
|
return self
|
|
shape = self.shape
|
|
if (
|
|
self.layout == WGMMA_LAYOUT
|
|
and new_layout == WGMMA_TRANSPOSED_LAYOUT
|
|
and utils.bitwidth(self.mlir_dtype) == 16
|
|
):
|
|
is_even_row = arith.cmpi(
|
|
arith.CmpIPredicate.eq,
|
|
arith.remui(arith.divui(utils.thread_idx(), c(4)), c(2)),
|
|
c(0),
|
|
)
|
|
perm = arith.select(is_even_row, c(0x5410), c(0x3276))
|
|
new_regs = []
|
|
for reg in self.registers.flat:
|
|
reg_ty = reg.type
|
|
reg = utils.bitcast(reg, i32)
|
|
reg_shfl = utils.shfl_bfly(reg, 4)
|
|
new_reg = utils.prmt(reg, reg_shfl, perm)
|
|
new_regs.append(utils.bitcast(new_reg, reg_ty))
|
|
return FragmentedArray(
|
|
_registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)),
|
|
_layout=new_layout,
|
|
_is_signed=self.is_signed,
|
|
)
|
|
if (
|
|
self.layout == WGMMA_LAYOUT_UPCAST_2X
|
|
and new_layout == WGMMA_LAYOUT
|
|
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16
|
|
):
|
|
assert shape[1] % 16 == 0 # Should be implied by the layout
|
|
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
|
is_even = arith.cmpi(
|
|
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
|
|
)
|
|
registers = self.registers
|
|
if dtype_bitwidth == 4:
|
|
if registers.shape[1] % 2:
|
|
raise NotImplementedError(
|
|
"This relayout implementation requires an even number of column"
|
|
" tiles (to pack pairs of them for efficiency)"
|
|
)
|
|
# We pair up the consecutive column tiles, so each register is 32-bit.
|
|
# If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout,
|
|
# LLVM will realize that the paired up vectors actually came from the
|
|
# same 32-bit register and it will become a no-op.
|
|
col_minor_registers = np.moveaxis(registers, 1, -1)
|
|
flat_registers = [
|
|
utils.vector_concat((l, h))
|
|
for l, h in zip(
|
|
col_minor_registers.flat[::2], col_minor_registers.flat[1::2]
|
|
)
|
|
]
|
|
registers = np.asarray(flat_registers, dtype=object).reshape(
|
|
*col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2
|
|
)
|
|
registers = np.moveaxis(registers, -1, 1)
|
|
for idx, reg in np.ndenumerate(registers):
|
|
if dtype_bitwidth == 16:
|
|
assert reg.type.shape == [4]
|
|
# A single vector is 64-bits, but shuffles are only 32-bit wide.
|
|
# We only shuffle the half that needs to go to other thread.
|
|
low = utils.vector_slice(reg, slice(0, 2))
|
|
high = utils.vector_slice(reg, slice(2, 4))
|
|
to_exchange = arith.select(is_even, high, low)
|
|
# Exchange values between even and odd threads.
|
|
exchanged = utils.shfl_bfly(to_exchange, 1)
|
|
low = arith.select(is_even, low, exchanged)
|
|
high = arith.select(is_even, exchanged, high)
|
|
new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low
|
|
new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high
|
|
elif dtype_bitwidth == 8:
|
|
assert reg.type.shape == [4]
|
|
# The vector is 32-bits, so we just shuffle the whole thing and
|
|
# use prmt to blend it with the local register.
|
|
exchanged = utils.shfl_bfly(reg, 1)
|
|
# Consider lanes 0 and 1, because the situation is symmetric for
|
|
# each pair. If we feed reg[lane] and exchanged[lane] (which is
|
|
# really the same as reg of the other lane) to prmt, we can index
|
|
# the elements of the result using the following indices:
|
|
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
|
|
# prmt[0]: 0 1 2 3 4 5 6 7
|
|
# prmt[1]: 4 5 6 7 0 1 2 3
|
|
# The expected outputs and their respective permutations are:
|
|
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
|
|
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
|
|
# Note that the patterns still need to be flipped, since we listed
|
|
# bytes with LSB on the left, which is the opposite of how the
|
|
# numeric constants are spelled in Python (LSB on the right).
|
|
perm = arith.select(is_even, c(0x5410), c(0x3276))
|
|
blend = utils.prmt(reg, exchanged, perm)
|
|
for i in range(2):
|
|
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
|
|
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
|
|
else:
|
|
assert dtype_bitwidth == 4
|
|
assert reg.type.shape == [8] # We paired up the registers above.
|
|
exchanged = utils.shfl_bfly(reg, 1)
|
|
# See comment above for a more complete explanation.
|
|
# reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27
|
|
# prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7--
|
|
# prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3--
|
|
# The expected outputs and their respective permutations are:
|
|
# out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27
|
|
# prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3--
|
|
perm = arith.select(is_even, c(0x6240), c(0x3715))
|
|
blend = utils.prmt(reg, exchanged, perm)
|
|
for i in range(4):
|
|
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
|
|
new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg
|
|
assert all(r is not None for r in new_registers)
|
|
return FragmentedArray(
|
|
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
|
)
|
|
if (
|
|
self.layout == WGMMA_LAYOUT_UPCAST_4X
|
|
and new_layout == WGMMA_LAYOUT_UPCAST_2X
|
|
and utils.bitwidth(self.mlir_dtype) == 4
|
|
):
|
|
assert shape[0] % 64 == 0 # Should be implied by the layout
|
|
assert shape[1] % 32 == 0 # Should be implied by the layout
|
|
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
c = lambda x: arith.constant(i32, x)
|
|
is_01 = arith.cmpi(
|
|
arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2)
|
|
)
|
|
for idx, reg in np.ndenumerate(self.registers):
|
|
assert ir.VectorType(reg.type).shape == [8]
|
|
# The vector is 32-bits, so we just shuffle the whole thing and
|
|
# use prmt to blend it with the local register.
|
|
exchanged = utils.shfl_bfly(reg, 2)
|
|
# See comments above for conventions. Here we exchange data between
|
|
# threads with lane index related by flipping 2nd bit (e.g. 0 and 2).
|
|
# reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23
|
|
# prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7--
|
|
# prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3--
|
|
# The expected outputs and their respective permutations are:
|
|
# out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23
|
|
# prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3--
|
|
perm = arith.select(is_01, c(0x5410), c(0x3276))
|
|
blend = utils.prmt(reg, exchanged, perm)
|
|
for i in range(2):
|
|
reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4))
|
|
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
|
|
assert all(r is not None for r in new_registers)
|
|
return FragmentedArray(
|
|
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
|
)
|
|
if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT:
|
|
return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout)
|
|
if not isinstance(self.layout, WGSplatFragLayout):
|
|
raise NotImplementedError(
|
|
f"Cannot convert from {self.layout} to {new_layout}"
|
|
)
|
|
[reg] = self.registers.flat
|
|
return type(self).splat(
|
|
reg, self.shape, new_layout, is_signed=self.is_signed
|
|
)
|
|
|
|
def _pointwise(self, op, *other, output_is_signed: bool | None = None):
|
|
# If our layout is a splat, then we should either dispatch to a non-splat
|
|
# layout, or broadcast ourselves to the output shape first.
|
|
if isinstance(self.layout, WGSplatFragLayout):
|
|
output_shape = self.shape
|
|
for i, o in enumerate(other):
|
|
if not isinstance(o, FragmentedArray):
|
|
continue
|
|
elif not isinstance(o.layout, WGSplatFragLayout):
|
|
return o._pointwise(
|
|
lambda o, this, *args: op(this, *args[:i], o, *args[i:]),
|
|
self,
|
|
*other[:i],
|
|
*other[i + 1 :],
|
|
output_is_signed=output_is_signed,
|
|
)
|
|
else:
|
|
output_shape = np.broadcast_shapes(output_shape, o.shape)
|
|
# If we get here then we haven't found any non-splat layout.
|
|
if self.shape != output_shape:
|
|
return self.broadcast(output_shape)._pointwise(
|
|
op, *other, output_is_signed=output_is_signed
|
|
)
|
|
|
|
other_arrs = []
|
|
for o in other:
|
|
if not isinstance(o, FragmentedArray):
|
|
if isinstance(o, (float, int)):
|
|
o = utils.c(o, self.mlir_dtype)
|
|
elif not isinstance(o, ir.Value):
|
|
raise NotImplementedError(o)
|
|
|
|
o = FragmentedArray.splat(
|
|
o, shape=self.shape, layout=self.layout, is_signed=self.is_signed
|
|
)
|
|
|
|
if isinstance(o.layout, WGSplatFragLayout):
|
|
if not o.layout.can_broadcast_to(self.shape):
|
|
raise ValueError(
|
|
f"Cannot broadcast shape {self.shape} to layout {o.layout}")
|
|
o = FragmentedArray.splat(
|
|
o.registers.flat[0],
|
|
shape=self.shape,
|
|
layout=self.layout,
|
|
is_signed=o.is_signed,
|
|
)
|
|
else:
|
|
if self.layout != o.layout:
|
|
raise ValueError("Incompatible FragmentedArray layouts")
|
|
if self.registers.shape != o.registers.shape:
|
|
raise ValueError("Incompatible FragmentedArray shapes")
|
|
|
|
other_arrs.append(o)
|
|
new_regs = np.empty_like(self.registers)
|
|
|
|
for idx, reg in np.ndenumerate(self.registers):
|
|
new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs))
|
|
reg_ty = new_regs.flat[0].type
|
|
if ir.VectorType.isinstance(reg_ty):
|
|
reg_ty = ir.VectorType(reg_ty).element_type
|
|
if output_is_signed is None and ir.IntegerType.isinstance(reg_ty):
|
|
output_is_signed = self.is_signed
|
|
return FragmentedArray(
|
|
_registers=new_regs, _layout=self.layout, _is_signed=output_is_signed
|
|
)
|
|
|
|
def __pos__(self):
|
|
return self
|
|
|
|
def __neg__(self):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(arith.negf)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return 0 - self
|
|
else:
|
|
return NotImplemented
|
|
|
|
def __add__(self, other):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(addf, other)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(arith.addi, other)
|
|
else:
|
|
return NotImplemented
|
|
|
|
def __radd__(self, other):
|
|
return self + other
|
|
|
|
def __mul__(self, other):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(mulf, other)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(arith.muli, other)
|
|
else:
|
|
return NotImplemented
|
|
|
|
def __rmul__(self, other):
|
|
return self * other
|
|
|
|
def __sub__(self, other):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(subf, other)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(arith.subi, other)
|
|
else:
|
|
return NotImplemented
|
|
|
|
def __rsub__(self, other):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(lambda s, o: subf(o, s), other)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(lambda s, o: arith.subi(o, s), other)
|
|
else:
|
|
return NotImplemented
|
|
|
|
def __truediv__(self, other):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
return NotImplemented
|
|
return self._pointwise(arith.divf, other)
|
|
|
|
def __rtruediv__(self, other):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
return NotImplemented
|
|
return self._pointwise(lambda s, o: arith.divf(o, s), other)
|
|
|
|
def __floordiv__(self, other):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(
|
|
lambda s, o: mlir_math.floor(arith.divf(s, o)), other
|
|
)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
if self.is_signed:
|
|
return self._pointwise(arith.floordivsi, other)
|
|
else:
|
|
return self._pointwise(arith.divui, other)
|
|
else:
|
|
return NotImplemented
|
|
|
|
def __rfloordiv__(self, other):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(
|
|
lambda s, o: mlir_math.floor(arith.divf(o, s)), other
|
|
)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
if self.is_signed:
|
|
return self._pointwise(lambda s, o: arith.floordivsi(o, s), other)
|
|
else:
|
|
return self._pointwise(lambda s, o: arith.divui(o, s), other)
|
|
else:
|
|
return NotImplemented
|
|
|
|
def __mod__(self, other):
|
|
if not ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return NotImplemented
|
|
if self.is_signed:
|
|
return self._pointwise(arith.remsi, other)
|
|
else:
|
|
return self._pointwise(arith.remui, other)
|
|
|
|
def __rmod__(self, other):
|
|
if not ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return NotImplemented
|
|
if self.is_signed:
|
|
return self._pointwise(lambda s, o: arith.remsi(o, s), other)
|
|
else:
|
|
return self._pointwise(lambda s, o: arith.remui(o, s), other)
|
|
|
|
def __invert__(self):
|
|
if not ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return NotImplemented
|
|
return self ^ ~0
|
|
|
|
def __or__(self, other):
|
|
if not ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return NotImplemented
|
|
return self._pointwise(arith.ori, other)
|
|
|
|
def __ror__(self, other):
|
|
return self | other
|
|
|
|
def __and__(self, other):
|
|
if not ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return NotImplemented
|
|
return self._pointwise(arith.andi, other)
|
|
|
|
def __rand__(self, other):
|
|
return self & other
|
|
|
|
def __xor__(self, other):
|
|
if not ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return NotImplemented
|
|
return self._pointwise(arith.xori, other)
|
|
|
|
def __rxor__(self, other):
|
|
return self ^ other
|
|
|
|
def __eq__(self, other):
|
|
return self._compare(
|
|
other,
|
|
f_pred=arith.CmpFPredicate.OEQ,
|
|
si_pred=arith.CmpIPredicate.eq,
|
|
ui_pred=arith.CmpIPredicate.eq,
|
|
)
|
|
|
|
def __ne__(self, other):
|
|
return self._compare(
|
|
other,
|
|
f_pred=arith.CmpFPredicate.UNE,
|
|
si_pred=arith.CmpIPredicate.ne,
|
|
ui_pred=arith.CmpIPredicate.ne,
|
|
)
|
|
|
|
def __lt__(self, other):
|
|
return self._compare(
|
|
other,
|
|
f_pred=arith.CmpFPredicate.OLT,
|
|
si_pred=arith.CmpIPredicate.slt,
|
|
ui_pred=arith.CmpIPredicate.ult,
|
|
)
|
|
|
|
def __le__(self, other):
|
|
return self._compare(
|
|
other,
|
|
f_pred=arith.CmpFPredicate.OLE,
|
|
si_pred=arith.CmpIPredicate.sle,
|
|
ui_pred=arith.CmpIPredicate.ule,
|
|
)
|
|
|
|
def __gt__(self, other):
|
|
return self._compare(
|
|
other,
|
|
f_pred=arith.CmpFPredicate.OGT,
|
|
si_pred=arith.CmpIPredicate.sgt,
|
|
ui_pred=arith.CmpIPredicate.ugt,
|
|
)
|
|
|
|
def __ge__(self, other):
|
|
return self._compare(
|
|
other,
|
|
f_pred=arith.CmpFPredicate.OGE,
|
|
si_pred=arith.CmpIPredicate.sge,
|
|
ui_pred=arith.CmpIPredicate.uge,
|
|
)
|
|
|
|
def _compare(self, other, *, f_pred, si_pred, ui_pred):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
pred = functools.partial(arith.cmpf, f_pred)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
if ir.IntegerType(self.mlir_dtype).is_signed:
|
|
pred = functools.partial(arith.cmpi, si_pred)
|
|
else:
|
|
pred = functools.partial(arith.cmpi, ui_pred)
|
|
else:
|
|
raise NotImplementedError
|
|
return self._pointwise(pred, other, output_is_signed=False)
|
|
|
|
def max(self, other):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
maximumf = arith.maximumf
|
|
if ir.F32Type.isinstance(self.mlir_dtype):
|
|
maximumf = self._lift_fast_instr("max.NaN.f32")
|
|
return self._pointwise(maximumf, other)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(
|
|
arith.maxsi if self.is_signed else arith.maxui, other
|
|
)
|
|
else:
|
|
return NotImplementedError
|
|
|
|
def min(self, other):
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(arith.minimumf, other)
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
return self._pointwise(
|
|
arith.minsi if self.is_signed else arith.minui, other
|
|
)
|
|
else:
|
|
return NotImplementedError
|
|
|
|
def exp(self, *, approx: bool = False):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError
|
|
if approx:
|
|
dtype = self.mlir_dtype
|
|
log2e = arith.constant(dtype, ir.FloatAttr.get(dtype, 1.4426950408889634))
|
|
return (self * log2e).exp2()
|
|
return self._pointwise(mlir_math.exp)
|
|
|
|
def exp2(self, *, approx: bool = False):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError
|
|
if approx:
|
|
if not ir.F32Type.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError(self.mlir_dtype)
|
|
return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32"))
|
|
return self._pointwise(mlir_math.exp2)
|
|
|
|
def log(self, *, approx: bool = False):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError
|
|
if approx:
|
|
dtype = self.mlir_dtype
|
|
ln2 = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.6931471805599453))
|
|
return self.log2(approx=True) * ln2
|
|
return self._pointwise(mlir_math.log)
|
|
|
|
def log2(self, *, approx: bool = False):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError(self.mlir_dtype)
|
|
if approx:
|
|
if not ir.F32Type.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError(self.mlir_dtype)
|
|
return self._pointwise(self._lift_fast_instr("lg2.approx.ftz.f32"))
|
|
return self._pointwise(mlir_math.log2)
|
|
|
|
def sin(self, *, approx: bool = False):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError
|
|
if approx and self.mlir_dtype != ir.F32Type.get():
|
|
raise NotImplementedError
|
|
return self._pointwise(
|
|
self._lift_fast_instr("sin.approx.f32") if approx else mlir_math.sin
|
|
)
|
|
|
|
def cos(self, *, approx: bool = False):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError
|
|
if approx and self.mlir_dtype != ir.F32Type.get():
|
|
raise NotImplementedError
|
|
return self._pointwise(
|
|
self._lift_fast_instr("cos.approx.f32") if approx else mlir_math.cos
|
|
)
|
|
|
|
def tanh(self, *, approx: bool = False):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError
|
|
if approx and self.mlir_dtype != ir.F32Type.get():
|
|
raise NotImplementedError
|
|
return self._pointwise(
|
|
self._lift_fast_instr("tanh.approx.f32") if approx else mlir_math.tanh
|
|
)
|
|
|
|
def rsqrt(self, *, approx: bool = False):
|
|
if not ir.FloatType.isinstance(self.mlir_dtype):
|
|
raise NotImplementedError
|
|
if approx and self.mlir_dtype != ir.F32Type.get():
|
|
raise NotImplementedError
|
|
return self._pointwise(
|
|
self._lift_fast_instr("rsqrt.approx.f32") if approx else mlir_math.rsqrt
|
|
)
|
|
|
|
@staticmethod
|
|
def _lift_fast_instr(
|
|
instr: str | Callable[[ir.Value], ir.Value],
|
|
) -> Callable[[ir.Value], ir.Value]:
|
|
def fast_instr(*args):
|
|
f32 = ir.F32Type.get()
|
|
arg_ty = args[0].type
|
|
assert all(a.type == arg_ty for a in args)
|
|
if arg_ty == f32:
|
|
if isinstance(instr, str):
|
|
args_ptx = ", ".join(f"${i}" for i in range(len(args) + 1))
|
|
return llvm.inline_asm(
|
|
f32, args, f"{instr} {args_ptx};", "=f" + ",f" * len(args)
|
|
)
|
|
else:
|
|
return instr(*args)
|
|
elif ir.VectorType.isinstance(arg_ty):
|
|
index = ir.IndexType.get()
|
|
result = llvm.mlir_undef(arg_ty)
|
|
[vec_len] = ir.VectorType(arg_ty).shape
|
|
for i in range(vec_len):
|
|
vs = [vector.extractelement(a, position=c(i, index)) for a in args]
|
|
vr = fast_instr(*vs)
|
|
result = vector.insertelement(vr, result, position=c(i, index))
|
|
return result
|
|
else:
|
|
raise NotImplementedError(arg_ty)
|
|
return fast_instr
|
|
|
|
def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None):
|
|
if (output_is_signed is not None) != ir.IntegerType.isinstance(elt):
|
|
raise TypeError(
|
|
"output_is_signed must be non-None if and only if the MLIR type is an"
|
|
f" integer type, got {output_is_signed=} for {elt}"
|
|
)
|
|
|
|
if elt == self.mlir_dtype:
|
|
return self
|
|
reg_type = self.registers.flat[0].type
|
|
if ir.VectorType.isinstance(reg_type):
|
|
reg_shape = ir.VectorType(reg_type).shape
|
|
ty = ir.VectorType.get(reg_shape, elt)
|
|
else:
|
|
ty = elt
|
|
|
|
return self._pointwise(
|
|
lambda x: arith.bitcast(ty, x), output_is_signed=output_is_signed
|
|
)
|
|
|
|
def __getitem__(self, idx):
|
|
if self.layout != WGMMA_LAYOUT:
|
|
raise NotImplementedError("Only WGMMA layouts support slicing")
|
|
base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
|
|
if any(is_squeezed):
|
|
raise NotImplementedError("Only slicing implemented")
|
|
if (
|
|
base_idx[0] % 64
|
|
or slice_shape[0] % 64
|
|
or base_idx[1] % 8
|
|
or slice_shape[1] % 8
|
|
):
|
|
raise NotImplementedError("Only tile aligned slicing supported")
|
|
base_idx[0] //= 64
|
|
slice_shape[0] //= 64
|
|
base_idx[1] //= 8
|
|
slice_shape[1] //= 8
|
|
new_regs = self.registers[
|
|
base_idx[0] : base_idx[0] + slice_shape[0],
|
|
base_idx[1] : base_idx[1] + slice_shape[1],
|
|
]
|
|
return FragmentedArray(
|
|
_registers=new_regs, _layout=self.layout, _is_signed=self.is_signed
|
|
)
|
|
|
|
# TODO(apaszke): Support JAX dtypes here as well?
|
|
def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None):
|
|
i4 = ir.IntegerType.get_signless(4)
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
i16 = ir.IntegerType.get_signless(16)
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
bf16 = ir.BF16Type.get()
|
|
|
|
cur_dtype = self.mlir_dtype
|
|
if cur_dtype == new_dtype:
|
|
if self.is_signed == is_signed:
|
|
return self
|
|
return FragmentedArray(
|
|
_registers=self.registers, _layout=self.layout, _is_signed=is_signed
|
|
)
|
|
reg_type = self.registers.flat[0].type
|
|
is_vector_reg = ir.VectorType.isinstance(reg_type)
|
|
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
|
|
[vector_len] = reg_shape # This is meant to be a 1D assertion.
|
|
if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8:
|
|
raise ValueError(
|
|
"Register bitwidth in target type must be divisible by 8, got"
|
|
f" {new_reg_bitwidth}"
|
|
)
|
|
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
|
|
new_registers = np.empty_like(self.registers)
|
|
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
|
|
for idx, reg in np.ndenumerate(self.registers):
|
|
# The algorithm here is largely the same as CUTLASS's
|
|
# NumericArrayConverter specialization for int4 -> bf16 casts.
|
|
# We modify it slightly, because we only extract 2 values.
|
|
# We first shift the value by 4 bits, to put the high int4 in low bits.
|
|
# The prmt then blends the two values together, by putting them into the
|
|
# low bits of each 16-bit subword of our register. Then, we use the lop3
|
|
# to zero any bits that don't belong to our int4s, and finally use the
|
|
# XOR to: (1) set the exponent bits to 0x43 (at which point the mantissa
|
|
# represents integer increments) and (2) flip the sign bit. If we
|
|
# interpret the 4 bits as uint4 after the flip, then we'll see that
|
|
# positive int4s will end up larger than negative int4s, with a bias of
|
|
# 8. Use use the sub to subtract the base (our initial exponent) and the
|
|
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
|
|
def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
|
|
assert 0 <= part < 4
|
|
return llvm.inline_asm(
|
|
i32,
|
|
[reg, reg_shr],
|
|
f"""
|
|
{{
|
|
.reg .b32 s<4>;
|
|
prmt.b32 s1, $1, $2, 0xF{part + 4}F{part};
|
|
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
|
|
mov.b32 s3, 0x43084308;
|
|
sub.bf16x2 $0, s2, s3;
|
|
}}
|
|
""",
|
|
"=r,r,r",
|
|
)
|
|
offset = 0
|
|
out_int_regs = []
|
|
for group_size in (8, 4, 2):
|
|
int_ty = ir.IntegerType.get_signless(group_size * 4)
|
|
while vector_len - offset >= group_size:
|
|
# If the vector originates from a slice (common after relayouts), we
|
|
# can fuse the slicing into the conversion and prevent LLVM from
|
|
# generating a bunch of shifts to align the vector data to the LSB.
|
|
# This also lets us share the right shift among more vectors.
|
|
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
|
|
and utils.bitwidth(slice_op.vector.type) == 32
|
|
and slice_op.strides[0].value == 1):
|
|
slice_offset = slice_op.offsets[0].value + offset
|
|
reg_int = utils.bitcast(slice_op.vector, i32)
|
|
reg_int_shr = arith.shrui(reg_int, c(4, i32))
|
|
out_int_regs.extend(
|
|
upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
|
|
for part in range(group_size // 2)
|
|
)
|
|
else:
|
|
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
|
reg_slice_int = utils.bitcast(reg_slice, int_ty)
|
|
if int_ty != i32:
|
|
reg_slice_int = arith.extsi(i32, reg_slice_int)
|
|
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
|
|
out_int_regs.extend(
|
|
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
|
|
for part in range(group_size // 2)
|
|
)
|
|
offset += group_size
|
|
assert offset == vector_len
|
|
out_vec_int = utils.vector_concat([
|
|
vector.splat(ir.VectorType.get((1,), i32), reg)
|
|
for reg in out_int_regs
|
|
])
|
|
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
|
|
return FragmentedArray(
|
|
_registers=new_registers, _layout=self.layout, _is_signed=None
|
|
)
|
|
if cur_dtype == i8 and self.is_signed and new_dtype == bf16 and vector_len in {2, 4}:
|
|
new_registers = np.empty_like(self.registers)
|
|
def upcast_to_bf16(reg, high):
|
|
# We first embed the s8 into a bf16 with the exponent equal to
|
|
# bias + mantissa bits. Then, we zero the msb that didn't fit into the
|
|
# mantissa, zero out all bits other than msb, and subtract the last
|
|
# two values from each other. This takes advantage of the fact that the
|
|
# lsb of the exponent (msb of the second byte) is zero, which allows us
|
|
# to losslesly pack the msb there. When 1, it doubles the value of s2,
|
|
# making the result negative.
|
|
return llvm.inline_asm(
|
|
i32,
|
|
[reg],
|
|
f"""
|
|
{{
|
|
.reg .b32 s<3>;
|
|
prmt.b32 s0, $1, 0x43, {0x4342 if high else 0x4140};
|
|
and.b32 s1, s0, 0xff7fff7f;
|
|
and.b32 s2, s0, 0xff80ff80;
|
|
sub.bf16x2 $0, s1, s2;
|
|
}}
|
|
""",
|
|
"=r,r",
|
|
)
|
|
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((vector_len // 2,), i32))
|
|
for idx, reg in np.ndenumerate(self.registers):
|
|
if vector_len == 2:
|
|
reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg)
|
|
new_reg_32 = upcast_to_bf16(reg_16, high=False)
|
|
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
|
|
elif vector_len == 4:
|
|
reg_32 = vector.bitcast(ir.VectorType.get((1,), i32), reg)
|
|
low = upcast_to_bf16(reg_32, high=False)
|
|
high = upcast_to_bf16(reg_32, high=True)
|
|
new_vec_32 = llvm.insertelement(empty_vec_32, low, c(0, i32))
|
|
new_vec_32 = llvm.insertelement(new_vec_32, high, c(1, i32))
|
|
else:
|
|
raise NotImplementedError(vector_len)
|
|
new_registers[idx] = vector.bitcast(
|
|
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
|
|
)
|
|
return FragmentedArray(
|
|
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
|
|
)
|
|
# Generic path.
|
|
from_float = ir.FloatType.isinstance(cur_dtype)
|
|
to_float = ir.FloatType.isinstance(new_dtype)
|
|
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
|
to_integer = ir.IntegerType.isinstance(new_dtype)
|
|
if from_float and to_float:
|
|
cur_ty_width = ir.FloatType(cur_dtype).width
|
|
new_ty_width = ir.FloatType(new_dtype).width
|
|
if cur_ty_width == new_ty_width:
|
|
# There is no instruction to perform conversions between two float types
|
|
# of the same width. Go through the next-larger standard type.
|
|
# TODO(bchetioui): support conversions between float types of width 8.
|
|
# Which larger type to pick will depend on the number of bits in the
|
|
# smallest exponent.
|
|
if cur_ty_width != 16:
|
|
raise NotImplementedError(
|
|
"Conversion between float types of width other than 16 not"
|
|
" supported"
|
|
)
|
|
larger_ty = ir.F32Type.get()
|
|
match self.layout:
|
|
case WGStridedFragLayout() | TiledLayout():
|
|
shape = ir.VectorType(self.registers.flat[0].type).shape
|
|
upcast_ty = ir.VectorType.get(shape, larger_ty)
|
|
case WGMMARowFragLayout() | WGSplatFragLayout():
|
|
upcast_ty = larger_ty
|
|
case _:
|
|
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
|
convert = lambda ty, x: arith.truncf(ty, arith.extf(upcast_ty, x))
|
|
elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
|
|
convert = arith.truncf
|
|
else:
|
|
convert = arith.extf
|
|
elif from_integer and to_integer:
|
|
if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width:
|
|
convert = arith.trunci
|
|
else:
|
|
convert = arith.extsi if self.is_signed else arith.extui
|
|
elif from_integer and to_float:
|
|
convert = arith.sitofp if self.is_signed else arith.uitofp
|
|
elif from_float and to_integer:
|
|
convert = arith.fptosi if is_signed else arith.fptoui
|
|
else:
|
|
raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}")
|
|
new_registers = np.empty_like(self.registers)
|
|
match self.layout:
|
|
case WGStridedFragLayout() | TiledLayout():
|
|
shape = ir.VectorType(self.registers.flat[0].type).shape
|
|
new_reg_ty = ir.VectorType.get(shape, new_dtype)
|
|
case WGMMARowFragLayout() | WGSplatFragLayout():
|
|
new_reg_ty = new_dtype
|
|
case _:
|
|
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
|
for idx, reg in np.ndenumerate(self.registers):
|
|
new_registers[idx] = convert(new_reg_ty, reg)
|
|
return FragmentedArray(
|
|
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
|
|
)
|
|
|
|
# NOTE: scratch can be reused immediately once this function returns.
|
|
def reduce_sum(self, scratch: ir.Value | None = None):
|
|
if isinstance(self.layout, WGSplatFragLayout):
|
|
[reg] = self.registers.flat
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
op = mulf
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
op = arith.muli
|
|
else:
|
|
raise NotImplementedError(self.mlir_dtype)
|
|
return FragmentedArray.splat(
|
|
op(reg, utils.c(math.prod(self.shape), self.mlir_dtype)),
|
|
(),
|
|
is_signed=self.is_signed,
|
|
)
|
|
|
|
if not isinstance(self.layout, WGStridedFragLayout):
|
|
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
|
|
|
if scratch is None:
|
|
raise ValueError("scratch must be provided")
|
|
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
op = addf
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
op = arith.addi
|
|
else:
|
|
raise NotImplementedError(self.mlir_dtype)
|
|
|
|
result = c(0, self.mlir_dtype)
|
|
for reg in self.registers:
|
|
result = op(
|
|
result,
|
|
vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
|
|
)
|
|
scratch_ty = ir.MemRefType(scratch.type)
|
|
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
|
|
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
|
|
|
|
index = ir.IndexType.get()
|
|
warp_result = utils.warp_tree_reduce(result, op, 32)
|
|
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
|
|
memref.store(warp_result, scratch, [warp_id])
|
|
utils.warpgroup_barrier()
|
|
zero_index = c(0, index)
|
|
with mgpu.single_thread(per_block=False):
|
|
scratch_vec = vector.load(
|
|
ir.VectorType.get((4,), self.mlir_dtype),
|
|
scratch,
|
|
[zero_index],
|
|
)
|
|
scratch_sum = vector.reduction(
|
|
self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec
|
|
)
|
|
memref.store(scratch_sum, scratch, [zero_index])
|
|
utils.warpgroup_barrier()
|
|
result = memref.load(scratch, [zero_index])
|
|
utils.warpgroup_barrier() # Make sure everyone is done using scratch.
|
|
return FragmentedArray.splat(result, (), is_signed=self.is_signed)
|
|
|
|
def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis):
|
|
if isinstance(op, str):
|
|
match op:
|
|
case "add":
|
|
if ir.FloatType.isinstance(self.mlir_dtype):
|
|
op = addf
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
op = arith.addi
|
|
else:
|
|
raise NotImplementedError(self.mlir_dtype)
|
|
case "max":
|
|
if ir.F32Type.isinstance(self.mlir_dtype):
|
|
op = self._lift_fast_instr("max.NaN.f32")
|
|
elif ir.FloatType.isinstance(self.mlir_dtype):
|
|
op = arith.maximumf
|
|
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
|
op = arith.maxsi if self.is_signed else arith.maxui
|
|
else:
|
|
raise NotImplementedError(self.mlir_dtype)
|
|
case _:
|
|
raise ValueError(f"Unrecognized reduction operator: {op}")
|
|
if self.layout != WGMMA_LAYOUT:
|
|
raise NotImplementedError(self.layout)
|
|
if axis != 1:
|
|
raise NotImplementedError
|
|
index = ir.IndexType.get()
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
row_tile_dim = self.registers.shape[0]
|
|
row_subtile_dim = self.registers.shape[4]
|
|
new_regs = np.empty((row_tile_dim, row_subtile_dim), dtype=object)
|
|
assert self.registers.shape[-1] == 1
|
|
for row_tile, row_subtile in np.ndindex(new_regs.shape):
|
|
# Reduce the registers owned by the current thread over n tiles
|
|
reg_index = [0] * self.registers.ndim
|
|
reg_index[0] = row_tile
|
|
reg_index[4] = row_subtile
|
|
thread_result_vec = self.registers[tuple(reg_index)]
|
|
for n_tile in range(1, self.registers.shape[1]):
|
|
reg_index[1] = n_tile
|
|
thread_result_vec = op(
|
|
thread_result_vec, self.registers[tuple(reg_index)]
|
|
)
|
|
|
|
thread_result = vector.extractelement(thread_result_vec, position=c(0, index))
|
|
for i in range(1, self.layout.vector_length):
|
|
thread_result = op(
|
|
thread_result,
|
|
vector.extractelement(thread_result_vec, position=c(i, index)),
|
|
)
|
|
|
|
# Do a shuffle to reduce in groups of 4 consecutive threads.
|
|
result = thread_result
|
|
for i in (1, 2):
|
|
other_result = nvvm.shfl_sync(
|
|
result.type,
|
|
c(0xFFFFFFFF, i32),
|
|
result,
|
|
c(i, i32),
|
|
c(0x1F, i32),
|
|
nvvm.ShflKind.bfly,
|
|
)
|
|
result = op(result, other_result)
|
|
new_regs[row_tile, row_subtile] = result
|
|
return FragmentedArray(
|
|
_registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed
|
|
)
|
|
|
|
def broadcast(self, shape):
|
|
if not isinstance(self.layout, WGSplatFragLayout):
|
|
raise NotImplementedError(self.layout)
|
|
|
|
if self.shape == shape:
|
|
return self
|
|
|
|
if not self.layout.can_broadcast_to(shape):
|
|
raise ValueError(f"Can't broadcast {self.shape} to {shape}")
|
|
|
|
return FragmentedArray(
|
|
_registers=self.registers,
|
|
_layout=WGSplatFragLayout(shape),
|
|
_is_signed=self.is_signed,
|
|
)
|
|
|
|
def reshape(self, shape):
|
|
if self.shape == shape:
|
|
return self
|
|
if math.prod(shape) != math.prod(self.shape):
|
|
raise ValueError(f"Can't reshape {self.shape} to {shape}")
|
|
|
|
match self.layout:
|
|
case WGSplatFragLayout() | WGStridedFragLayout():
|
|
new_layout = dataclasses.replace(self.layout, shape=shape)
|
|
case _:
|
|
raise NotImplementedError(self.layout)
|
|
|
|
return FragmentedArray(
|
|
_registers=self.registers, _layout=new_layout, _is_signed=self.is_signed
|
|
)
|
|
|
|
def broadcast_minor(self, n):
|
|
if self.layout != WGMMA_ROW_LAYOUT:
|
|
raise NotImplementedError
|
|
if n % 8:
|
|
raise ValueError("Number of columns must be divisible by 8")
|
|
reg_shape = WGMMA_LAYOUT.registers_shape((self.shape[0], n))
|
|
new_regs = np.empty(reg_shape, dtype=object)
|
|
dtype = self.mlir_dtype
|
|
for (row_tile, row_subtile), reg in np.ndenumerate(self.registers):
|
|
tile = [slice(None)] * len(new_regs.shape)
|
|
tile[0] = row_tile
|
|
tile[4] = row_subtile
|
|
new_regs[tuple(tile)] = vector.splat(
|
|
ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), reg
|
|
)
|
|
return FragmentedArray(
|
|
_registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed
|
|
)
|
|
|
|
def select(self, on_true, on_false):
|
|
if (
|
|
not ir.IntegerType.isinstance(self.mlir_dtype)
|
|
or ir.IntegerType(self.mlir_dtype).width != 1
|
|
):
|
|
raise NotImplementedError
|
|
# We change the receiver here, because the return type is defined by
|
|
# `on_true` and `on_false` and not the predicate `self`.
|
|
return on_true._pointwise(
|
|
lambda t, p, f: arith.select(p, t, f), self, on_false,
|
|
)
|
|
|
|
def foreach(
|
|
self,
|
|
fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None],
|
|
*,
|
|
create_array=False,
|
|
is_signed=None,
|
|
):
|
|
"""Call a function for each value and index."""
|
|
index = ir.IndexType.get()
|
|
new_regs = None
|
|
if create_array:
|
|
new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type))
|
|
for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True):
|
|
reg = self.registers[reg_idx]
|
|
assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape)
|
|
[elems] = ir.VectorType(reg.type).shape
|
|
for i in range(elems):
|
|
i = c(i, index)
|
|
val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i)))
|
|
if create_array:
|
|
new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i)
|
|
|
|
if create_array:
|
|
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
|
|
|
|
def debug_print(self, fmt: str):
|
|
idx_fmt = ", ".join(["{}"] * len(self.shape))
|
|
@self.foreach
|
|
def _(val, idx):
|
|
fmt_str = fmt.format(f"[{idx_fmt}]: {{}}")
|
|
utils.debug_print(fmt_str, *idx, val, uniform=False)
|
|
|
|
def store_untiled(self, ref: ir.Value, *, vector_store: bool = True):
|
|
if not ir.MemRefType.isinstance(ref.type):
|
|
raise ValueError(ref)
|
|
|
|
def vs_unsupported():
|
|
if not vector_store:
|
|
raise NotImplementedError(
|
|
f"Can't use non-vector stores with layout {self.layout}"
|
|
)
|
|
|
|
match self.layout:
|
|
case WGSplatFragLayout():
|
|
vs_unsupported()
|
|
self._store_untiled_splat(ref)
|
|
case WGStridedFragLayout():
|
|
vs_unsupported()
|
|
self._store_untiled_wg_strided(ref)
|
|
case TiledLayout():
|
|
self._store_untiled_tiled(ref, vector_store=vector_store)
|
|
case _:
|
|
raise NotImplementedError(self.layout)
|
|
|
|
def _store_untiled_splat(self, ref: ir.Value):
|
|
vec_size = 64 // mgpu.bitwidth(self.mlir_dtype)
|
|
if np.prod(self.shape) < vec_size * WARPGROUP_SIZE:
|
|
vec_size = 1
|
|
|
|
if np.prod(self.shape) % WARPGROUP_SIZE * vec_size:
|
|
raise ValueError(self.shape, WARPGROUP_SIZE, vec_size)
|
|
|
|
fa = FragmentedArray.splat(
|
|
self.registers.flat[0],
|
|
self.shape,
|
|
layout=WGStridedFragLayout(shape=self.shape, vec_size=vec_size),
|
|
is_signed=self.is_signed,
|
|
)
|
|
fa.store_untiled(ref)
|
|
|
|
def _store_untiled_wg_strided(self, ref: ir.Value):
|
|
ref_ty = ir.MemRefType(ref.type)
|
|
try:
|
|
# Flattening the reference potentially produces simpler PTX but
|
|
# if the ref is not already 1D and has strided dimensions
|
|
# flattening won't work. We use a different variable for ref in
|
|
# case `NotImplementedError` is thrown by
|
|
# .linear_thread_idxs().
|
|
ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
|
|
idxs = ([i] for i in self.layout.linear_thread_idxs())
|
|
except NotImplementedError:
|
|
ref_ = ref
|
|
idxs = self.layout.thread_idxs()
|
|
ref_shape = tuple(ref_ty.shape)
|
|
if ref_shape != self.shape:
|
|
raise ValueError((ref_shape, self.shape))
|
|
for idx, reg in zip(idxs, self.registers.flat):
|
|
vector.store(reg, ref_, idx)
|
|
|
|
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
|
|
"""Stores an array with a tiled layout. Not optimized at the moment."""
|
|
if utils.bitwidth(self.mlir_dtype) < 8:
|
|
raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})")
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
layout = self.layout
|
|
assert isinstance(layout, TiledLayout)
|
|
ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset()
|
|
if vector_store and ref_strides[layout.vector_dim] != 1:
|
|
raise NotImplementedError(
|
|
"Can't use vector stores with non-unit minormost stride"
|
|
)
|
|
strides = layout.tiling.tile_strides(ref_strides)
|
|
smem_space = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
|
ref_space = ir.MemRefType(ref.type).memory_space
|
|
memory_space = None
|
|
if str(ref_space) == str(smem_space):
|
|
memory_space = 3
|
|
elif ref_space:
|
|
raise NotImplementedError(f"Unexpected ref space {ref_space}")
|
|
ptr = utils.memref_ptr(ref, memory_space=memory_space)
|
|
# Fold warp and lane offsets into the pointer once, since they are dynamic.
|
|
dyn_strides = [
|
|
arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :]
|
|
]
|
|
warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides)
|
|
lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides)
|
|
dyn_offset = arith.addi(warp_offset, lane_offset)
|
|
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
|
|
# All warp tile offsets are static and can be fused into the store.
|
|
for tile_idx, reg in np.ndenumerate(self.registers):
|
|
if vector_store:
|
|
elems = [reg]
|
|
else:
|
|
index = ir.IndexType.get()
|
|
elems = [
|
|
vector.extractelement(reg, position=c(i, index))
|
|
for i in range(ir.VectorType(reg.type).shape[0])
|
|
]
|
|
for i, e in enumerate(elems):
|
|
tile_idx_local = list(tile_idx)
|
|
tile_idx_local[layout.vector_dim] += i
|
|
tile_idx_local = list(tile_idx_local)
|
|
lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True))
|
|
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
|
|
llvm.store(e, reg_ptr)
|
|
|
|
def store_tiled(self, ref, swizzle: int | None):
|
|
if not isinstance(self.layout, TiledLayout):
|
|
raise NotImplementedError(self.layout)
|
|
layout, shape = self.layout, self.shape
|
|
for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape):
|
|
llvm.store(get(self.registers), ptr)
|
|
|
|
@classmethod
|
|
def load_tiled(
|
|
cls,
|
|
ref,
|
|
swizzle: int | None,
|
|
*,
|
|
is_signed: bool | None = None,
|
|
layout: FragmentedLayout = WGMMA_LAYOUT,
|
|
):
|
|
ref_ty = ir.MemRefType(ref.type)
|
|
dtype = ref_ty.element_type
|
|
match layout:
|
|
case TiledLayout():
|
|
ref_ty = ir.MemRefType(ref.type)
|
|
tiled_shape = ref_ty.shape
|
|
if len(tiled_shape) % 2:
|
|
raise ValueError("Tiled reference must have even rank")
|
|
tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],))
|
|
shape = tiling.untile_shape(tiled_shape)
|
|
zero = (
|
|
vector.splat(
|
|
ir.VectorType.get((layout.vector_length,), dtype), c(0, dtype)
|
|
),
|
|
)
|
|
registers = np.full(layout.registers_shape(shape), zero, dtype=object)
|
|
reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type)
|
|
for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape):
|
|
update(registers, llvm.load(reg_ty, ptr))
|
|
case _:
|
|
raise NotImplementedError(layout)
|
|
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
|
|
|
|
@staticmethod
|
|
def transfer_tiled(shape, dtype, swizzle: int | None):
|
|
# TODO(apaszke): We could use ldmatrix/stmatrix for 16-bit types.
|
|
bw = mgpu.bitwidth(dtype)
|
|
m, n = shape
|
|
assert m % 64 == 0 and n % 8 == 0 # Implied by the layout.
|
|
cols_per_tile = swizzle_elems = (swizzle * 8) // bw
|
|
if n < swizzle_elems:
|
|
cols_per_tile = n
|
|
else:
|
|
assert n % swizzle_elems == 0, (n, swizzle_elems)
|
|
if swizzle not in {32, 64, 128}:
|
|
raise NotImplementedError("Only swizzled stores supported")
|
|
|
|
c = arith.ConstantOp.create_index
|
|
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
|
|
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
|
|
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
|
|
sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7}
|
|
if bw > 16: # Stagger is only necessary for values larger than 16bit.
|
|
# We split the rows into two groups (left/right) and change the order in
|
|
# which they perform accesses to avoid bank conflicts.
|
|
# It seems that the STS.64 is 2x faster (and the hardware reports no
|
|
# conflicts) when the conflicts are split between half-warps, as
|
|
# opposed to having them within the half-warp. This requires a
|
|
# little more work for the selects, but is ultimately worth it.
|
|
match swizzle:
|
|
case 128:
|
|
is_stagger_left = arith.cmpi(
|
|
arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0)
|
|
)
|
|
case 64:
|
|
is_stagger_left = arith.cmpi(
|
|
arith.CmpIPredicate.eq,
|
|
arith.remui(arith.divui(sub_row_base, c(2)), c(2)),
|
|
c(0),
|
|
)
|
|
case 32:
|
|
# 32-byte tiles of 4-byte types have only 8 columns so there is no way
|
|
# to stagger the memory accesses within a single tile. We could do it
|
|
# across tiles, but that would be a completely different scheme.
|
|
raise NotImplementedError
|
|
case _:
|
|
raise AssertionError(swizzle)
|
|
stagger_amount = swizzle // 64
|
|
if (cols_per_tile // 8) % (stagger_amount * 2):
|
|
raise NotImplementedError
|
|
else:
|
|
# We rely on canonicalization to clean up the selects.
|
|
i1 = ir.IntegerType.get_signless(1)
|
|
is_stagger_left = arith.constant(i1, ir.BoolAttr.get(True))
|
|
stagger_amount = 0
|
|
row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16)))
|
|
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
|
|
# The swizzle pattern is constant for a given thread.
|
|
col_swizzle_bits = arith.muli(
|
|
arith.divui(sub_row_base, c(128 // swizzle)), c(128 // bw),
|
|
)
|
|
for row_group in range(m // 64):
|
|
for col_group in range(n // cols_per_tile):
|
|
for row_subidx in range(2):
|
|
row = arith.addi(row_base, c(row_subidx * 8))
|
|
for col_subidx in range(cols_per_tile // 8):
|
|
col_subidx_left = col_subidx
|
|
col_subidx_right = col_subidx ^ stagger_amount
|
|
col_off = arith.select(
|
|
is_stagger_left, c(col_subidx_left * 8), c(col_subidx_right * 8)
|
|
)
|
|
col = arith.addi(col_base, col_off)
|
|
col = arith.xori(col, col_swizzle_bits)
|
|
reg_idx_left = col_subidx_left + col_group * (cols_per_tile // 8)
|
|
reg_idx_right = col_subidx_right + col_group * (cols_per_tile // 8)
|
|
left_idx = row_group, reg_idx_left, row_subidx, 0
|
|
right_idx = row_group, reg_idx_right, row_subidx, 0
|
|
idx = c(row_group), c(col_group), row, col
|
|
def get_register(regs, left_idx=left_idx, right_idx=right_idx):
|
|
value_left = regs[left_idx]
|
|
value_right = regs[right_idx]
|
|
return arith.select(is_stagger_left, value_left, value_right)
|
|
def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx):
|
|
regs[left_idx] = arith.select(is_stagger_left, new, regs[left_idx])
|
|
regs[right_idx] = arith.select(is_stagger_left, regs[right_idx], new)
|
|
yield get_register, update_registers, idx
|
|
|
|
@staticmethod
|
|
def transfer_tiled2(
|
|
ref: ir.Value,
|
|
swizzle: int | None,
|
|
layout: TiledLayout,
|
|
shape: tuple[int, ...],
|
|
):
|
|
"""Generate a transfer schedule for a tiled layout.
|
|
|
|
Given a ref with one level tiling applied to it (we assume all dimensions
|
|
have been tiled), this function generates an iterable describing a good
|
|
schedule for swizzled SMEM loads/stores.
|
|
|
|
At each step, the iterable yields a tuple of three values:
|
|
* a function that takes a register array and returns the register to be
|
|
stored at the current address
|
|
* a function that takes a register array and a register loaded from the
|
|
current address, and updates the register array with that register
|
|
* the current address for load/store instructions
|
|
"""
|
|
# TODO(apaszke): Use ldmatrix/stmatrix when possible.
|
|
c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x)
|
|
tiling = layout.tiling
|
|
|
|
ref_ty = ir.MemRefType(ref.type)
|
|
dtype = ref_ty.element_type
|
|
if ref_ty.rank % 2:
|
|
raise ValueError("Tiled reference must have even rank")
|
|
ref_logical_rank = ref_ty.rank // 2
|
|
ref_tiling_shape = tuple(ref_ty.shape[ref_logical_rank:])
|
|
ref_tiling = Tiling((ref_tiling_shape,))
|
|
ref_strides, _ = ref_ty.get_strides_and_offset()
|
|
if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape:
|
|
raise ValueError()
|
|
nested_ref_shape = tuple(
|
|
(ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank])
|
|
for i in range(ref_logical_rank)
|
|
)
|
|
nested_ref_strides = tuple(
|
|
(ref_strides[i], ref_strides[i + ref_logical_rank])
|
|
for i in range(ref_logical_rank)
|
|
)
|
|
tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides(
|
|
nested_ref_shape, nested_ref_strides
|
|
)
|
|
|
|
# We could technically handle this case, but it would be quite complicated.
|
|
# If tiling dimensions would have to be expanded into multiple, we'd have to
|
|
# adjust the dimension indices in layouts, including expanding some of them
|
|
# into multiple indices. Note that for non-tiling dims, we allow the shape
|
|
# to be arbitrary, which is why we fix it up below in mem_idx_to_reg_idx.
|
|
if any(
|
|
len(dim_shape) != 1 for dim_shape in tiled_nested_shape[-layout.tiled_tiling_rank :]
|
|
):
|
|
raise NotImplementedError("Memory and register tiling incompatible")
|
|
tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape))
|
|
elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides))
|
|
elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims]
|
|
lane_shape = [tiled_shape[d] for d in layout.lane_dims]
|
|
if elem_tiled_strides[layout.vector_dim] != 1:
|
|
raise ValueError("Stride of the vectorized dimension should be 1")
|
|
for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim):
|
|
tiled_shape[d] = 1
|
|
|
|
element_bits = mgpu.bitwidth(dtype)
|
|
if (layout.vector_length * element_bits) % 8 != 0:
|
|
raise ValueError(
|
|
f"Vector length ({layout.vector_length}) must be a multiple of bytes,"
|
|
f" but has {layout.vector_length * element_bits} bits"
|
|
)
|
|
transfer_bytes = (layout.vector_length * element_bits) // 8
|
|
# Not sure if this is strictly required for all data types, but it certainly
|
|
# is for sub-byte types (else we might not increment the pointer by whole bytes).
|
|
if any(
|
|
s % layout.vector_length and i != layout.vector_dim and d != 1
|
|
for i, (s, d) in enumerate_negative(
|
|
list(zip(elem_tiled_strides, tiled_shape))
|
|
)
|
|
):
|
|
raise ValueError(
|
|
"Tiled strides must be a multiple of the vector length, except for the"
|
|
" vector dimension"
|
|
)
|
|
|
|
if swizzle not in {16, 32, 64, 128}:
|
|
raise ValueError("Only swizzled transfers supported")
|
|
# We will be computing the offsets in units of vectors, not elements,
|
|
# to better support sub-byte types.
|
|
swizzle_tile_transfers = 16 // transfer_bytes
|
|
swizzle_group_transfers = 128 // transfer_bytes
|
|
swizzle_groups_per_block = swizzle // 16
|
|
swizzle_block_transfers = swizzle_groups_per_block * swizzle_group_transfers
|
|
# Technically we should keep the vector_dim set to 1, but its shape is 1
|
|
# so it does not matter.
|
|
transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides]
|
|
transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype)
|
|
|
|
plan = plan_tiled_transfer(
|
|
tiled_shape, elem_tiled_strides, lane_shape, elem_lane_strides, layout,
|
|
element_bits, swizzle
|
|
)
|
|
|
|
# All offsets are in units of transfer_dtype.
|
|
dyn_tiled_strides = [
|
|
c(s) for s in transfer_tiled_strides[-layout.tiled_tiling_rank :]
|
|
]
|
|
lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides)
|
|
warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides)
|
|
dyn_offset = arith.addi(lane_offset, warp_offset)
|
|
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
|
|
raise ValueError("Tiled stores can be performed into SMEM")
|
|
ptr = utils.memref_ptr(ref, memory_space=3)
|
|
_as_consts = lambda consts: [c(const) for const in consts.tolist()]
|
|
# This has bits set only for the offset bits that influence swizzling.
|
|
swizzle_mask = swizzle_block_transfers - swizzle_tile_transfers
|
|
for tile_idx in np.ndindex(*tiled_shape):
|
|
indices = np.asarray([f(tile_idx) for f in plan.tile_index_transforms])
|
|
const_offset = np.dot(indices, transfer_tiled_strides)
|
|
# We split the offset into a part that interacts with swizzling and a
|
|
# part that doesn't. This lets us generate better code because constant
|
|
# offsets can be fused into load and store instructions.
|
|
const_offset_swizzle = const_offset & swizzle_mask
|
|
const_offset_no_swizzle = const_offset - const_offset_swizzle
|
|
offset_pre_swizzle = arith.addi(
|
|
dyn_offset, plan.select(_as_consts(const_offset_swizzle))
|
|
)
|
|
swizzle_group = arith.remui(
|
|
arith.divui(offset_pre_swizzle, c(swizzle_group_transfers)),
|
|
c(swizzle_groups_per_block),
|
|
)
|
|
swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_transfers))
|
|
offset = arith.xori(offset_pre_swizzle, swizzle_bits)
|
|
reg_ptr = utils.getelementptr(ptr, [offset], transfer_dtype)
|
|
offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle))
|
|
reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], transfer_dtype)
|
|
# Here, registers are organized in an array with shape obtained by tiling
|
|
# the logical data bounds. But, the reference was tiled and so each
|
|
# logical tiled dimension can map to multiple dims in tiled_shape.
|
|
# The transform below maps this potentially higher-rank representation
|
|
# back to the lower-rank representation used by the register arrays.
|
|
def mem_idx_to_reg_idx(idx):
|
|
reg_tiled_idx = []
|
|
base_idx = 0
|
|
for dim_shape in tiled_nested_shape[:ref_logical_rank]:
|
|
dim_strides = utils.get_contiguous_strides(dim_shape)
|
|
dim_idxs = idx[base_idx:base_idx + len(dim_shape)]
|
|
base_idx += len(dim_shape)
|
|
reg_tiled_idx.append(sum(i * s for i, s in zip(dim_idxs, dim_strides)))
|
|
# We should have fixed up all but the tiling dims.
|
|
assert base_idx == len(idx) - layout.tiled_tiling_rank
|
|
return (*reg_tiled_idx, *idx[base_idx:])
|
|
reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()]
|
|
def get_register(regs, reg_idxs=reg_idxs):
|
|
return plan.select([regs[reg_idx] for reg_idx in reg_idxs])
|
|
def update_registers(regs, new, reg_idxs=reg_idxs):
|
|
# TODO(apaszke): If the staggering forms a permutation with a small
|
|
# cycle length, then instead of blending at each step we could construct
|
|
# a small routing network (kind of like a sorting network) to fix up
|
|
# each cycle separately after all the loads are performed.
|
|
# This would be especially useful for dims that are powers of two and
|
|
# staggered by another power of 2, since all cycles are of length 2 (and
|
|
# we could save half the selects).
|
|
for i, reg_idx in enumerate(reg_idxs):
|
|
regs[reg_idx] = plan.select_if_group(i, regs[reg_idx], new)
|
|
yield get_register, update_registers, reg_ptr
|
|
|
|
def tree_flatten(self):
|
|
aux = self.layout, self.registers.shape, self.is_signed
|
|
return list(self.registers.flat), aux
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux, flat_registers):
|
|
layout, reg_shape, is_signed = aux
|
|
registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape)
|
|
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
|
|
|
|
|
|
class TransferPlan(Protocol):
|
|
IndexTransform = Callable[[tuple[int, ...]], tuple[int, ...]]
|
|
tile_index_transforms: tuple[IndexTransform, ...]
|
|
|
|
def select(self, group_elems: Sequence[ir.Value]) -> ir.Value:
|
|
"""Selects the value corresponding to the group of the current thread.
|
|
|
|
The argument must be of the same length as tile_index_transforms.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value:
|
|
"""Returns `new` if the current thread belongs to the given group and `old` otherwise.
|
|
|
|
group_idx must be between 0 and len(tile_index_transforms) - 1.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TrivialTransferPlan(TransferPlan):
|
|
@property
|
|
def tile_index_transforms(self):
|
|
return (lambda x: x,)
|
|
|
|
def select(self, group_elems: Sequence[ir.Value]) -> ir.Value:
|
|
assert len(group_elems) == 1
|
|
return group_elems[0]
|
|
|
|
def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value:
|
|
assert group_idx == 0
|
|
return new
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class StaggeredTransferPlan(TransferPlan):
|
|
stagger: int
|
|
dim: int
|
|
size: int
|
|
group_pred: ir.Value
|
|
|
|
@property
|
|
def tile_index_transforms(self):
|
|
dim = self.dim
|
|
def rotate(idx: tuple[int, ...]) -> tuple[int, ...]:
|
|
return (
|
|
*idx[:dim], (idx[dim] + self.stagger) % self.size, *idx[dim + 1 :],
|
|
)
|
|
return (lambda x: x, rotate)
|
|
|
|
def select(self, group_elems: Sequence[ir.Value]) -> ir.Value:
|
|
assert len(group_elems) == 2
|
|
return arith.select(self.group_pred, group_elems[1], group_elems[0])
|
|
|
|
def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value:
|
|
assert 0 <= group_idx <= 1
|
|
sides = [old, new] if group_idx == 0 else [new, old]
|
|
return arith.select(self.group_pred, *sides)
|
|
|
|
|
|
def plan_tiled_transfer(
|
|
tiled_shape: Sequence[int],
|
|
tiled_strides: Sequence[int],
|
|
lane_shape: Sequence[int],
|
|
lane_strides: Sequence[int],
|
|
layout: TiledLayout,
|
|
element_bits: int,
|
|
swizzle: int,
|
|
) -> TransferPlan:
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
c = lambda x: arith.constant(i32, x)
|
|
# TODO(apaszke): Rewrite this function in terms of transfer_bytes (that we get
|
|
# from the caller).
|
|
swizzle_tile_elems = (16 * 8) // element_bits
|
|
swizzle_group_elems = (128 * 8) // element_bits
|
|
# Should be checked at the call site.
|
|
assert layout.vector_length * element_bits % 8 == 0
|
|
transfer_bytes = (layout.vector_length * element_bits) // 8
|
|
# Below, all calculations are in elements, not in bytes, since it should
|
|
# generalize better to sub-byte types.
|
|
# Here, we verify two conditions:
|
|
# 1. Each vector transfer only accesses addresses that fall within a single
|
|
# swizzle tile (if not we'd need to split it and swizzle parts differently).
|
|
transfer_alignment = math.gcd(*(
|
|
s
|
|
for i, (s, d) in enumerate_negative(list(zip(tiled_strides, tiled_shape)))
|
|
if d > 1 or i in {layout.warp_dim, *layout.lane_dims}
|
|
))
|
|
if (
|
|
swizzle_tile_elems % transfer_alignment
|
|
and layout.vector_length <= transfer_alignment
|
|
):
|
|
raise ValueError(
|
|
"Failed to prove that vector transfers don't cross swizzle tile"
|
|
" boundaries. This check is incomplete, and does not guarantee that"
|
|
" this is a user error, but it might be." + str(transfer_alignment)
|
|
)
|
|
|
|
# 2. The transfer pattern does not cause bank conflicts.
|
|
# TODO(apaszke): For now, when performing transfers narrower than a bank,
|
|
# we simply narrow each bank to the transfer width. The truth is more likely
|
|
# that bank conflicts only don't occur if the addresses mapping to the same
|
|
# bank are contiguous, but that's a more complicated check to perform.
|
|
if transfer_bytes > SMEM_BANK_BYTES * 4:
|
|
raise NotImplementedError
|
|
if element_bits > SMEM_BANK_BYTES * 8:
|
|
raise NotImplementedError
|
|
smem_bank_bytes = min(SMEM_BANK_BYTES, transfer_bytes)
|
|
num_banks = SMEM_BANKS * (SMEM_BANK_BYTES // smem_bank_bytes)
|
|
elems_per_bank = (smem_bank_bytes * 8) // element_bits
|
|
num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1)
|
|
wavefront_lanes = WARP_SIZE // num_wavefronts
|
|
|
|
lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides)
|
|
def has_bank_conflicts(tile_idx_transform):
|
|
tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape)
|
|
tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims]
|
|
lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims]
|
|
assert lane_tile_idx.shape[1] in {1, WARP_SIZE}
|
|
lane_tile_offsets = np.dot(lane_tile_idx, tiled_strides)
|
|
offsets = lane_tile_offsets + lane_offsets_in_tile # [#tiles, #lanes]
|
|
assert offsets.shape[-1] == WARP_SIZE
|
|
swizzle_groups = (offsets // swizzle_group_elems) % (swizzle // 16)
|
|
swizzle_bits = swizzle_groups * swizzle_tile_elems
|
|
lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks
|
|
wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes)
|
|
# Order of threads within the wavefront is unimportant.
|
|
wavefront_banks = np.sort(wavefront_banks, axis=-1)
|
|
# There are no conflicts if each wavefront only contains unique banks.
|
|
return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1])
|
|
|
|
# We don't need any special treatment if there are no conflicts when each lane
|
|
# transfers the same tile at a time.
|
|
if not has_bank_conflicts(lambda tile_idx: tile_idx):
|
|
return TrivialTransferPlan()
|
|
|
|
# Otherwise, we will try to partition the lanes into two groups and have
|
|
# each group store to different tile. The only tile dimensions that can help
|
|
# us with bank conflicts are those that have multiple elements and a stride
|
|
# that's not a multiple of the number of banks.
|
|
#
|
|
# Note that the code is set up so that we could also consider partitioning
|
|
# the lanes into more groups, but the selects will become more expensive if
|
|
# we do that. It's a possibility we have if we need it.
|
|
candidate_dims = (
|
|
i for i, (s, d) in enumerate(zip(tiled_strides, tiled_shape))
|
|
if d > 1 and s % (SMEM_BANKS * elems_per_bank)
|
|
)
|
|
for dim in candidate_dims:
|
|
for group_stride in (1, 2, 4, 8, 16):
|
|
# We change the group assignment each group_stride lanes.
|
|
lane_id = np.arange(WARP_SIZE)[:, None]
|
|
lane_group = (lane_id // group_stride) % 2
|
|
# We only consider a transformation where the second group stores to a
|
|
# tile that's a constant offset (modulo dim size) from the first one.
|
|
for stagger in range(1, tiled_shape[dim]):
|
|
offset = np.zeros(len(tiled_shape), np.int64)
|
|
offset[dim] = stagger
|
|
transform = lambda idx: (idx + offset * lane_group) % tiled_shape
|
|
if not has_bank_conflicts(transform):
|
|
# We've found a strategy that avoids bank conflicts!
|
|
lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE))
|
|
group_idx = arith.remui(arith.divui(lane_idx, c(group_stride)), c(2))
|
|
group_pred = arith.cmpi(arith.CmpIPredicate.ne, group_idx, c(0))
|
|
return StaggeredTransferPlan(
|
|
stagger, dim, tiled_shape[dim], group_pred
|
|
)
|
|
raise ValueError(
|
|
"Failed to synthesize a transfer pattern that avoids bank conflicts"
|
|
)
|
|
|
|
# We allow contractions, to potentially take advantage of FMA instructions.
|
|
# They can change the results, but the precision should only increase.
|
|
def addf(a: ir.Value, b: ir.Value):
|
|
return arith.addf(a, b, fastmath=arith.FastMathFlags.contract)
|
|
|
|
def subf(a: ir.Value, b: ir.Value):
|
|
return arith.subf(a, b, fastmath=arith.FastMathFlags.contract)
|
|
|
|
def mulf(a: ir.Value, b: ir.Value):
|
|
return arith.mulf(a, b, fastmath=arith.FastMathFlags.contract)
|
|
|
|
|
|
def optimization_barrier(*arrays: mgpu.FragmentedArray):
|
|
"""Acts as an optimization barrier for LLVM.
|
|
|
|
Passing arrays through this function will make sure that they are computed
|
|
before any side-effecting operations that follow this barrier.
|
|
"""
|
|
index = ir.IndexType.get()
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
|
|
regs = []
|
|
reg_dtypes = []
|
|
reg_constraints = []
|
|
repack_fns = []
|
|
# We unpack each array into a flat list of registers, and prepare the
|
|
# functions that invert the transform in repack_fns.
|
|
for array in arrays:
|
|
reg_ty = array.registers.flat[0].type
|
|
dtype = array.mlir_dtype
|
|
if ir.F32Type.isinstance(dtype):
|
|
if ir.VectorType.isinstance(reg_ty):
|
|
[vec_len] = ir.VectorType(reg_ty).shape
|
|
array_regs = [ # pylint: disable=g-complex-comprehension
|
|
vector.extractelement(reg, position=c(pos, index))
|
|
for reg in array.registers.flat
|
|
for pos in range(vec_len)
|
|
]
|
|
def _repack(regs, reg_ty=reg_ty):
|
|
reg = llvm.mlir_undef(reg_ty)
|
|
[vec_len] = ir.VectorType(reg_ty).shape
|
|
for i_elem in range(vec_len):
|
|
reg = llvm.insertelement(
|
|
reg, next(regs), arith.constant(i32, i_elem)
|
|
)
|
|
return reg
|
|
repack_fns.append(_repack)
|
|
else:
|
|
array_regs = list(array.registers.flat)
|
|
repack_fns.append(lambda regs: next(regs))
|
|
reg_constraint = "f"
|
|
elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype):
|
|
if not ir.VectorType.isinstance(reg_ty):
|
|
raise NotImplementedError(array.mlir_dtype)
|
|
[vec_len] = ir.VectorType(reg_ty).shape
|
|
if vec_len != 2:
|
|
raise NotImplementedError(vec_len)
|
|
i32_reg_ty = ir.VectorType.get((1,), i32)
|
|
array_regs = [
|
|
vector.extractelement(
|
|
vector.bitcast(i32_reg_ty, reg), position=c(0, index)
|
|
)
|
|
for reg in array.registers.flat
|
|
]
|
|
reg_constraint = "r"
|
|
def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty):
|
|
return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs)))
|
|
repack_fns.append(_repack)
|
|
else:
|
|
raise NotImplementedError(array.mlir_dtype)
|
|
regs += array_regs
|
|
reg_dtypes += [array_regs[0].type] * len(array_regs)
|
|
reg_constraints += [reg_constraint] * len(array_regs)
|
|
ptx_lines = [
|
|
f"mov.b32 ${i}, ${len(reg_constraints)+i}"
|
|
for i in range(len(reg_constraints))
|
|
]
|
|
ptx = ";\n\t".join(ptx_lines) + ";"
|
|
all_reg_constraints = ",".join(
|
|
[*("=" + c for c in reg_constraints), *reg_constraints]
|
|
)
|
|
struct_ty = ir.Type.parse(
|
|
f"!llvm.struct<({','.join(map(str, reg_dtypes))})>"
|
|
)
|
|
result_struct = llvm.inline_asm(
|
|
struct_ty, regs, ptx, all_reg_constraints,
|
|
asm_dialect=0, has_side_effects=True,
|
|
)
|
|
regs = [
|
|
llvm.extractvalue(dtype, result_struct, [i])
|
|
for i, dtype in enumerate(reg_dtypes)
|
|
]
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
results = []
|
|
regs_it = iter(regs)
|
|
for array, repack_fn in zip(arrays, repack_fns, strict=True):
|
|
num_regs = array.registers.size
|
|
reg_ty = array.registers.flat[0].type
|
|
if ir.VectorType.isinstance(reg_ty):
|
|
reg_ty = ir.VectorType(reg_ty)
|
|
new_registers = np.empty((num_regs,), dtype=object)
|
|
for i_vreg in range(num_regs):
|
|
reg = repack_fn(regs_it)
|
|
assert reg.type == reg_ty, (reg.type, reg_ty)
|
|
new_registers[i_vreg] = reg
|
|
results.append(
|
|
FragmentedArray(
|
|
_registers=new_registers.reshape(array.registers.shape),
|
|
_layout=array.layout,
|
|
_is_signed=array.is_signed,
|
|
)
|
|
)
|
|
return results[0] if len(arrays) == 1 else results
|