# Copyright 2024 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for code generator.""" from __future__ import annotations import dataclasses import functools import math from collections.abc import Callable from typing import Iterable, Protocol, Sequence, TypeVar import itertools import jax from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import math as mlir_math from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import vector import numpy as np import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors T = TypeVar("T") WARPGROUP_SIZE = utils.WARPGROUP_SIZE WARP_SIZE = 32 WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE SMEM_BANKS = 32 SMEM_BANK_BYTES = 4 c = utils.c @dataclasses.dataclass(frozen=True) class Tiling: """A tiling expression describing a permutation of elements of an nd-array. To apply one level of tiling to an array, each of the trailing dimensions (up to the rank of the tile) is unfolded into two dimensions: first equal to the ratio of the dimension size and the tile size, and second equal to the tile size. Then, all newly unfolded minor dimensions are transposed to appear at the end. This expression describes multi-level tiling, by applying each element of `tiles` in sequence to the array. See https://openxla.org/xla/tiled_layout for a more detailed explanation. """ tiles: tuple[tuple[int, ...], ...] def __post_init__(self): if not self.tiles: return tiled_rank = len(self.tiles[0]) for tile in self.tiles: if len(tile) > tiled_rank: raise ValueError("Only the first tile can refer to value dimensions") if not tile: raise ValueError("Tiles must not be empty") if any(d <= 0 for d in tile): raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}") tiled_rank += len(tile) def __str__(self): return f"Tiling({''.join(map(str, self.tiles))})" def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Computes the shape of an array after tiling.""" def fail(): raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") for tile in self.tiles: if len(tile) > len(shape): fail() untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):] if any(s % t != 0 for s, t in zip(tiled_dims, tile)): fail() shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile) return shape def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Computes the shape of an array before tiling from its tiled shape.""" def fail(): raise ValueError( f"shape {shape} is not a valid result of applying tiling {self}." ) for tile in reversed(self.tiles): if len(tile) > len(shape): fail() untiled_dims = shape[:-2 * len(tile)] tiled_dims = shape[-2 * len(tile):-len(tile)] tiling_dims = shape[-len(tile):] if tiling_dims != tile: fail() shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile))) return shape def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: """Computes the strides of an array after tiling.""" for tile in self.tiles: untiled, tiled = strides[:-len(tile)], strides[-len(tile):] strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) return strides def tile_nested_shape_strides( self, shape: tuple[tuple[int, ...], ...], strides: tuple[tuple[int, ...], ...], ) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]: """A fused version of `tile_shape` and `tile_strides` for nested shapes. By nested shape we mean that each logical dimension (i.e. each element of shape/strides) is actually composed out of multiple physical dimensions. For example, a row-major array of logical shape (128, 128) that is tiled into (64, 64) tiles would have a nested shape ((2, 64), (2, 64)) (i.e. each dim is split into two sub-dims) and nested strides of ((2 * 64 * 64, 64), (64 * 64, 1)). """ if len(shape) != len(strides): raise ValueError( f"Shape {shape} and strides {strides} must have the same length" ) def fail_if(cond, shape=shape): # Capture shape now. if cond: raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") for tile in self.tiles: fail_if(len(tile) > len(shape)) untiled_shape, tiled_shape = shape[:-len(tile)], shape[-len(tile):] untiled_strides, tiled_strides = strides[:-len(tile)], strides[-len(tile):] major_dim_shapes, major_dim_strides = [], [] minor_dim_shapes, minor_dim_strides = [], [] for t, dim_shape, dim_strides in zip(tile, tiled_shape, tiled_strides): major_dim_shape_rev, major_dim_stride_rev = [], [] minor_dim_shape_rev, minor_dim_stride_rev = [], [] for d, s in zip(reversed(dim_shape), reversed(dim_strides), strict=True): if d < t: # We will need to tile more dims fail_if(t % d != 0) t //= d minor_dim_shape_rev.append(d) minor_dim_stride_rev.append(s) elif t != 1: # Last dim to tile! fail_if(d % t != 0) minor_dim_shape_rev.append(t) minor_dim_stride_rev.append(s) if d != t: # No need to insert singleton dims. major_dim_shape_rev.append(d // t) major_dim_stride_rev.append(s * t) t = 1 else: # Done tiling! major_dim_shape_rev.append(d) major_dim_stride_rev.append(s) fail_if(t != 1) major_dim_shapes.append(major_dim_shape_rev[::-1]) minor_dim_shapes.append(minor_dim_shape_rev[::-1]) major_dim_strides.append(major_dim_stride_rev[::-1]) minor_dim_strides.append(minor_dim_stride_rev[::-1]) shape = (*untiled_shape, *major_dim_shapes, *minor_dim_shapes) strides = (*untiled_strides, *major_dim_strides, *minor_dim_strides) return ( tuple(tuple(d) if d else (1,) for d in shape), tuple(tuple(d) if d else (1,) for d in strides), ) def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: for tile in self.tiles: untiled, tiled = indices[:-len(tile)], indices[-len(tile):] indices = ( *untiled, *(i // t for i, t in zip(tiled, tile)), *(i % t for i, t in zip(tiled, tile)), ) return indices def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: for tile in reversed(self.tiles): untiled = indices[:-2 * len(tile)] outer = indices[-2 * len(tile):-len(tile)] inner = indices[-len(tile):] indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile))) return indices def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: """Like built-in enumerate, but returns negative indices into the sequence.""" offset = len(elems) for i, e in enumerate(elems): yield i - offset, e @dataclasses.dataclass(frozen=True) class TiledLayout: """A FragmentedArray layout derived from a tiling expression. A logical array is transformed according to the tiling expression, and then split across warps (within a warpgroup), lanes, and vectorized according to the dimension indices. All dimension indices must be negative and should refer to the dimensions after tiling is applied. Note that warp_dim and vector_dim could be sets as well, but we don't have a usecase for that yet. To better understand this layout, consider the example of WGMMA-related tiling from https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d as applied to a 128x128 array. The corresponding TiledLayout has a tiling of: (64, 8)(16, 8)(8, 8)(1, 2) and warp_dim=-8, lane_dims=(-4, -3), vector_dim=-1. We begin by applying the tiling (note that it always applies to a suffix): Tiled shape Remaining tiling actions =========================================================================== 128 128 (64, 8)(16, 8)(8, 8)(1, 2) 2 16 64 8 (16, 8)(8, 8)(1, 2) 2 16 4 1 16 8 (8, 8)(1, 2) 2 16 4 1 2 1 8 8 (1, 2) 2 16 4 1 2 1 8 4 1 2 The last expression is our final shape. At this stage, we're ready to interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the end is partitioned over 4 warps in a warpgroup (and so it must be of size 4). lane_dims=(-4, -3) indicate that those two dimensions are partitioned over the lanes within a warp (their product must be equal to 32, i.e. warp size). Finally, vector_dim=-1 indicates that each (logical) register is a vector containing 2 elements (there are no shape restrictions here). Given the above, the shape of the (logical) register array used to represent the array in each thread is: (2, 16, 1, 1, 2, 1, 1, 1, 1, 1). We have set all the dimensions above to 1, since each thread is a member of a single warp, a single lane, and the elements along the vectorized dimension are represented by a single (logical) register. """ tiling: Tiling warp_dim: int lane_dims: tuple[int, ...] # major-to-minor vector_dim: int def __post_init__(self): if not self.tiling.tiles: raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] min_tiled_shape = self.tiling.tile_shape(min_shape) dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim} if len(dims_set) != len(self.lane_dims) + 2: raise ValueError for d in dims_set: if d >= 0: raise ValueError("All dimensions must be negative") if d < -(len(min_tiled_shape) - len(min_shape)): raise ValueError("Dimension out of range") if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: raise ValueError if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: raise ValueError def thread_idxs(self, shape: tuple[int, ...]) -> Iterable[tuple[ir.Value, ...]]: # We first find the linear index and then divide by the shape to # get the index. i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() contig_strides = utils.get_contiguous_strides(shape) tile_strides = self.tiling.tile_strides(contig_strides) dyn_tile_strides = [c(s, i32) for s in tile_strides[-self.tiled_tiling_rank:]] warp_offset = utils.dyn_dot(self.warp_indices(), dyn_tile_strides) lane_offset = utils.dyn_dot(self.lane_indices(), dyn_tile_strides) dyn_offset = arith.addi(warp_offset, lane_offset) register_shape = self.registers_shape(shape) for tile_idx in np.ndindex(register_shape): tile_lin_idx = sum(i * s for i, s in zip(tile_idx, tile_strides)) dyn_lin_idx = arith.addi(dyn_offset, c(tile_lin_idx, i32)) idx = [] for stride in contig_strides: idx.append(arith.index_castui(index, arith.divui(dyn_lin_idx, c(stride, i32)))) dyn_lin_idx = arith.remui(dyn_lin_idx, c(stride, i32)) yield tuple(idx) @property def base_tile_shape(self) -> int: """The shape of the first tile in the tiling expression. This tile acts as the divisibility constraint for a suffix of arrays to which this layout applies. """ return self.tiling.tiles[0] @functools.cached_property def tiled_tiling_shape(self) -> tuple[int, ...]: """The shape of the suffix of the array after tiling. We only allow our repeated tiling actions to further subdivide the dimensions created by previous tiling actions (except for the first one), so the tiled shape always ends with this suffix, no matter what array shape it's applied to. """ base_tile_shape = self.base_tile_shape return self.tiling.tile_shape(base_tile_shape)[len(base_tile_shape):] @functools.cached_property def tiled_tiling_rank(self) -> int: return len(self.tiled_tiling_shape) @property def vector_length(self) -> int: return self.tiled_tiling_shape[self.vector_dim] def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) tiled_shape[self.warp_dim] = 1 for d in self.lane_dims: tiled_shape[d] = 1 tiled_shape[self.vector_dim] = 1 return tuple(tiled_shape) def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the logical shape of an array given its register array shape. Inverse to `registers_shape`. """ tiled_tiling = self.tiled_tiling_shape shape = list(shape) shape[self.warp_dim] = WARPS_IN_WARPGROUP for d in self.lane_dims: shape[d] = tiled_tiling[d] shape[self.vector_dim] = tiled_tiling[self.vector_dim] return self.tiling.untile_shape(tuple(shape)) def lane_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape = self.tiled_tiling_shape lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims) assert math.prod(lanes_shape) == WARP_SIZE lane_strides = utils.get_contiguous_strides(lanes_shape) lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) lane_indices = tuple( arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) for stride, size in zip(lane_strides, lanes_shape) ) full_indices = [arith.constant(i32, 0)] * len(tiled_shape) for d, i in zip(self.lane_dims, lane_indices): full_indices[d] = i return tuple(full_indices) def warp_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape_rank = len(self.tiled_tiling_shape) warp_idx = arith.remui( arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), c(WARPS_IN_WARPGROUP, i32), ) indices = [arith.constant(i32, 0)] * tiled_shape_rank indices[self.warp_dim] = warp_idx return tuple(indices) def _tiled_wgmma_layout(shape: tuple[int, ...]): """Returns the tiled layout relevant for WGMMA operations. The tiled layout is equivalent to one described here in PTX documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d """ if len(shape) != 2: raise ValueError(f"Shape {shape} is not 2D") if shape[0] % 64 != 0 or shape[1] % 8 != 0: raise ValueError(f"Shape {shape} is not a multiple of 64x8") return WGMMA_LAYOUT @dataclasses.dataclass(frozen=True) class WGMMARowFragLayout: """[m] matrix, where m % 64 == 0.""" def thread_idxs(self, shape): raise NotImplementedError @dataclasses.dataclass(frozen=True) class WGSplatFragLayout: """A fragmented array where all the values are equal represented as a register per thread. FragmentedArrays in this layout can be are always the result of a splat, each thread in the warpgroup has a single copy of the value, while the FragmentedArray pretends it has whatever shape the user wants. This means we can trivially broadcast, reshape and do elementwise operations with all other layouts. Examples: To load a value in ``` FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2)) ``` A shape is always provided for sanity check reasons. """ shape: tuple[int, ...] = () def can_broadcast_to(self, shape) -> bool: """Check that the shape can be broadcast. Only dimensions of size 1 can be broadcast. All other dimensions must be the same as the argument shape. """ return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) def thread_idxs(self, shape): assert shape == self.shape raise NotImplementedError @dataclasses.dataclass(frozen=True) class WGStridedFragLayout: """Convert the array to 1D and then shard across threads.""" shape: tuple[int, ...] vec_size: int def __post_init__(self): if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0: raise ValueError((self, WARPGROUP_SIZE)) @classmethod def from_shaped_type(cls, shaped_ty: ir.Type): if not ir.ShapedType.isinstance(shaped_ty): raise TypeError(shaped_ty) shaped_ty = ir.ShapedType(shaped_ty) bw = mgpu.bytewidth(shaped_ty.element_type) assert 8 % bw == 0 and 8 // bw != 0, bw if math.prod(shaped_ty.shape) % WARPGROUP_SIZE != 0: raise ValueError( f"{shaped_ty} must have a number of elements that is a multiple of" f" {WARPGROUP_SIZE} (got {math.prod(shaped_ty.shape)})" ) max_vec_size = np.prod(shaped_ty.shape) // WARPGROUP_SIZE return cls( shape=tuple(shaped_ty.shape), vec_size=min(8 // bw, max_vec_size) ) def thread_idxs(self, shape): assert shape == self.shape index = ir.IndexType.get() for v in self.linear_thread_idxs(): res = [] for dim in reversed(self.shape): dim = c(dim, index) res.append(arith.remui(v, dim)) v = arith.divui(v, dim) res.reverse() yield res def linear_thread_idxs(self): """The indexes to be used for vector load/store WGStridedFragLayout. Yields: The indices of the vector that correspond to the current thread. """ index = ir.IndexType.get() cardinality = np.prod(self.shape) assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0 reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size) tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE, index)) off = arith.muli(tidx, c(self.vec_size, tidx.type)) for i in range(reg_num): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | TiledLayout WGMMA_ROW_LAYOUT = WGMMARowFragLayout() # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d # In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles. # Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit # of data that is split across a warp. Since 8*8 = 64, but a warp has only 32 # threads, we vectorize pairs of elements along columns. # The assignment of elements to warp lanes is as follows: # # 0 0 1 1 2 2 3 3 # 4 4 5 5 6 6 7 7 # 8 8 9 9 10 10 11 11 # 12 12 13 13 14 14 15 15 # ... WGMMA_LAYOUT = TiledLayout( Tiling(((64, 8), (16, 8), (8, 8), (1, 2))), warp_dim=-8, lane_dims=(-4, -3), vector_dim=-1, ) # This tiled layout is similar to the WGMMA layout, only the unit at which we # assign submatrices to warps grows from 8x8 to 8x16. The elements within each # submatrix are assigned to threads in the following way: # # 0 0 0 0 2 2 2 2 1 1 1 1 3 3 3 3 # 4 4 4 4 6 6 6 6 5 5 5 5 7 7 7 7 # ... # # Our vector length is twice the size of that of WGMMA_LAYOUT, which lets us use # 32-bit SMEM loads/stores when dealing with 8-bit values. The conversion # to the WGMMA layout only requires communication between with index differing # in their 2 bit (i.e. 0 and 1, 2 and 4), so the conversion to WGMMA_LAYOUT # only requires a single warp shuffle (plus permutes local to each thread). WGMMA_LAYOUT_UPCAST_2X = TiledLayout( Tiling(((64, 16), (16, 16), (8, 16), (8,), (4,))), warp_dim=-8, lane_dims=(-4, -2, -3), vector_dim=-1, ) # This layout should be used when upcasting 4-bit elements to 16-bit, for the # purpose of passing them into WGMMA later. The core matrices stored by a warp # are 8x32, because each of the 4 threads in a row holds 8 elements in a single # vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each # group of 4 threads in order (as opposed to the swapping between 1 and 2, # 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does). WGMMA_LAYOUT_UPCAST_4X = TiledLayout( Tiling(((64, 32), (16, 32), (8, 32), (8,))), warp_dim=-7, lane_dims=(-3, -2), vector_dim=-1, ) # This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8 # submatrix in the following way (we only show the first 4 rows for brevity): # # 0 0 1 1 2 2 3 3 # 4 4 5 5 6 6 7 7 # 8 8 9 9 10 10 11 11 # 12 12 13 13 14 14 15 15 # ... # # This tiled layout stores the same 8x8 submatrix in the following way: # # 0 4 1 5 2 6 3 7 # 0 4 1 5 2 6 3 7 # 8 12 9 13 10 14 11 15 # 8 12 9 13 10 14 11 15 # ... # # You can see that we have taken 2x2 submatrices from the above layout and # transposed them. The assigment of lanes to elements is such that in both # layouts the same two lanes map to a single 2x2 submatrix, making the transpose # very cheap (one shuffle and permute suffices to change between those layouts). WGMMA_TRANSPOSED_LAYOUT = TiledLayout( Tiling(((64, 8), (16, 8), (8, 8), (2, 2), (2, 1))), warp_dim=-10, lane_dims=(-6, -3, -5), vector_dim=-2, ) @jax.tree_util.register_pytree_node_class @dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True) class FragmentedArray: # An array of ir.Value, see checks in init for shapes. registers: np.ndarray = dataclasses.field(repr=False) layout: FragmentedLayout is_signed: bool | None def __init__( self, *, _registers: np.ndarray, _layout: FragmentedLayout, _is_signed: bool | None, ): """Initializes a fragmented array. This is a low-level API. Prefer using classmethods to construct fragmented arrays instead. """ # We need to use ``object.__setattr__`` here because of ``frozen=True``. object.__setattr__(self, "registers", _registers) object.__setattr__(self, "layout", _layout) object.__setattr__(self, "is_signed", _is_signed) if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype): raise TypeError( "is_signed must be non-None if and only if the MLIR type is an" f" integer type, got {_is_signed=} for {self.mlir_dtype}" ) match self.layout: # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout # Each element is a dtype scalar case WGMMARowFragLayout(): if _registers.ndim != 2 or _registers.shape[-1] != 2: raise ValueError(f"Invalid register array shape: {_registers.shape}") # Registers are flat case WGStridedFragLayout(shape): [reg_size] = ir.VectorType(_registers.flat[0].type).shape if ( math.prod(shape) != math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size ): raise ValueError( "Invalid register array shape: math.prod({_registers.shape}) *" " {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" ) # Just a single register case WGSplatFragLayout(): if _registers.size != 1: raise ValueError(f"Invalid register array shape: {_registers.shape}") case TiledLayout(): try: self.layout.shape_from_registers_shape(_registers.shape) except ValueError: raise ValueError( "Register array shape does not match the tiled layout" ) from None case _: raise NotImplementedError @classmethod def load_strided( cls, ref: ir.Value, *, is_signed: bool | None = None, vec_size: int | None = None, ): if not ir.MemRefType.isinstance(ref.type): raise TypeError(ref.type) ref_ty = ir.MemRefType(ref.type) shape = tuple(ref_ty.shape) if vec_size is None: layout = WGStridedFragLayout.from_shaped_type(ref_ty) else: layout = WGStridedFragLayout(shape=shape, vec_size=vec_size) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) try: # Flattening the reference potentially produces simpler PTX but # if the ref is not already 1D and has strided dimensions # flattening won't work. ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_idxs()] except NotImplementedError: vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: case WGMMARowFragLayout(): if len(shape) != 1: raise ValueError("WGMMARowFragLayout requires a 1D shape") if shape[0] % 64: raise ValueError( "WGMMARowFragLayout requires shape[0] to be a multiple of 64" ) reg_shape = (shape[0] // 64, 2) case WGStridedFragLayout(vec_size=vec_size): assert shape == layout.shape elems = np.prod(shape) reg_shape = (elems // (WARPGROUP_SIZE * vec_size),) value = vector.splat(ir.VectorType.get((vec_size,), value.type), value) case WGSplatFragLayout(): assert shape == layout.shape reg_shape = () case TiledLayout(): value = vector.splat(ir.VectorType.get((layout.vector_length,), value.type), value) reg_shape = layout.registers_shape(shape) case _: raise NotImplementedError(layout) return cls( _registers=np.full(reg_shape, value, dtype=object), _layout=layout, _is_signed=is_signed, ) @property def shape(self): match self.layout: case WGMMARowFragLayout(): row_tiles = self.registers.shape[0] return (row_tiles * 64,) case WGStridedFragLayout(shape): return shape case WGSplatFragLayout(shape=shape): return shape case TiledLayout(): return self.layout.shape_from_registers_shape(self.registers.shape) case _: raise NotImplementedError @property def mlir_dtype(self): reg_ty = self.registers.flat[0].type match self.layout: case WGStridedFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type case WGMMARowFragLayout() | WGSplatFragLayout(): return reg_ty case _: raise NotImplementedError def to_layout(self, new_layout: FragmentedLayout): """Converts the fragmented array to the given layout. At the moment, only conversions from ``WGSplatFragLayout`` are supported. """ i32 = ir.IntegerType.get_signless(32) c = lambda x: arith.constant(i32, x) if self.layout == new_layout: return self shape = self.shape if ( self.layout == WGMMA_LAYOUT and new_layout == WGMMA_TRANSPOSED_LAYOUT and utils.bitwidth(self.mlir_dtype) == 16 ): is_even_row = arith.cmpi( arith.CmpIPredicate.eq, arith.remui(arith.divui(utils.thread_idx(), c(4)), c(2)), c(0), ) perm = arith.select(is_even_row, c(0x5410), c(0x3276)) new_regs = [] for reg in self.registers.flat: reg_ty = reg.type reg = utils.bitcast(reg, i32) reg_shfl = utils.shfl_bfly(reg, 4) new_reg = utils.prmt(reg, reg_shfl, perm) new_regs.append(utils.bitcast(new_reg, reg_ty)) return FragmentedArray( _registers=np.asarray(new_regs, dtype=object).reshape(new_layout.registers_shape(shape)), _layout=new_layout, _is_signed=self.is_signed, ) if ( self.layout == WGMMA_LAYOUT_UPCAST_2X and new_layout == WGMMA_LAYOUT and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16 ): assert shape[1] % 16 == 0 # Should be implied by the layout new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) is_even = arith.cmpi( arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0) ) registers = self.registers if dtype_bitwidth == 4: if registers.shape[1] % 2: raise NotImplementedError( "This relayout implementation requires an even number of column" " tiles (to pack pairs of them for efficiency)" ) # We pair up the consecutive column tiles, so each register is 32-bit. # If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout, # LLVM will realize that the paired up vectors actually came from the # same 32-bit register and it will become a no-op. col_minor_registers = np.moveaxis(registers, 1, -1) flat_registers = [ utils.vector_concat((l, h)) for l, h in zip( col_minor_registers.flat[::2], col_minor_registers.flat[1::2] ) ] registers = np.asarray(flat_registers, dtype=object).reshape( *col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2 ) registers = np.moveaxis(registers, -1, 1) for idx, reg in np.ndenumerate(registers): if dtype_bitwidth == 16: assert reg.type.shape == [4] # A single vector is 64-bits, but shuffles are only 32-bit wide. # We only shuffle the half that needs to go to other thread. low = utils.vector_slice(reg, slice(0, 2)) high = utils.vector_slice(reg, slice(2, 4)) to_exchange = arith.select(is_even, high, low) # Exchange values between even and odd threads. exchanged = utils.shfl_bfly(to_exchange, 1) low = arith.select(is_even, low, exchanged) high = arith.select(is_even, exchanged, high) new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high elif dtype_bitwidth == 8: assert reg.type.shape == [4] # The vector is 32-bits, so we just shuffle the whole thing and # use prmt to blend it with the local register. exchanged = utils.shfl_bfly(reg, 1) # Consider lanes 0 and 1, because the situation is symmetric for # each pair. If we feed reg[lane] and exchanged[lane] (which is # really the same as reg of the other lane) to prmt, we can index # the elements of the result using the following indices: # reg[0]: 0 1 2 3 reg[1]: 8 9 10 11 # prmt[0]: 0 1 2 3 4 5 6 7 # prmt[1]: 4 5 6 7 0 1 2 3 # The expected outputs and their respective permutations are: # out[0]: 0 1 8 9 out[1]: 2 3 10 11 # prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3 # Note that the patterns still need to be flipped, since we listed # bytes with LSB on the left, which is the opposite of how the # numeric constants are spelled in Python (LSB on the right). perm = arith.select(is_even, c(0x5410), c(0x3276)) blend = utils.prmt(reg, exchanged, perm) for i in range(2): reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2)) new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg else: assert dtype_bitwidth == 4 assert reg.type.shape == [8] # We paired up the registers above. exchanged = utils.shfl_bfly(reg, 1) # See comment above for a more complete explanation. # reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27 # prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7-- # prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3-- # The expected outputs and their respective permutations are: # out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27 # prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3-- perm = arith.select(is_even, c(0x6240), c(0x3715)) blend = utils.prmt(reg, exchanged, perm) for i in range(4): reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2)) new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg assert all(r is not None for r in new_registers) return FragmentedArray( _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, ) if ( self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT_UPCAST_2X and utils.bitwidth(self.mlir_dtype) == 4 ): assert shape[0] % 64 == 0 # Should be implied by the layout assert shape[1] % 32 == 0 # Should be implied by the layout new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) i32 = ir.IntegerType.get_signless(32) c = lambda x: arith.constant(i32, x) is_01 = arith.cmpi( arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2) ) for idx, reg in np.ndenumerate(self.registers): assert ir.VectorType(reg.type).shape == [8] # The vector is 32-bits, so we just shuffle the whole thing and # use prmt to blend it with the local register. exchanged = utils.shfl_bfly(reg, 2) # See comments above for conventions. Here we exchange data between # threads with lane index related by flipping 2nd bit (e.g. 0 and 2). # reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23 # prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7-- # prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3-- # The expected outputs and their respective permutations are: # out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23 # prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3-- perm = arith.select(is_01, c(0x5410), c(0x3276)) blend = utils.prmt(reg, exchanged, perm) for i in range(2): reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4)) new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg assert all(r is not None for r in new_registers) return FragmentedArray( _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, ) if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT: return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout) if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError( f"Cannot convert from {self.layout} to {new_layout}" ) [reg] = self.registers.flat return type(self).splat( reg, self.shape, new_layout, is_signed=self.is_signed ) def _pointwise(self, op, *other, output_is_signed: bool | None = None): # If our layout is a splat, then we should either dispatch to a non-splat # layout, or broadcast ourselves to the output shape first. if isinstance(self.layout, WGSplatFragLayout): output_shape = self.shape for i, o in enumerate(other): if not isinstance(o, FragmentedArray): continue elif not isinstance(o.layout, WGSplatFragLayout): return o._pointwise( lambda o, this, *args: op(this, *args[:i], o, *args[i:]), self, *other[:i], *other[i + 1 :], output_is_signed=output_is_signed, ) else: output_shape = np.broadcast_shapes(output_shape, o.shape) # If we get here then we haven't found any non-splat layout. if self.shape != output_shape: return self.broadcast(output_shape)._pointwise( op, *other, output_is_signed=output_is_signed ) other_arrs = [] for o in other: if not isinstance(o, FragmentedArray): if isinstance(o, (float, int)): o = utils.c(o, self.mlir_dtype) elif not isinstance(o, ir.Value): raise NotImplementedError(o) o = FragmentedArray.splat( o, shape=self.shape, layout=self.layout, is_signed=self.is_signed ) if isinstance(o.layout, WGSplatFragLayout): if not o.layout.can_broadcast_to(self.shape): raise ValueError( f"Cannot broadcast shape {self.shape} to layout {o.layout}") o = FragmentedArray.splat( o.registers.flat[0], shape=self.shape, layout=self.layout, is_signed=o.is_signed, ) else: if self.layout != o.layout: raise ValueError("Incompatible FragmentedArray layouts") if self.registers.shape != o.registers.shape: raise ValueError("Incompatible FragmentedArray shapes") other_arrs.append(o) new_regs = np.empty_like(self.registers) for idx, reg in np.ndenumerate(self.registers): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) reg_ty = new_regs.flat[0].type if ir.VectorType.isinstance(reg_ty): reg_ty = ir.VectorType(reg_ty).element_type if output_is_signed is None and ir.IntegerType.isinstance(reg_ty): output_is_signed = self.is_signed return FragmentedArray( _registers=new_regs, _layout=self.layout, _is_signed=output_is_signed ) def __pos__(self): return self def __neg__(self): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(arith.negf) elif ir.IntegerType.isinstance(self.mlir_dtype): return 0 - self else: return NotImplemented def __add__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(addf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.addi, other) else: return NotImplemented def __radd__(self, other): return self + other def __mul__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(mulf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.muli, other) else: return NotImplemented def __rmul__(self, other): return self * other def __sub__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(subf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.subi, other) else: return NotImplemented def __rsub__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(lambda s, o: subf(o, s), other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(lambda s, o: arith.subi(o, s), other) else: return NotImplemented def __truediv__(self, other): if not ir.FloatType.isinstance(self.mlir_dtype): return NotImplemented return self._pointwise(arith.divf, other) def __rtruediv__(self, other): if not ir.FloatType.isinstance(self.mlir_dtype): return NotImplemented return self._pointwise(lambda s, o: arith.divf(o, s), other) def __floordiv__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise( lambda s, o: mlir_math.floor(arith.divf(s, o)), other ) elif ir.IntegerType.isinstance(self.mlir_dtype): if self.is_signed: return self._pointwise(arith.floordivsi, other) else: return self._pointwise(arith.divui, other) else: return NotImplemented def __rfloordiv__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise( lambda s, o: mlir_math.floor(arith.divf(o, s)), other ) elif ir.IntegerType.isinstance(self.mlir_dtype): if self.is_signed: return self._pointwise(lambda s, o: arith.floordivsi(o, s), other) else: return self._pointwise(lambda s, o: arith.divui(o, s), other) else: return NotImplemented def __mod__(self, other): if not ir.IntegerType.isinstance(self.mlir_dtype): return NotImplemented if self.is_signed: return self._pointwise(arith.remsi, other) else: return self._pointwise(arith.remui, other) def __rmod__(self, other): if not ir.IntegerType.isinstance(self.mlir_dtype): return NotImplemented if self.is_signed: return self._pointwise(lambda s, o: arith.remsi(o, s), other) else: return self._pointwise(lambda s, o: arith.remui(o, s), other) def __invert__(self): if not ir.IntegerType.isinstance(self.mlir_dtype): return NotImplemented return self ^ ~0 def __or__(self, other): if not ir.IntegerType.isinstance(self.mlir_dtype): return NotImplemented return self._pointwise(arith.ori, other) def __ror__(self, other): return self | other def __and__(self, other): if not ir.IntegerType.isinstance(self.mlir_dtype): return NotImplemented return self._pointwise(arith.andi, other) def __rand__(self, other): return self & other def __xor__(self, other): if not ir.IntegerType.isinstance(self.mlir_dtype): return NotImplemented return self._pointwise(arith.xori, other) def __rxor__(self, other): return self ^ other def __eq__(self, other): return self._compare( other, f_pred=arith.CmpFPredicate.OEQ, si_pred=arith.CmpIPredicate.eq, ui_pred=arith.CmpIPredicate.eq, ) def __ne__(self, other): return self._compare( other, f_pred=arith.CmpFPredicate.UNE, si_pred=arith.CmpIPredicate.ne, ui_pred=arith.CmpIPredicate.ne, ) def __lt__(self, other): return self._compare( other, f_pred=arith.CmpFPredicate.OLT, si_pred=arith.CmpIPredicate.slt, ui_pred=arith.CmpIPredicate.ult, ) def __le__(self, other): return self._compare( other, f_pred=arith.CmpFPredicate.OLE, si_pred=arith.CmpIPredicate.sle, ui_pred=arith.CmpIPredicate.ule, ) def __gt__(self, other): return self._compare( other, f_pred=arith.CmpFPredicate.OGT, si_pred=arith.CmpIPredicate.sgt, ui_pred=arith.CmpIPredicate.ugt, ) def __ge__(self, other): return self._compare( other, f_pred=arith.CmpFPredicate.OGE, si_pred=arith.CmpIPredicate.sge, ui_pred=arith.CmpIPredicate.uge, ) def _compare(self, other, *, f_pred, si_pred, ui_pred): if ir.FloatType.isinstance(self.mlir_dtype): pred = functools.partial(arith.cmpf, f_pred) elif ir.IntegerType.isinstance(self.mlir_dtype): if ir.IntegerType(self.mlir_dtype).is_signed: pred = functools.partial(arith.cmpi, si_pred) else: pred = functools.partial(arith.cmpi, ui_pred) else: raise NotImplementedError return self._pointwise(pred, other, output_is_signed=False) def max(self, other): if ir.FloatType.isinstance(self.mlir_dtype): maximumf = arith.maximumf if ir.F32Type.isinstance(self.mlir_dtype): maximumf = self._lift_fast_instr("max.NaN.f32") return self._pointwise(maximumf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise( arith.maxsi if self.is_signed else arith.maxui, other ) else: return NotImplementedError def min(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(arith.minimumf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise( arith.minsi if self.is_signed else arith.minui, other ) else: return NotImplementedError def exp(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx: dtype = self.mlir_dtype log2e = arith.constant(dtype, ir.FloatAttr.get(dtype, 1.4426950408889634)) return (self * log2e).exp2() return self._pointwise(mlir_math.exp) def exp2(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx: if not ir.F32Type.isinstance(self.mlir_dtype): raise NotImplementedError(self.mlir_dtype) return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32")) return self._pointwise(mlir_math.exp2) def log(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx: dtype = self.mlir_dtype ln2 = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.6931471805599453)) return self.log2(approx=True) * ln2 return self._pointwise(mlir_math.log) def log2(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError(self.mlir_dtype) if approx: if not ir.F32Type.isinstance(self.mlir_dtype): raise NotImplementedError(self.mlir_dtype) return self._pointwise(self._lift_fast_instr("lg2.approx.ftz.f32")) return self._pointwise(mlir_math.log2) def sin(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( self._lift_fast_instr("sin.approx.f32") if approx else mlir_math.sin ) def cos(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( self._lift_fast_instr("cos.approx.f32") if approx else mlir_math.cos ) def tanh(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( self._lift_fast_instr("tanh.approx.f32") if approx else mlir_math.tanh ) def rsqrt(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( self._lift_fast_instr("rsqrt.approx.f32") if approx else mlir_math.rsqrt ) @staticmethod def _lift_fast_instr( instr: str | Callable[[ir.Value], ir.Value], ) -> Callable[[ir.Value], ir.Value]: def fast_instr(*args): f32 = ir.F32Type.get() arg_ty = args[0].type assert all(a.type == arg_ty for a in args) if arg_ty == f32: if isinstance(instr, str): args_ptx = ", ".join(f"${i}" for i in range(len(args) + 1)) return llvm.inline_asm( f32, args, f"{instr} {args_ptx};", "=f" + ",f" * len(args) ) else: return instr(*args) elif ir.VectorType.isinstance(arg_ty): index = ir.IndexType.get() result = llvm.mlir_undef(arg_ty) [vec_len] = ir.VectorType(arg_ty).shape for i in range(vec_len): vs = [vector.extractelement(a, position=c(i, index)) for a in args] vr = fast_instr(*vs) result = vector.insertelement(vr, result, position=c(i, index)) return result else: raise NotImplementedError(arg_ty) return fast_instr def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): if (output_is_signed is not None) != ir.IntegerType.isinstance(elt): raise TypeError( "output_is_signed must be non-None if and only if the MLIR type is an" f" integer type, got {output_is_signed=} for {elt}" ) if elt == self.mlir_dtype: return self reg_type = self.registers.flat[0].type if ir.VectorType.isinstance(reg_type): reg_shape = ir.VectorType(reg_type).shape ty = ir.VectorType.get(reg_shape, elt) else: ty = elt return self._pointwise( lambda x: arith.bitcast(ty, x), output_is_signed=output_is_signed ) def __getitem__(self, idx): if self.layout != WGMMA_LAYOUT: raise NotImplementedError("Only WGMMA layouts support slicing") base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) if any(is_squeezed): raise NotImplementedError("Only slicing implemented") if ( base_idx[0] % 64 or slice_shape[0] % 64 or base_idx[1] % 8 or slice_shape[1] % 8 ): raise NotImplementedError("Only tile aligned slicing supported") base_idx[0] //= 64 slice_shape[0] //= 64 base_idx[1] //= 8 slice_shape[1] //= 8 new_regs = self.registers[ base_idx[0] : base_idx[0] + slice_shape[0], base_idx[1] : base_idx[1] + slice_shape[1], ] return FragmentedArray( _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed ) # TODO(apaszke): Support JAX dtypes here as well? def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): i4 = ir.IntegerType.get_signless(4) i8 = ir.IntegerType.get_signless(8) i16 = ir.IntegerType.get_signless(16) i32 = ir.IntegerType.get_signless(32) bf16 = ir.BF16Type.get() cur_dtype = self.mlir_dtype if cur_dtype == new_dtype: if self.is_signed == is_signed: return self return FragmentedArray( _registers=self.registers, _layout=self.layout, _is_signed=is_signed ) reg_type = self.registers.flat[0].type is_vector_reg = ir.VectorType.isinstance(reg_type) reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) [vector_len] = reg_shape # This is meant to be a 1D assertion. if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8: raise ValueError( "Register bitwidth in target type must be divisible by 8, got" f" {new_reg_bitwidth}" ) if cur_dtype == i4 and self.is_signed and new_dtype == bf16: new_registers = np.empty_like(self.registers) out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) for idx, reg in np.ndenumerate(self.registers): # The algorithm here is largely the same as CUTLASS's # NumericArrayConverter specialization for int4 -> bf16 casts. # We modify it slightly, because we only extract 2 values. # We first shift the value by 4 bits, to put the high int4 in low bits. # The prmt then blends the two values together, by putting them into the # low bits of each 16-bit subword of our register. Then, we use the lop3 # to zero any bits that don't belong to our int4s, and finally use the # XOR to: (1) set the exponent bits to 0x43 (at which point the mantissa # represents integer increments) and (2) flip the sign bit. If we # interpret the 4 bits as uint4 after the flip, then we'll see that # positive int4s will end up larger than negative int4s, with a bias of # 8. Use use the sub to subtract the base (our initial exponent) and the # bias coming from flipping the sign bit which is 136 (0x4308 as bits). def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int): assert 0 <= part < 4 return llvm.inline_asm( i32, [reg, reg_shr], f""" {{ .reg .b32 s<4>; prmt.b32 s1, $1, $2, 0xF{part + 4}F{part}; lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa; mov.b32 s3, 0x43084308; sub.bf16x2 $0, s2, s3; }} """, "=r,r,r", ) offset = 0 out_int_regs = [] for group_size in (8, 4, 2): int_ty = ir.IntegerType.get_signless(group_size * 4) while vector_len - offset >= group_size: # If the vector originates from a slice (common after relayouts), we # can fuse the slicing into the conversion and prevent LLVM from # generating a bunch of shifts to align the vector data to the LSB. # This also lets us share the right shift among more vectors. if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp) and utils.bitwidth(slice_op.vector.type) == 32 and slice_op.strides[0].value == 1): slice_offset = slice_op.offsets[0].value + offset reg_int = utils.bitcast(slice_op.vector, i32) reg_int_shr = arith.shrui(reg_int, c(4, i32)) out_int_regs.extend( upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part)) for part in range(group_size // 2) ) else: reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) reg_slice_int = utils.bitcast(reg_slice, int_ty) if int_ty != i32: reg_slice_int = arith.extsi(i32, reg_slice_int) reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) out_int_regs.extend( upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) for part in range(group_size // 2) ) offset += group_size assert offset == vector_len out_vec_int = utils.vector_concat([ vector.splat(ir.VectorType.get((1,), i32), reg) for reg in out_int_regs ]) new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=None ) if cur_dtype == i8 and self.is_signed and new_dtype == bf16 and vector_len in {2, 4}: new_registers = np.empty_like(self.registers) def upcast_to_bf16(reg, high): # We first embed the s8 into a bf16 with the exponent equal to # bias + mantissa bits. Then, we zero the msb that didn't fit into the # mantissa, zero out all bits other than msb, and subtract the last # two values from each other. This takes advantage of the fact that the # lsb of the exponent (msb of the second byte) is zero, which allows us # to losslesly pack the msb there. When 1, it doubles the value of s2, # making the result negative. return llvm.inline_asm( i32, [reg], f""" {{ .reg .b32 s<3>; prmt.b32 s0, $1, 0x43, {0x4342 if high else 0x4140}; and.b32 s1, s0, 0xff7fff7f; and.b32 s2, s0, 0xff80ff80; sub.bf16x2 $0, s1, s2; }} """, "=r,r", ) empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((vector_len // 2,), i32)) for idx, reg in np.ndenumerate(self.registers): if vector_len == 2: reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) new_reg_32 = upcast_to_bf16(reg_16, high=False) new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) elif vector_len == 4: reg_32 = vector.bitcast(ir.VectorType.get((1,), i32), reg) low = upcast_to_bf16(reg_32, high=False) high = upcast_to_bf16(reg_32, high=True) new_vec_32 = llvm.insertelement(empty_vec_32, low, c(0, i32)) new_vec_32 = llvm.insertelement(new_vec_32, high, c(1, i32)) else: raise NotImplementedError(vector_len) new_registers[idx] = vector.bitcast( ir.VectorType.get((vector_len,), new_dtype), new_vec_32 ) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) # Generic path. from_float = ir.FloatType.isinstance(cur_dtype) to_float = ir.FloatType.isinstance(new_dtype) from_integer = ir.IntegerType.isinstance(cur_dtype) to_integer = ir.IntegerType.isinstance(new_dtype) if from_float and to_float: cur_ty_width = ir.FloatType(cur_dtype).width new_ty_width = ir.FloatType(new_dtype).width if cur_ty_width == new_ty_width: # There is no instruction to perform conversions between two float types # of the same width. Go through the next-larger standard type. # TODO(bchetioui): support conversions between float types of width 8. # Which larger type to pick will depend on the number of bits in the # smallest exponent. if cur_ty_width != 16: raise NotImplementedError( "Conversion between float types of width other than 16 not" " supported" ) larger_ty = ir.F32Type.get() match self.layout: case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape upcast_ty = ir.VectorType.get(shape, larger_ty) case WGMMARowFragLayout() | WGSplatFragLayout(): upcast_ty = larger_ty case _: raise NotImplementedError(f"Unsupported layout {self.layout}") convert = lambda ty, x: arith.truncf(ty, arith.extf(upcast_ty, x)) elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width: convert = arith.truncf else: convert = arith.extf elif from_integer and to_integer: if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width: convert = arith.trunci else: convert = arith.extsi if self.is_signed else arith.extui elif from_integer and to_float: convert = arith.sitofp if self.is_signed else arith.uitofp elif from_float and to_integer: convert = arith.fptosi if is_signed else arith.fptoui else: raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}") new_registers = np.empty_like(self.registers) match self.layout: case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape new_reg_ty = ir.VectorType.get(shape, new_dtype) case WGMMARowFragLayout() | WGSplatFragLayout(): new_reg_ty = new_dtype case _: raise NotImplementedError(f"Unsupported layout {self.layout}") for idx, reg in np.ndenumerate(self.registers): new_registers[idx] = convert(new_reg_ty, reg) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) # NOTE: scratch can be reused immediately once this function returns. def reduce_sum(self, scratch: ir.Value | None = None): if isinstance(self.layout, WGSplatFragLayout): [reg] = self.registers.flat if ir.FloatType.isinstance(self.mlir_dtype): op = mulf elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.muli else: raise NotImplementedError(self.mlir_dtype) return FragmentedArray.splat( op(reg, utils.c(math.prod(self.shape), self.mlir_dtype)), (), is_signed=self.is_signed, ) if not isinstance(self.layout, WGStridedFragLayout): raise NotImplementedError(f"Unsupported layout {self.layout}") if scratch is None: raise ValueError("scratch must be provided") if ir.FloatType.isinstance(self.mlir_dtype): op = addf elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.addi else: raise NotImplementedError(self.mlir_dtype) result = c(0, self.mlir_dtype) for reg in self.registers: result = op( result, vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg), ) scratch_ty = ir.MemRefType(scratch.type) if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]: raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})") index = ir.IndexType.get() warp_result = utils.warp_tree_reduce(result, op, 32) warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index)) memref.store(warp_result, scratch, [warp_id]) utils.warpgroup_barrier() zero_index = c(0, index) with mgpu.single_thread(per_block=False): scratch_vec = vector.load( ir.VectorType.get((4,), self.mlir_dtype), scratch, [zero_index], ) scratch_sum = vector.reduction( self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec ) memref.store(scratch_sum, scratch, [zero_index]) utils.warpgroup_barrier() result = memref.load(scratch, [zero_index]) utils.warpgroup_barrier() # Make sure everyone is done using scratch. return FragmentedArray.splat(result, (), is_signed=self.is_signed) def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): if isinstance(op, str): match op: case "add": if ir.FloatType.isinstance(self.mlir_dtype): op = addf elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.addi else: raise NotImplementedError(self.mlir_dtype) case "max": if ir.F32Type.isinstance(self.mlir_dtype): op = self._lift_fast_instr("max.NaN.f32") elif ir.FloatType.isinstance(self.mlir_dtype): op = arith.maximumf elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.maxsi if self.is_signed else arith.maxui else: raise NotImplementedError(self.mlir_dtype) case _: raise ValueError(f"Unrecognized reduction operator: {op}") if self.layout != WGMMA_LAYOUT: raise NotImplementedError(self.layout) if axis != 1: raise NotImplementedError index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) row_tile_dim = self.registers.shape[0] row_subtile_dim = self.registers.shape[4] new_regs = np.empty((row_tile_dim, row_subtile_dim), dtype=object) assert self.registers.shape[-1] == 1 for row_tile, row_subtile in np.ndindex(new_regs.shape): # Reduce the registers owned by the current thread over n tiles reg_index = [0] * self.registers.ndim reg_index[0] = row_tile reg_index[4] = row_subtile thread_result_vec = self.registers[tuple(reg_index)] for n_tile in range(1, self.registers.shape[1]): reg_index[1] = n_tile thread_result_vec = op( thread_result_vec, self.registers[tuple(reg_index)] ) thread_result = vector.extractelement(thread_result_vec, position=c(0, index)) for i in range(1, self.layout.vector_length): thread_result = op( thread_result, vector.extractelement(thread_result_vec, position=c(i, index)), ) # Do a shuffle to reduce in groups of 4 consecutive threads. result = thread_result for i in (1, 2): other_result = nvvm.shfl_sync( result.type, c(0xFFFFFFFF, i32), result, c(i, i32), c(0x1F, i32), nvvm.ShflKind.bfly, ) result = op(result, other_result) new_regs[row_tile, row_subtile] = result return FragmentedArray( _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed ) def broadcast(self, shape): if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError(self.layout) if self.shape == shape: return self if not self.layout.can_broadcast_to(shape): raise ValueError(f"Can't broadcast {self.shape} to {shape}") return FragmentedArray( _registers=self.registers, _layout=WGSplatFragLayout(shape), _is_signed=self.is_signed, ) def reshape(self, shape): if self.shape == shape: return self if math.prod(shape) != math.prod(self.shape): raise ValueError(f"Can't reshape {self.shape} to {shape}") match self.layout: case WGSplatFragLayout() | WGStridedFragLayout(): new_layout = dataclasses.replace(self.layout, shape=shape) case _: raise NotImplementedError(self.layout) return FragmentedArray( _registers=self.registers, _layout=new_layout, _is_signed=self.is_signed ) def broadcast_minor(self, n): if self.layout != WGMMA_ROW_LAYOUT: raise NotImplementedError if n % 8: raise ValueError("Number of columns must be divisible by 8") reg_shape = WGMMA_LAYOUT.registers_shape((self.shape[0], n)) new_regs = np.empty(reg_shape, dtype=object) dtype = self.mlir_dtype for (row_tile, row_subtile), reg in np.ndenumerate(self.registers): tile = [slice(None)] * len(new_regs.shape) tile[0] = row_tile tile[4] = row_subtile new_regs[tuple(tile)] = vector.splat( ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), reg ) return FragmentedArray( _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed ) def select(self, on_true, on_false): if ( not ir.IntegerType.isinstance(self.mlir_dtype) or ir.IntegerType(self.mlir_dtype).width != 1 ): raise NotImplementedError # We change the receiver here, because the return type is defined by # `on_true` and `on_false` and not the predicate `self`. return on_true._pointwise( lambda t, p, f: arith.select(p, t, f), self, on_false, ) def foreach( self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None], *, create_array=False, is_signed=None, ): """Call a function for each value and index.""" index = ir.IndexType.get() new_regs = None if create_array: new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type)) for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): reg = self.registers[reg_idx] assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) [elems] = ir.VectorType(reg.type).shape for i in range(elems): i = c(i, index) val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) if create_array: new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) def debug_print(self, fmt: str): idx_fmt = ", ".join(["{}"] * len(self.shape)) @self.foreach def _(val, idx): fmt_str = fmt.format(f"[{idx_fmt}]: {{}}") utils.debug_print(fmt_str, *idx, val, uniform=False) def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) def vs_unsupported(): if not vector_store: raise NotImplementedError( f"Can't use non-vector stores with layout {self.layout}" ) match self.layout: case WGSplatFragLayout(): vs_unsupported() self._store_untiled_splat(ref) case WGStridedFragLayout(): vs_unsupported() self._store_untiled_wg_strided(ref) case TiledLayout(): self._store_untiled_tiled(ref, vector_store=vector_store) case _: raise NotImplementedError(self.layout) def _store_untiled_splat(self, ref: ir.Value): vec_size = 64 // mgpu.bitwidth(self.mlir_dtype) if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: vec_size = 1 if np.prod(self.shape) % WARPGROUP_SIZE * vec_size: raise ValueError(self.shape, WARPGROUP_SIZE, vec_size) fa = FragmentedArray.splat( self.registers.flat[0], self.shape, layout=WGStridedFragLayout(shape=self.shape, vec_size=vec_size), is_signed=self.is_signed, ) fa.store_untiled(ref) def _store_untiled_wg_strided(self, ref: ir.Value): ref_ty = ir.MemRefType(ref.type) try: # Flattening the reference potentially produces simpler PTX but # if the ref is not already 1D and has strided dimensions # flattening won't work. We use a different variable for ref in # case `NotImplementedError` is thrown by # .linear_thread_idxs(). ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) idxs = ([i] for i in self.layout.linear_thread_idxs()) except NotImplementedError: ref_ = ref idxs = self.layout.thread_idxs() ref_shape = tuple(ref_ty.shape) if ref_shape != self.shape: raise ValueError((ref_shape, self.shape)) for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})") i32 = ir.IntegerType.get_signless(32) layout = self.layout assert isinstance(layout, TiledLayout) ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() if vector_store and ref_strides[layout.vector_dim] != 1: raise NotImplementedError( "Can't use vector stores with non-unit minormost stride" ) strides = layout.tiling.tile_strides(ref_strides) smem_space = ir.Attribute.parse("#gpu.address_space") ref_space = ir.MemRefType(ref.type).memory_space memory_space = None if str(ref_space) == str(smem_space): memory_space = 3 elif ref_space: raise NotImplementedError(f"Unexpected ref space {ref_space}") ptr = utils.memref_ptr(ref, memory_space=memory_space) # Fold warp and lane offsets into the pointer once, since they are dynamic. dyn_strides = [ arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :] ] warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) dyn_offset = arith.addi(warp_offset, lane_offset) ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) # All warp tile offsets are static and can be fused into the store. for tile_idx, reg in np.ndenumerate(self.registers): if vector_store: elems = [reg] else: index = ir.IndexType.get() elems = [ vector.extractelement(reg, position=c(i, index)) for i in range(ir.VectorType(reg.type).shape[0]) ] for i, e in enumerate(elems): tile_idx_local = list(tile_idx) tile_idx_local[layout.vector_dim] += i tile_idx_local = list(tile_idx_local) lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True)) reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) llvm.store(e, reg_ptr) def store_tiled(self, ref, swizzle: int | None): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): llvm.store(get(self.registers), ptr) @classmethod def load_tiled( cls, ref, swizzle: int | None, *, is_signed: bool | None = None, layout: FragmentedLayout = WGMMA_LAYOUT, ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type match layout: case TiledLayout(): ref_ty = ir.MemRefType(ref.type) tiled_shape = ref_ty.shape if len(tiled_shape) % 2: raise ValueError("Tiled reference must have even rank") tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],)) shape = tiling.untile_shape(tiled_shape) zero = ( vector.splat( ir.VectorType.get((layout.vector_length,), dtype), c(0, dtype) ), ) registers = np.full(layout.registers_shape(shape), zero, dtype=object) reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): update(registers, llvm.load(reg_ty, ptr)) case _: raise NotImplementedError(layout) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) @staticmethod def transfer_tiled(shape, dtype, swizzle: int | None): # TODO(apaszke): We could use ldmatrix/stmatrix for 16-bit types. bw = mgpu.bitwidth(dtype) m, n = shape assert m % 64 == 0 and n % 8 == 0 # Implied by the layout. cols_per_tile = swizzle_elems = (swizzle * 8) // bw if n < swizzle_elems: cols_per_tile = n else: assert n % swizzle_elems == 0, (n, swizzle_elems) if swizzle not in {32, 64, 128}: raise NotImplementedError("Only swizzled stores supported") c = arith.ConstantOp.create_index tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE)) lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31} warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3} sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7} if bw > 16: # Stagger is only necessary for values larger than 16bit. # We split the rows into two groups (left/right) and change the order in # which they perform accesses to avoid bank conflicts. # It seems that the STS.64 is 2x faster (and the hardware reports no # conflicts) when the conflicts are split between half-warps, as # opposed to having them within the half-warp. This requires a # little more work for the selects, but is ultimately worth it. match swizzle: case 128: is_stagger_left = arith.cmpi( arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0) ) case 64: is_stagger_left = arith.cmpi( arith.CmpIPredicate.eq, arith.remui(arith.divui(sub_row_base, c(2)), c(2)), c(0), ) case 32: # 32-byte tiles of 4-byte types have only 8 columns so there is no way # to stagger the memory accesses within a single tile. We could do it # across tiles, but that would be a completely different scheme. raise NotImplementedError case _: raise AssertionError(swizzle) stagger_amount = swizzle // 64 if (cols_per_tile // 8) % (stagger_amount * 2): raise NotImplementedError else: # We rely on canonicalization to clean up the selects. i1 = ir.IntegerType.get_signless(1) is_stagger_left = arith.constant(i1, ir.BoolAttr.get(True)) stagger_amount = 0 row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16))) col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6} # The swizzle pattern is constant for a given thread. col_swizzle_bits = arith.muli( arith.divui(sub_row_base, c(128 // swizzle)), c(128 // bw), ) for row_group in range(m // 64): for col_group in range(n // cols_per_tile): for row_subidx in range(2): row = arith.addi(row_base, c(row_subidx * 8)) for col_subidx in range(cols_per_tile // 8): col_subidx_left = col_subidx col_subidx_right = col_subidx ^ stagger_amount col_off = arith.select( is_stagger_left, c(col_subidx_left * 8), c(col_subidx_right * 8) ) col = arith.addi(col_base, col_off) col = arith.xori(col, col_swizzle_bits) reg_idx_left = col_subidx_left + col_group * (cols_per_tile // 8) reg_idx_right = col_subidx_right + col_group * (cols_per_tile // 8) left_idx = row_group, reg_idx_left, row_subidx, 0 right_idx = row_group, reg_idx_right, row_subidx, 0 idx = c(row_group), c(col_group), row, col def get_register(regs, left_idx=left_idx, right_idx=right_idx): value_left = regs[left_idx] value_right = regs[right_idx] return arith.select(is_stagger_left, value_left, value_right) def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx): regs[left_idx] = arith.select(is_stagger_left, new, regs[left_idx]) regs[right_idx] = arith.select(is_stagger_left, regs[right_idx], new) yield get_register, update_registers, idx @staticmethod def transfer_tiled2( ref: ir.Value, swizzle: int | None, layout: TiledLayout, shape: tuple[int, ...], ): """Generate a transfer schedule for a tiled layout. Given a ref with one level tiling applied to it (we assume all dimensions have been tiled), this function generates an iterable describing a good schedule for swizzled SMEM loads/stores. At each step, the iterable yields a tuple of three values: * a function that takes a register array and returns the register to be stored at the current address * a function that takes a register array and a register loaded from the current address, and updates the register array with that register * the current address for load/store instructions """ # TODO(apaszke): Use ldmatrix/stmatrix when possible. c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) tiling = layout.tiling ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type if ref_ty.rank % 2: raise ValueError("Tiled reference must have even rank") ref_logical_rank = ref_ty.rank // 2 ref_tiling_shape = tuple(ref_ty.shape[ref_logical_rank:]) ref_tiling = Tiling((ref_tiling_shape,)) ref_strides, _ = ref_ty.get_strides_and_offset() if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape: raise ValueError() nested_ref_shape = tuple( (ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank]) for i in range(ref_logical_rank) ) nested_ref_strides = tuple( (ref_strides[i], ref_strides[i + ref_logical_rank]) for i in range(ref_logical_rank) ) tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides( nested_ref_shape, nested_ref_strides ) # We could technically handle this case, but it would be quite complicated. # If tiling dimensions would have to be expanded into multiple, we'd have to # adjust the dimension indices in layouts, including expanding some of them # into multiple indices. Note that for non-tiling dims, we allow the shape # to be arbitrary, which is why we fix it up below in mem_idx_to_reg_idx. if any( len(dim_shape) != 1 for dim_shape in tiled_nested_shape[-layout.tiled_tiling_rank :] ): raise NotImplementedError("Memory and register tiling incompatible") tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape)) elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides)) elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims] lane_shape = [tiled_shape[d] for d in layout.lane_dims] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): tiled_shape[d] = 1 element_bits = mgpu.bitwidth(dtype) if (layout.vector_length * element_bits) % 8 != 0: raise ValueError( f"Vector length ({layout.vector_length}) must be a multiple of bytes," f" but has {layout.vector_length * element_bits} bits" ) transfer_bytes = (layout.vector_length * element_bits) // 8 # Not sure if this is strictly required for all data types, but it certainly # is for sub-byte types (else we might not increment the pointer by whole bytes). if any( s % layout.vector_length and i != layout.vector_dim and d != 1 for i, (s, d) in enumerate_negative( list(zip(elem_tiled_strides, tiled_shape)) ) ): raise ValueError( "Tiled strides must be a multiple of the vector length, except for the" " vector dimension" ) if swizzle not in {16, 32, 64, 128}: raise ValueError("Only swizzled transfers supported") # We will be computing the offsets in units of vectors, not elements, # to better support sub-byte types. swizzle_tile_transfers = 16 // transfer_bytes swizzle_group_transfers = 128 // transfer_bytes swizzle_groups_per_block = swizzle // 16 swizzle_block_transfers = swizzle_groups_per_block * swizzle_group_transfers # Technically we should keep the vector_dim set to 1, but its shape is 1 # so it does not matter. transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides] transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) plan = plan_tiled_transfer( tiled_shape, elem_tiled_strides, lane_shape, elem_lane_strides, layout, element_bits, swizzle ) # All offsets are in units of transfer_dtype. dyn_tiled_strides = [ c(s) for s in transfer_tiled_strides[-layout.tiled_tiling_rank :] ] lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides) warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides) dyn_offset = arith.addi(lane_offset, warp_offset) if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): raise ValueError("Tiled stores can be performed into SMEM") ptr = utils.memref_ptr(ref, memory_space=3) _as_consts = lambda consts: [c(const) for const in consts.tolist()] # This has bits set only for the offset bits that influence swizzling. swizzle_mask = swizzle_block_transfers - swizzle_tile_transfers for tile_idx in np.ndindex(*tiled_shape): indices = np.asarray([f(tile_idx) for f in plan.tile_index_transforms]) const_offset = np.dot(indices, transfer_tiled_strides) # We split the offset into a part that interacts with swizzling and a # part that doesn't. This lets us generate better code because constant # offsets can be fused into load and store instructions. const_offset_swizzle = const_offset & swizzle_mask const_offset_no_swizzle = const_offset - const_offset_swizzle offset_pre_swizzle = arith.addi( dyn_offset, plan.select(_as_consts(const_offset_swizzle)) ) swizzle_group = arith.remui( arith.divui(offset_pre_swizzle, c(swizzle_group_transfers)), c(swizzle_groups_per_block), ) swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_transfers)) offset = arith.xori(offset_pre_swizzle, swizzle_bits) reg_ptr = utils.getelementptr(ptr, [offset], transfer_dtype) offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle)) reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], transfer_dtype) # Here, registers are organized in an array with shape obtained by tiling # the logical data bounds. But, the reference was tiled and so each # logical tiled dimension can map to multiple dims in tiled_shape. # The transform below maps this potentially higher-rank representation # back to the lower-rank representation used by the register arrays. def mem_idx_to_reg_idx(idx): reg_tiled_idx = [] base_idx = 0 for dim_shape in tiled_nested_shape[:ref_logical_rank]: dim_strides = utils.get_contiguous_strides(dim_shape) dim_idxs = idx[base_idx:base_idx + len(dim_shape)] base_idx += len(dim_shape) reg_tiled_idx.append(sum(i * s for i, s in zip(dim_idxs, dim_strides))) # We should have fixed up all but the tiling dims. assert base_idx == len(idx) - layout.tiled_tiling_rank return (*reg_tiled_idx, *idx[base_idx:]) reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()] def get_register(regs, reg_idxs=reg_idxs): return plan.select([regs[reg_idx] for reg_idx in reg_idxs]) def update_registers(regs, new, reg_idxs=reg_idxs): # TODO(apaszke): If the staggering forms a permutation with a small # cycle length, then instead of blending at each step we could construct # a small routing network (kind of like a sorting network) to fix up # each cycle separately after all the loads are performed. # This would be especially useful for dims that are powers of two and # staggered by another power of 2, since all cycles are of length 2 (and # we could save half the selects). for i, reg_idx in enumerate(reg_idxs): regs[reg_idx] = plan.select_if_group(i, regs[reg_idx], new) yield get_register, update_registers, reg_ptr def tree_flatten(self): aux = self.layout, self.registers.shape, self.is_signed return list(self.registers.flat), aux @classmethod def tree_unflatten(cls, aux, flat_registers): layout, reg_shape, is_signed = aux registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) class TransferPlan(Protocol): IndexTransform = Callable[[tuple[int, ...]], tuple[int, ...]] tile_index_transforms: tuple[IndexTransform, ...] def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: """Selects the value corresponding to the group of the current thread. The argument must be of the same length as tile_index_transforms. """ raise NotImplementedError def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: """Returns `new` if the current thread belongs to the given group and `old` otherwise. group_idx must be between 0 and len(tile_index_transforms) - 1. """ raise NotImplementedError @dataclasses.dataclass(frozen=True) class TrivialTransferPlan(TransferPlan): @property def tile_index_transforms(self): return (lambda x: x,) def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: assert len(group_elems) == 1 return group_elems[0] def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: assert group_idx == 0 return new @dataclasses.dataclass(frozen=True) class StaggeredTransferPlan(TransferPlan): stagger: int dim: int size: int group_pred: ir.Value @property def tile_index_transforms(self): dim = self.dim def rotate(idx: tuple[int, ...]) -> tuple[int, ...]: return ( *idx[:dim], (idx[dim] + self.stagger) % self.size, *idx[dim + 1 :], ) return (lambda x: x, rotate) def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: assert len(group_elems) == 2 return arith.select(self.group_pred, group_elems[1], group_elems[0]) def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: assert 0 <= group_idx <= 1 sides = [old, new] if group_idx == 0 else [new, old] return arith.select(self.group_pred, *sides) def plan_tiled_transfer( tiled_shape: Sequence[int], tiled_strides: Sequence[int], lane_shape: Sequence[int], lane_strides: Sequence[int], layout: TiledLayout, element_bits: int, swizzle: int, ) -> TransferPlan: i32 = ir.IntegerType.get_signless(32) c = lambda x: arith.constant(i32, x) # TODO(apaszke): Rewrite this function in terms of transfer_bytes (that we get # from the caller). swizzle_tile_elems = (16 * 8) // element_bits swizzle_group_elems = (128 * 8) // element_bits # Should be checked at the call site. assert layout.vector_length * element_bits % 8 == 0 transfer_bytes = (layout.vector_length * element_bits) // 8 # Below, all calculations are in elements, not in bytes, since it should # generalize better to sub-byte types. # Here, we verify two conditions: # 1. Each vector transfer only accesses addresses that fall within a single # swizzle tile (if not we'd need to split it and swizzle parts differently). transfer_alignment = math.gcd(*( s for i, (s, d) in enumerate_negative(list(zip(tiled_strides, tiled_shape))) if d > 1 or i in {layout.warp_dim, *layout.lane_dims} )) if ( swizzle_tile_elems % transfer_alignment and layout.vector_length <= transfer_alignment ): raise ValueError( "Failed to prove that vector transfers don't cross swizzle tile" " boundaries. This check is incomplete, and does not guarantee that" " this is a user error, but it might be." + str(transfer_alignment) ) # 2. The transfer pattern does not cause bank conflicts. # TODO(apaszke): For now, when performing transfers narrower than a bank, # we simply narrow each bank to the transfer width. The truth is more likely # that bank conflicts only don't occur if the addresses mapping to the same # bank are contiguous, but that's a more complicated check to perform. if transfer_bytes > SMEM_BANK_BYTES * 4: raise NotImplementedError if element_bits > SMEM_BANK_BYTES * 8: raise NotImplementedError smem_bank_bytes = min(SMEM_BANK_BYTES, transfer_bytes) num_banks = SMEM_BANKS * (SMEM_BANK_BYTES // smem_bank_bytes) elems_per_bank = (smem_bank_bytes * 8) // element_bits num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1) wavefront_lanes = WARP_SIZE // num_wavefronts lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) def has_bank_conflicts(tile_idx_transform): tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape) tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims] lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims] assert lane_tile_idx.shape[1] in {1, WARP_SIZE} lane_tile_offsets = np.dot(lane_tile_idx, tiled_strides) offsets = lane_tile_offsets + lane_offsets_in_tile # [#tiles, #lanes] assert offsets.shape[-1] == WARP_SIZE swizzle_groups = (offsets // swizzle_group_elems) % (swizzle // 16) swizzle_bits = swizzle_groups * swizzle_tile_elems lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) # Order of threads within the wavefront is unimportant. wavefront_banks = np.sort(wavefront_banks, axis=-1) # There are no conflicts if each wavefront only contains unique banks. return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1]) # We don't need any special treatment if there are no conflicts when each lane # transfers the same tile at a time. if not has_bank_conflicts(lambda tile_idx: tile_idx): return TrivialTransferPlan() # Otherwise, we will try to partition the lanes into two groups and have # each group store to different tile. The only tile dimensions that can help # us with bank conflicts are those that have multiple elements and a stride # that's not a multiple of the number of banks. # # Note that the code is set up so that we could also consider partitioning # the lanes into more groups, but the selects will become more expensive if # we do that. It's a possibility we have if we need it. candidate_dims = ( i for i, (s, d) in enumerate(zip(tiled_strides, tiled_shape)) if d > 1 and s % (SMEM_BANKS * elems_per_bank) ) for dim in candidate_dims: for group_stride in (1, 2, 4, 8, 16): # We change the group assignment each group_stride lanes. lane_id = np.arange(WARP_SIZE)[:, None] lane_group = (lane_id // group_stride) % 2 # We only consider a transformation where the second group stores to a # tile that's a constant offset (modulo dim size) from the first one. for stagger in range(1, tiled_shape[dim]): offset = np.zeros(len(tiled_shape), np.int64) offset[dim] = stagger transform = lambda idx: (idx + offset * lane_group) % tiled_shape if not has_bank_conflicts(transform): # We've found a strategy that avoids bank conflicts! lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE)) group_idx = arith.remui(arith.divui(lane_idx, c(group_stride)), c(2)) group_pred = arith.cmpi(arith.CmpIPredicate.ne, group_idx, c(0)) return StaggeredTransferPlan( stagger, dim, tiled_shape[dim], group_pred ) raise ValueError( "Failed to synthesize a transfer pattern that avoids bank conflicts" ) # We allow contractions, to potentially take advantage of FMA instructions. # They can change the results, but the precision should only increase. def addf(a: ir.Value, b: ir.Value): return arith.addf(a, b, fastmath=arith.FastMathFlags.contract) def subf(a: ir.Value, b: ir.Value): return arith.subf(a, b, fastmath=arith.FastMathFlags.contract) def mulf(a: ir.Value, b: ir.Value): return arith.mulf(a, b, fastmath=arith.FastMathFlags.contract) def optimization_barrier(*arrays: mgpu.FragmentedArray): """Acts as an optimization barrier for LLVM. Passing arrays through this function will make sure that they are computed before any side-effecting operations that follow this barrier. """ index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) regs = [] reg_dtypes = [] reg_constraints = [] repack_fns = [] # We unpack each array into a flat list of registers, and prepare the # functions that invert the transform in repack_fns. for array in arrays: reg_ty = array.registers.flat[0].type dtype = array.mlir_dtype if ir.F32Type.isinstance(dtype): if ir.VectorType.isinstance(reg_ty): [vec_len] = ir.VectorType(reg_ty).shape array_regs = [ # pylint: disable=g-complex-comprehension vector.extractelement(reg, position=c(pos, index)) for reg in array.registers.flat for pos in range(vec_len) ] def _repack(regs, reg_ty=reg_ty): reg = llvm.mlir_undef(reg_ty) [vec_len] = ir.VectorType(reg_ty).shape for i_elem in range(vec_len): reg = llvm.insertelement( reg, next(regs), arith.constant(i32, i_elem) ) return reg repack_fns.append(_repack) else: array_regs = list(array.registers.flat) repack_fns.append(lambda regs: next(regs)) reg_constraint = "f" elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): if not ir.VectorType.isinstance(reg_ty): raise NotImplementedError(array.mlir_dtype) [vec_len] = ir.VectorType(reg_ty).shape if vec_len != 2: raise NotImplementedError(vec_len) i32_reg_ty = ir.VectorType.get((1,), i32) array_regs = [ vector.extractelement( vector.bitcast(i32_reg_ty, reg), position=c(0, index) ) for reg in array.registers.flat ] reg_constraint = "r" def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs))) repack_fns.append(_repack) else: raise NotImplementedError(array.mlir_dtype) regs += array_regs reg_dtypes += [array_regs[0].type] * len(array_regs) reg_constraints += [reg_constraint] * len(array_regs) ptx_lines = [ f"mov.b32 ${i}, ${len(reg_constraints)+i}" for i in range(len(reg_constraints)) ] ptx = ";\n\t".join(ptx_lines) + ";" all_reg_constraints = ",".join( [*("=" + c for c in reg_constraints), *reg_constraints] ) struct_ty = ir.Type.parse( f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" ) result_struct = llvm.inline_asm( struct_ty, regs, ptx, all_reg_constraints, asm_dialect=0, has_side_effects=True, ) regs = [ llvm.extractvalue(dtype, result_struct, [i]) for i, dtype in enumerate(reg_dtypes) ] i32 = ir.IntegerType.get_signless(32) results = [] regs_it = iter(regs) for array, repack_fn in zip(arrays, repack_fns, strict=True): num_regs = array.registers.size reg_ty = array.registers.flat[0].type if ir.VectorType.isinstance(reg_ty): reg_ty = ir.VectorType(reg_ty) new_registers = np.empty((num_regs,), dtype=object) for i_vreg in range(num_regs): reg = repack_fn(regs_it) assert reg.type == reg_ty, (reg.type, reg_ty) new_registers[i_vreg] = reg results.append( FragmentedArray( _registers=new_registers.reshape(array.registers.shape), _layout=array.layout, _is_signed=array.is_signed, ) ) return results[0] if len(arrays) == 1 else results