rocm_jax/jax/experimental/mosaic/gpu/fragmented_array.py
Adam Paszke 8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00

477 lines
18 KiB
Python

# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for code generator."""
import dataclasses
import jax
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import math as mlir_math
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import vector
import numpy as np
from . import utils
from . import dsl as mgpu
# mypy: ignore-errors
WARPGROUP_SIZE = utils.WARPGROUP_SIZE
c = utils.c
@dataclasses.dataclass(frozen=True)
class WGMMAFragLayout:
"""[m, n] matrix, where m % 64 == 0 == n % 8."""
@dataclasses.dataclass(frozen=True)
class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0."""
@dataclasses.dataclass(frozen=True)
class WGStridedFragLayout:
"""Convert the array to 1D and then shard across threads."""
shape: tuple[int, ...]
vec_size: int
def __post_init__(self):
if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0:
raise ValueError((self, WARPGROUP_SIZE))
@classmethod
def from_memref_type(cls, memref_ty: ir.Type):
if not ir.MemRefType.isinstance(memref_ty):
raise TypeError(memref_ty)
memref_type = ir.MemRefType(memref_ty)
bw = mgpu.bytewidth(memref_type.element_type)
assert 8 % bw == 0 and 8 // bw != 0, bw
return cls(shape=memref_type.shape, vec_size=8 // bw)
def thread_vec_idxs(self):
"""The indexes to be used for vector load/store WGStridedFragLayout.
Yields:
The indices of the vector that correspond to the current thread.
"""
cardinality = np.prod(self.shape)
assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0
reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size)
tidx = gpu.thread_id(gpu.Dimension.x)
off = arith.muli(tidx, c(self.vec_size, tidx.type))
for i in range(reg_num):
yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))]
FragmentedLayout = WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout
WGMMA_LAYOUT = WGMMAFragLayout()
WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
@jax.tree_util.register_pytree_node_class
class FragmentedArray:
registers: np.ndarray # of ir.Value, see checks in init for shapes.
layout: FragmentedLayout
def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout):
self.registers = _registers
self.layout = _layout
match self.layout:
# Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout
# Each element is a vector<2xdtype>
case WGMMAFragLayout():
if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1):
raise ValueError("Invalid register array shape")
# Registers are [m_tiles, 2 rows] in WGMMA_ROW layout
# Each element is a dtype scalar
case WGMMARowFragLayout():
if self.registers.ndim != 2 or self.registers.shape[-1] != 2:
raise ValueError("Invalid register array shape")
# Registers are flat
case WGStridedFragLayout(shape):
(reg_size,) = ir.VectorType(_registers.flat[0].type).shape
if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size:
raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type)
case _:
raise NotImplementedError
@classmethod
def load_strided(cls, ref: ir.Value):
if not ir.MemRefType.isinstance(ref.type):
raise TypeError(ref.type)
ref_ty = ir.MemRefType(ref.type)
ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
layout = WGStridedFragLayout.from_memref_type(ref_ty)
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()]
return cls(_registers=np.array(vecs), _layout=layout)
@classmethod
def splat(cls, value, shape, layout):
match layout:
case WGMMARowFragLayout():
if len(shape) != 1:
raise ValueError
if shape[0] % 64:
raise ValueError
reg_shape = (shape[0] // 64, 2)
case WGMMAFragLayout():
if len(shape) != 2:
raise ValueError
if shape[0] % 64 or shape[1] % 8:
raise ValueError
reg_shape = (shape[0] // 64, shape[1] // 8, 2, 1)
value = vector.splat(ir.VectorType.get((2,), value.type), value)
case WGStridedFragLayout(shape=shape, vec_size=vec_size):
elems = np.prod(shape)
reg_shape = (elems // (WARPGROUP_SIZE * vec_size),)
value = vector.splat(ir.VectorType.get((vec_size,), value.type), value)
case _:
raise NotImplementedError(layout)
return cls(
_registers=np.full(reg_shape, value, dtype=object),
_layout=layout,
)
@property
def shape(self):
row_tiles = self.registers.shape[0]
match self.layout:
case WGMMAFragLayout():
col_tiles = self.registers.shape[1]
return (row_tiles * 64, col_tiles * 8)
case WGMMARowFragLayout():
return (row_tiles * 64,)
case WGStridedFragLayout(shape):
return shape
@property
def mlir_dtype(self):
reg_ty = self.registers.flat[0].type
match self.layout:
case WGMMAFragLayout() | WGStridedFragLayout():
return ir.VectorType(reg_ty).element_type
case WGMMARowFragLayout():
return reg_ty
def _pointwise(self, op, *other):
for o in other:
if not isinstance(o, FragmentedArray):
return NotImplemented
if self.layout != o.layout:
raise ValueError("Incompatible FragmentedArray layouts")
if self.registers.shape != o.registers.shape:
raise ValueError("Incompatible FragmentedArray shapes")
new_regs = np.empty_like(self.registers)
for idx, reg in np.ndenumerate(self.registers):
new_regs[idx] = op(reg, *(o.registers[idx] for o in other))
return FragmentedArray(_registers=new_regs, _layout=self.layout)
def __add__(self, other):
return self._pointwise(arith.addf, other)
def __mul__(self, other):
return self._pointwise(arith.mulf, other)
def __sub__(self, other):
return self._pointwise(arith.subf, other)
def __truediv__(self, other):
return self._pointwise(arith.divf, other)
def max(self, other):
return self._pointwise(arith.maximumf, other)
def exp(self, approx: bool = False):
def fast_exp(x):
f32 = ir.F32Type.get()
log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634))
if x.type == f32:
scaled = arith.mulf(x, log2e)
return llvm.inline_asm(
f32, [scaled], "ex2.approx.f32 $0,$1;", "=f,f", asm_dialect=0
)
elif ir.VectorType.isinstance(x.type):
index = ir.IndexType.get()
result = llvm.mlir_undef(x.type)
for i in range(2):
v = vector.extractelement(x, position=c(i, index))
vr = fast_exp(v)
result = vector.insertelement(vr, result, position=c(i, index))
return result
else:
raise NotImplementedError(x.type)
return self._pointwise(fast_exp if approx else mlir_math.exp)
def __and__(self, other):
if not ir.IntegerType.isinstance(self.mlir_dtype):
raise ValueError(
"Bitwise operations only defined for integer types, not"
f" {self.mlir_dtype}"
)
return self._pointwise(arith.andi, other)
def bitcast(self, elt: ir.Type):
reg_type = self.registers.flat[0].type
if ir.VectorType.isinstance(reg_type):
reg_shape = ir.VectorType(reg_type).shape
ty = ir.VectorType.get(reg_shape, elt)
else:
ty = elt
return self._pointwise(lambda x: arith.bitcast(ty, x))
def __getitem__(self, idx):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError("Only WGMMA layouts support slicing")
base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
if any(is_squeezed):
raise NotImplementedError("Only slicing implemented")
if (
base_idx[0] % 64
or slice_shape[0] % 64
or base_idx[1] % 8
or slice_shape[1] % 8
):
raise NotImplementedError("Only tile aligned slicing supported")
base_idx[0] //= 64
slice_shape[0] //= 64
base_idx[1] //= 8
slice_shape[1] //= 8
new_regs = self.registers[
base_idx[0] : base_idx[0] + slice_shape[0],
base_idx[1] : base_idx[1] + slice_shape[1],
]
return FragmentedArray(_registers=new_regs, _layout=self.layout)
# TODO(apaszke): Support JAX dtypes here as well?
def astype(self, new_dtype: ir.Type):
cur_dtype = self.mlir_dtype
if cur_dtype == new_dtype:
return self
from_float = ir.FloatType.isinstance(cur_dtype)
to_float = ir.FloatType.isinstance(new_dtype)
from_integer = ir.IntegerType.isinstance(cur_dtype)
to_integer = ir.IntegerType.isinstance(new_dtype)
if from_float and to_float:
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
convert = arith.truncf
else:
convert = arith.extf
elif from_integer and to_integer:
if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width:
convert = arith.trunci
else:
convert = arith.extsi
elif from_integer and to_float:
convert = arith.sitofp
elif from_float and to_integer:
convert = arith.fptosi
new_registers = np.empty_like(self.registers)
match self.layout:
case WGMMAFragLayout():
new_reg_ty = ir.VectorType.get((2,), new_dtype)
case WGStridedFragLayout(vec_size=vec_size):
new_reg_ty = ir.VectorType.get((vec_size,), new_dtype)
case WGMMARowFragLayout():
new_reg_ty = new_dtype
case _:
raise NotImplementedError(f"Unsupported layout {self.layout}")
for idx, reg in np.ndenumerate(self.registers):
new_registers[idx] = convert(new_reg_ty, reg)
return FragmentedArray(_registers=new_registers, _layout=self.layout)
def reduce(self, op, axis):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError(self.layout)
if axis != 1:
raise NotImplementedError
index = ir.IndexType.get()
i32 = ir.IntegerType.get_signless(32)
new_regs = np.empty(self.registers.shape[::2], dtype=object)
assert self.registers.shape[-1] == 1
for row_tile, row_subtile in np.ndindex(new_regs.shape):
# Reduce the registers owned by the current thread over n tiles
thread_result_vec = self.registers[row_tile, 0, row_subtile, 0]
for n_tile in range(1, self.registers.shape[1]):
thread_result_vec = op(
thread_result_vec, self.registers[row_tile, n_tile, row_subtile, 0]
)
thread_result = op(
vector.extractelement(thread_result_vec, position=c(0, index)),
vector.extractelement(thread_result_vec, position=c(1, index)),
)
# Do a shuffle to reduce in groups of 4 consecutive threads.
result = thread_result
for i in (1, 2):
other_result = nvvm.shfl_sync(
result.type,
c(0xFFFFFFFF, i32),
result,
c(i, i32),
c(0x1F, i32),
nvvm.ShflKind.bfly,
)
result = op(result, other_result)
new_regs[row_tile, row_subtile] = result
return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT)
def broadcast_minor(self, n):
if self.layout != WGMMA_ROW_LAYOUT:
raise NotImplementedError
num_row_tiles = self.registers.shape[0]
num_col_tiles, rem = divmod(n, 8)
if rem:
raise ValueError("Number of columns must be divisible by 8")
new_regs = np.empty((num_row_tiles, num_col_tiles, 2, 1), dtype=object)
dtype = self.mlir_dtype
for (row_tile, row_subtile), reg in np.ndenumerate(self.registers):
new_regs[row_tile, :, row_subtile, :] = vector.splat(
ir.VectorType.get((2,), dtype), reg
)
return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT)
def store_untiled(self, ref: ir.Value):
if not ir.MemRefType.isinstance(ref.type):
raise ValueError(ref)
match self.layout:
case WGMMAFragLayout():
self._store_untiled_wgmma(ref)
case WGStridedFragLayout():
self._store_untiled_wg_strided(ref)
case _:
raise NotImplementedError(self.layout)
def _store_untiled_wg_strided(self, ref: ir.Value):
ref_ty = ir.MemRefType(ref.type)
if ref_ty.shape != self.shape:
raise ValueError((ref_ty.shape, self.shape))
smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
assert isinstance(self.layout, WGStridedFragLayout)
for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat):
vector.store(reg, smem_1d, idx)
def _store_untiled_wgmma(self, ref: ir.Value):
"""Stores accumulator to a 2D memref. Not optimized at the moment."""
assert self.layout == WGMMA_LAYOUT
index = ir.IndexType.get()
m, n = self.shape # pytype: disable=bad-unpacking
ref_ty = ir.MemRefType(ref.type)
if ref_ty.shape != [m, n]:
raise ValueError(ref.type, (m, n))
def c(x):
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
tidx = gpu.thread_id(gpu.Dimension.x)
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
row_base = arith.addi(
arith.divui(lane_id, c(4)), arith.muli(warp_id, c(16))
)
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
it = np.ndenumerate(self.registers)
for (row_tile, col_tile, row_idx, col_zero), elem in it:
del col_zero
row = arith.addi(row_base, c(row_tile * 64 + row_idx * 8))
for col_idx in range(2):
value = vector.extractelement(elem, position=c(col_idx))
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
memref.store(value, ref, [row, col])
def store_tiled(self, ref, swizzle: int | None):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError
bw = mgpu.bytewidth(self.mlir_dtype)
m, n = self.shape # pytype: disable=bad-unpacking
assert m % 64 == 0 # This is implied by the layout.
if n % 32 != 0:
raise NotImplementedError
cols_per_tile = 128 // bw
expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile]
if ir.MemRefType(ref.type).shape != expected_shape:
raise ValueError(ref.type, (m, n))
if swizzle != 128:
raise NotImplementedError("Only 128B swizzle supported")
index = ir.IndexType.get()
def c(x):
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
tidx = gpu.thread_id(gpu.Dimension.x)
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7}
if bw > 2: # Stagger is only necessary for values larger than 16bit.
is_even_row = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0)
)
else:
# We rely on canonicalization to clean up the selects.
i1 = ir.IntegerType.get_signless(1)
is_even_row = arith.constant(i1, ir.IntegerAttr.get(i1, 1))
row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16)))
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
# The swizzle pattern is constant for a given thread.
col_swizzle_bits = arith.muli(sub_row_base, c(16 // bw))
for row_group in range(m // 64):
for col_group in range(n // cols_per_tile):
for row_subidx in range(2):
row = arith.addi(row_base, c(row_subidx * 8))
for col_subidx in range(cols_per_tile // 8):
# We stagger the even and odd rows a little to avoid bank conflicts.
# It seems that the STS.64 is 2x faster (and the hardware reports no
# conflicts) when the conflicts are split between half-warps, as
# opposed to having them within the half-warp. This requires a
# little more work for the selects, but is ultimately worth it.
col_subidx_even = col_subidx
col_subidx_odd = col_subidx ^ 2
col_off = arith.select(
is_even_row, c(col_subidx_even * 8), c(col_subidx_odd * 8)
)
col = arith.addi(col_base, col_off)
col = arith.xori(col, col_swizzle_bits)
reg_idx_even = col_subidx_even + col_group * (cols_per_tile // 8)
reg_idx_odd = col_subidx_odd + col_group * (cols_per_tile // 8)
value_even = self.registers[row_group, reg_idx_even, row_subidx, 0]
value_odd = self.registers[row_group, reg_idx_odd, row_subidx, 0]
value = arith.select(is_even_row, value_even, value_odd)
vector.store(value, ref, [c(row_group), c(col_group), row, col])
def tree_flatten(self):
return list(self.registers.flat), (self.layout, self.registers.shape)
@classmethod
def tree_unflatten(cls, aux, flat_registers):
layout, reg_shape = aux
registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape)
return cls(_registers=registers, _layout=layout)