mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU] Introduce a more flexible layout system
So far all of our layouts have been tailored to a limited set of use cases we've tried so far, but they're still not general enough to handle all of the register layouts needed for WGMMA or mixed precision matmuls (incl. intermediate steps during conversions). Instead of adding more special cases, I decided to adopt XLA tiled layouts and they do seem to work quite well! This change only lays the groundwork for the new layout system. Future changes will build upon them to add new features and eventually replace `WGMMA_LAYOUT` altogether. PiperOrigin-RevId: 694105514
This commit is contained in:
parent
f8dba3c8a4
commit
de06584d98
@ -14,10 +14,12 @@
|
||||
# ==============================================================================
|
||||
"""Utilities for code generator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
from typing import Callable
|
||||
from typing import Sequence, TypeVar, Iterable
|
||||
|
||||
import jax
|
||||
from jaxlib.mlir import ir
|
||||
@ -35,10 +37,276 @@ from . import utils
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
T = TypeVar("T")
|
||||
WARPGROUP_SIZE = utils.WARPGROUP_SIZE
|
||||
WARP_SIZE = 32
|
||||
WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE
|
||||
c = utils.c
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Tiling:
|
||||
"""A tiling expression describing a permutation of elements of an nd-array.
|
||||
|
||||
To apply one level of tiling to an array, each of the trailing dimensions (up
|
||||
to the rank of the tile) is unfolded into two dimensions: first equal to the
|
||||
ratio of the dimension size and the tile size, and second equal to the tile
|
||||
size. Then, all newly unfolded minor dimensions are transposed to appear at
|
||||
the end.
|
||||
|
||||
This expression describes multi-level tiling, by applying each element of
|
||||
`tiles` in sequence to the array.
|
||||
|
||||
See https://openxla.org/xla/tiled_layout for a more detailed explanation.
|
||||
"""
|
||||
tiles: tuple[tuple[int, ...], ...]
|
||||
|
||||
def __post_init__(self):
|
||||
max_rank = math.inf
|
||||
for tile in self.tiles:
|
||||
if not tile:
|
||||
raise ValueError("Tiles must not be empty")
|
||||
if len(tile) > max_rank:
|
||||
raise ValueError("Tile ranks must be non-increasing")
|
||||
max_rank = len(tile)
|
||||
if any(d <= 0 for d in tile):
|
||||
raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}")
|
||||
|
||||
def __str__(self):
|
||||
return f"Tiling({''.join(map(str, self.tiles))})"
|
||||
|
||||
def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""Computes the shape of an array after tiling."""
|
||||
def fail():
|
||||
raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}")
|
||||
for tile in self.tiles:
|
||||
if len(tile) > len(shape):
|
||||
fail()
|
||||
untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):]
|
||||
if any(s % t != 0 for s, t in zip(tiled_dims, tile)):
|
||||
fail()
|
||||
shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile)
|
||||
return shape
|
||||
|
||||
def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""Computes the shape of an array before tiling from its tiled shape."""
|
||||
def fail():
|
||||
raise ValueError("Shape does not look like it's been tiled?")
|
||||
for tile in reversed(self.tiles):
|
||||
if len(tile) > len(shape):
|
||||
fail()
|
||||
untiled_dims = shape[:-2 * len(tile)]
|
||||
tiled_dims = shape[-2 * len(tile):-len(tile)]
|
||||
tiling_dims = shape[-len(tile):]
|
||||
if tiling_dims != tile:
|
||||
fail()
|
||||
shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile)))
|
||||
return shape
|
||||
|
||||
def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""Computes the strides of an array after tiling."""
|
||||
for tile in self.tiles:
|
||||
untiled, tiled = strides[:-len(tile)], strides[-len(tile):]
|
||||
strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled)
|
||||
return strides
|
||||
|
||||
|
||||
def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]:
|
||||
"""Like built-in enumerate, but returns negative indices into the sequence."""
|
||||
offset = len(elems)
|
||||
for i, e in enumerate(elems):
|
||||
yield i - offset, e
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TiledLayout:
|
||||
"""A FragmentedArray layout derived from a tiling expression.
|
||||
|
||||
A logical array is transformed according to the tiling expression, and then
|
||||
split across warps (within a warpgroup), lanes, and vectorized according to
|
||||
the dimension indices. All dimension indices must be negative and should refer
|
||||
to the dimensions after tiling is applied.
|
||||
|
||||
Note that warp_dim and vector_dim could be sets as well, but we don't have a
|
||||
usecase for that yet.
|
||||
|
||||
To better understand this layout, consider the example of WGMMA-related tiling
|
||||
from https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d as
|
||||
applied to a 128x128 array. The corresponding TiledLayout has a tiling of:
|
||||
|
||||
(64, 8)(16, 8)(8, 8)(1, 2)
|
||||
|
||||
and warp_dim=-8, lane_dims={-4, -3}, vector_dim=-1.
|
||||
|
||||
We begin by applying the tiling (note that it always applies to a suffix):
|
||||
|
||||
Tiled shape Remaining tiling actions
|
||||
===========================================================================
|
||||
128 128 (64, 8)(16, 8)(8, 8)(1, 2)
|
||||
2 16 64 8 (16, 8)(8, 8)(1, 2)
|
||||
2 16 4 1 16 8 (8, 8)(1, 2)
|
||||
2 16 4 1 2 1 8 8 (1, 2)
|
||||
2 16 4 1 2 1 8 4 1 2
|
||||
|
||||
The last expression is our final shape. At this stage, we're ready to
|
||||
interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the
|
||||
end is partitioned over 4 warps in a warpgroup (and so it must be of size 4).
|
||||
lane_dims={-4, -3} indicate that those two dimensions are partitioned over
|
||||
the lanes within a warp (their product must be equal to 32, i.e. warp size).
|
||||
Finally, vector_dim=-1 indicates that each (logical) register is a vector
|
||||
containing 2 elements (there are no shape restrictions here).
|
||||
|
||||
Given the above, the shape of the (logical) register array used to represent
|
||||
the array in each thread is: (2, 16, 1, 1, 2, 1, 1, 1, 1, 1). We have set all
|
||||
the dimensions above to 1, since each thread is a member of a single warp,
|
||||
a single lane, and the elements along the vectorized dimension are represented
|
||||
by a single (logical) register.
|
||||
"""
|
||||
tiling: Tiling
|
||||
warp_dim: int
|
||||
lane_dims: frozenset[int]
|
||||
vector_dim: int
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tiling.tiles:
|
||||
raise ValueError("Tiling must have at least one tile")
|
||||
min_shape = self.tiling.tiles[0]
|
||||
min_tiled_shape = self.tiling.tile_shape(min_shape)
|
||||
dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim}
|
||||
if len(dims_set) != len(self.lane_dims) + 2:
|
||||
raise ValueError
|
||||
for d in dims_set:
|
||||
if d >= 0:
|
||||
raise ValueError("All dimensions must be negative")
|
||||
if d < -(len(min_tiled_shape) - len(min_shape)):
|
||||
raise ValueError("Dimension out of range")
|
||||
if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP:
|
||||
raise ValueError
|
||||
if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE:
|
||||
raise ValueError
|
||||
|
||||
@functools.cached_property
|
||||
def tiled_tiling_shape(self) -> tuple[int, ...]:
|
||||
"""The shape of the suffix of the array after tiling.
|
||||
|
||||
We only allow our repeated tiling actions to further subdivide the
|
||||
dimensions created by previous tiling actions (except for the first one),
|
||||
so the tiled shape always ends with this suffix, no matter what array shape
|
||||
it's applied to.
|
||||
"""
|
||||
return self.tiling.tile_shape(self.tiling.tiles[0])
|
||||
|
||||
@property
|
||||
def vector_length(self) -> int:
|
||||
return self.tiled_tiling_shape[self.vector_dim]
|
||||
|
||||
def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""Returns the shape of the register array needed to represent an array of the given logical shape."""
|
||||
tiled_shape = list(self.tiling.tile_shape(shape))
|
||||
tiled_shape[self.warp_dim] = 1
|
||||
for d in self.lane_dims:
|
||||
tiled_shape[d] = 1
|
||||
tiled_shape[self.vector_dim] = 1
|
||||
return tuple(tiled_shape)
|
||||
|
||||
def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""Returns the logical shape of an array given its register array shape.
|
||||
|
||||
Inverse to `registers_shape`.
|
||||
"""
|
||||
tiled_tiling = self.tiled_tiling_shape
|
||||
shape = list(shape)
|
||||
shape[self.warp_dim] = WARPS_IN_WARPGROUP
|
||||
for d in self.lane_dims:
|
||||
shape[d] = tiled_tiling[d]
|
||||
shape[self.vector_dim] = tiled_tiling[self.vector_dim]
|
||||
return self.tiling.untile_shape(tuple(shape))
|
||||
|
||||
def lane_indices(self) -> tuple[ir.Value, ...]:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
tiled_shape = tuple(
|
||||
d if i in self.lane_dims else 1
|
||||
for i, d in enumerate_negative(self.tiled_tiling_shape)
|
||||
)
|
||||
assert math.prod(tiled_shape) == WARP_SIZE
|
||||
lane_strides = utils.get_contiguous_strides(tiled_shape)
|
||||
lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32))
|
||||
return tuple(
|
||||
arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32))
|
||||
for stride, size in zip(lane_strides, tiled_shape)
|
||||
)
|
||||
|
||||
def warp_indices(self) -> tuple[ir.Value, ...]:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
tiled_shape = tuple(
|
||||
d if i == self.warp_dim else 1
|
||||
for i, d in enumerate_negative(self.tiled_tiling_shape)
|
||||
)
|
||||
assert math.prod(tiled_shape) == WARPS_IN_WARPGROUP
|
||||
warp_idx = arith.remui(
|
||||
arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)),
|
||||
c(WARPS_IN_WARPGROUP, i32),
|
||||
)
|
||||
indices = [arith.constant(i32, 0)] * len(tiled_shape)
|
||||
indices[self.warp_dim] = warp_idx
|
||||
return tuple(indices)
|
||||
|
||||
|
||||
def _tiled_wgmma_layout(shape: tuple[int, ...]):
|
||||
"""Returns the tiled layout relevant for WGMMA operations.
|
||||
|
||||
The tiled layout is equivalent to one described here in PTX documentation:
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d
|
||||
|
||||
This tiled layout is equivalent to WGMMAFragLayout and will subsume it.
|
||||
"""
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Shape {shape} is not 2D")
|
||||
if shape[0] % 64 != 0 or shape[1] % 8 != 0:
|
||||
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
|
||||
return TiledLayout(
|
||||
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
|
||||
warp_dim=-8,
|
||||
lane_dims=frozenset((-4, -3)),
|
||||
vector_dim=-1,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGMMAFragLayout:
|
||||
"""[m, n] matrix, where m % 64 == 0 == n % 8."""
|
||||
|
||||
def thread_idxs(self, shape):
|
||||
index = ir.IndexType.get()
|
||||
assert shape[0] % 64 == 0 and shape[1] % 8 == 0
|
||||
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
|
||||
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
|
||||
warp_idx = arith.divui(tid_wg, c(32, index))
|
||||
tid_warp = arith.remui(tid_wg, c(32, index))
|
||||
col_base = arith.muli(arith.remui(tid_warp, c(4, index)), c(2, index))
|
||||
row_base = arith.addi(
|
||||
arith.divui(tid_warp, c(4, index)), arith.muli(warp_idx, c(16, index))
|
||||
)
|
||||
for row_group in range(0, shape[0], 64):
|
||||
for col_group in range(0, shape[1], 8):
|
||||
for row_subgroup in range(0, 16, 8):
|
||||
row = arith.addi(row_base, c(row_group + row_subgroup, index))
|
||||
yield row, arith.addi(col_base, c(col_group, index))
|
||||
|
||||
def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
||||
assert len(shape) == 2
|
||||
assert shape[0] % 64 == 0 and shape[1] % 8 == 0
|
||||
return (shape[0] // 64, shape[1] // 8, 2, 1)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGMMARowFragLayout:
|
||||
"""[m] matrix, where m % 64 == 0."""
|
||||
|
||||
def thread_idxs(self, shape):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGSplatFragLayout:
|
||||
"""A fragmented array where all the values are equal represented as a register per thread.
|
||||
@ -75,36 +343,6 @@ class WGSplatFragLayout:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGMMAFragLayout:
|
||||
"""[m, n] matrix, where m % 64 == 0 == n % 8."""
|
||||
|
||||
def thread_idxs(self, shape):
|
||||
index = ir.IndexType.get()
|
||||
assert shape[0] % 64 == 0 and shape[1] % 8 == 0
|
||||
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
|
||||
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
|
||||
warp_idx = arith.divui(tid_wg, c(32, index))
|
||||
tid_warp = arith.remui(tid_wg, c(32, index))
|
||||
col_base = arith.muli(arith.remui(tid_warp, c(4, index)), c(2, index))
|
||||
row_base = arith.addi(
|
||||
arith.divui(tid_warp, c(4, index)), arith.muli(warp_idx, c(16, index))
|
||||
)
|
||||
for row_group in range(0, shape[0], 64):
|
||||
for col_group in range(0, shape[1], 8):
|
||||
for row_subgroup in range(0, 16, 8):
|
||||
row = arith.addi(row_base, c(row_group + row_subgroup, index))
|
||||
yield row, arith.addi(col_base, c(col_group, index))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGMMARowFragLayout:
|
||||
"""[m] matrix, where m % 64 == 0."""
|
||||
|
||||
def thread_idxs(self, shape):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WGStridedFragLayout:
|
||||
"""Convert the array to 1D and then shard across threads."""
|
||||
@ -162,7 +400,7 @@ class WGStridedFragLayout:
|
||||
yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))
|
||||
|
||||
|
||||
FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout
|
||||
FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout | TiledLayout
|
||||
|
||||
|
||||
WGMMA_LAYOUT = WGMMAFragLayout()
|
||||
@ -230,6 +468,14 @@ class FragmentedArray:
|
||||
if _registers.size != 1:
|
||||
raise ValueError(f"Invalid register array shape: {_registers.shape}")
|
||||
|
||||
case TiledLayout():
|
||||
try:
|
||||
self.layout.shape_from_registers_shape(_registers.shape)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Register array shape does not match the tiled layout"
|
||||
) from None
|
||||
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -304,15 +550,21 @@ class FragmentedArray:
|
||||
return shape
|
||||
case WGSplatFragLayout(shape=shape):
|
||||
return shape
|
||||
case TiledLayout():
|
||||
return self.layout.shape_from_registers_shape(self.registers.shape)
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def mlir_dtype(self):
|
||||
reg_ty = self.registers.flat[0].type
|
||||
match self.layout:
|
||||
case WGMMAFragLayout() | WGStridedFragLayout():
|
||||
case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout():
|
||||
return ir.VectorType(reg_ty).element_type
|
||||
case WGMMARowFragLayout() | WGSplatFragLayout():
|
||||
return reg_ty
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_layout(self, new_layout: FragmentedLayout):
|
||||
"""Converts the fragmented array to the given layout.
|
||||
@ -321,6 +573,17 @@ class FragmentedArray:
|
||||
"""
|
||||
if self.layout == new_layout:
|
||||
return self
|
||||
shape = self.shape
|
||||
if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 8 == 0:
|
||||
tiled_layout = _tiled_wgmma_layout(shape)
|
||||
if (self.layout == WGMMA_LAYOUT and new_layout == tiled_layout) or (
|
||||
self.layout == tiled_layout and new_layout == WGMMA_LAYOUT
|
||||
):
|
||||
return FragmentedArray(
|
||||
_registers=self.registers.reshape(new_layout.registers_shape(shape)),
|
||||
_layout=new_layout,
|
||||
_is_signed=self.is_signed,
|
||||
)
|
||||
if not isinstance(self.layout, WGSplatFragLayout):
|
||||
raise NotImplementedError(
|
||||
f"Cannot convert from {self.layout} to {new_layout}"
|
||||
@ -745,10 +1008,9 @@ class FragmentedArray:
|
||||
raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}")
|
||||
new_registers = np.empty_like(self.registers)
|
||||
match self.layout:
|
||||
case WGMMAFragLayout():
|
||||
new_reg_ty = ir.VectorType.get((2,), new_dtype)
|
||||
case WGStridedFragLayout(vec_size=vec_size):
|
||||
new_reg_ty = ir.VectorType.get((vec_size,), new_dtype)
|
||||
case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout():
|
||||
shape = ir.VectorType(self.registers.flat[0].type).shape
|
||||
new_reg_ty = ir.VectorType.get(shape, new_dtype)
|
||||
case WGMMARowFragLayout() | WGSplatFragLayout():
|
||||
new_reg_ty = new_dtype
|
||||
case _:
|
||||
@ -916,6 +1178,8 @@ class FragmentedArray:
|
||||
self._store_untiled_splat(ref)
|
||||
case WGStridedFragLayout():
|
||||
self._store_untiled_wg_strided(ref)
|
||||
case TiledLayout():
|
||||
self._store_untiled_tiled(ref)
|
||||
case _:
|
||||
raise NotImplementedError(self.layout)
|
||||
|
||||
@ -982,6 +1246,32 @@ class FragmentedArray:
|
||||
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
|
||||
memref.store(value, ref, [row, col])
|
||||
|
||||
def _store_untiled_tiled(self, ref: ir.Value):
|
||||
"""Stores an array with a tiled layout. Not optimized at the moment."""
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
layout = self.layout
|
||||
assert isinstance(layout, TiledLayout)
|
||||
ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset()
|
||||
if ref_strides[layout.vector_dim] != 1:
|
||||
raise NotImplementedError(
|
||||
"Can't use vector stores with non-unit minormost stride"
|
||||
)
|
||||
strides = layout.tiling.tile_strides(ref_strides)
|
||||
ptr = utils.memref_ptr(ref)
|
||||
# Fold warp and lane offsets into the pointer once, since they are dynamic.
|
||||
dyn_strides = [arith.constant(i32, s) for s in strides]
|
||||
def dyn_dot(x, y):
|
||||
return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y)))
|
||||
warp_offset = dyn_dot(layout.warp_indices(), dyn_strides)
|
||||
lane_offset = dyn_dot(layout.lane_indices(), dyn_strides)
|
||||
dyn_offset = arith.addi(warp_offset, lane_offset)
|
||||
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
|
||||
# All warp tile offsets are static and can be fused into the store.
|
||||
for tile_idx, reg in np.ndenumerate(self.registers):
|
||||
lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True))
|
||||
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
|
||||
llvm.store(reg, reg_ptr)
|
||||
|
||||
def store_tiled(self, ref, swizzle: int | None):
|
||||
if self.layout != WGMMA_LAYOUT:
|
||||
raise NotImplementedError
|
||||
|
@ -40,6 +40,7 @@ import numpy as np
|
||||
|
||||
WARPGROUP_SIZE: int = 128
|
||||
DYNAMIC = -9223372036854775808
|
||||
DYNAMIC32 = -2147483648
|
||||
|
||||
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
|
||||
|
||||
@ -1036,3 +1037,11 @@ def is_signed(dtype: jax.typing.DTypeLike) -> bool | None:
|
||||
elif jnp.issubdtype(dtype, jnp.integer):
|
||||
return jnp.issubdtype(dtype, jnp.signedinteger)
|
||||
return None
|
||||
|
||||
|
||||
def getelementptr(
|
||||
ptr: ir.Value, indices: Sequence[ir.Value | int], dtype: ir.Type
|
||||
) -> ir.Value:
|
||||
static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices]
|
||||
dyn_indices = [i for i in indices if not isinstance(i, int)]
|
||||
return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype)
|
||||
|
@ -29,6 +29,7 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
from jax.experimental.mosaic.gpu import fragmented_array as fa
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
try:
|
||||
@ -1583,5 +1584,29 @@ class TorchTest(TestCase):
|
||||
del y # Make sure the destructor runs successfully.
|
||||
|
||||
|
||||
class LayoutTest(TestCase):
|
||||
|
||||
@parameterized.product(
|
||||
shape=((128, 128), (64, 8), (64, 256)),
|
||||
dtype=(jnp.int32, jnp.int16, jnp.int8),
|
||||
)
|
||||
def test_wgmma_tiled_layout(self, shape, dtype):
|
||||
def kernel(ctx, dst, _):
|
||||
iota = iota_tensor(*shape, dtype)
|
||||
tiled = iota.to_layout(fa._tiled_wgmma_layout(shape))
|
||||
# Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1)
|
||||
self.assertEqual(
|
||||
tiled.registers.shape,
|
||||
(shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1),
|
||||
)
|
||||
self.assertEqual(tiled.shape, shape)
|
||||
self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype)
|
||||
tiled.store_untiled(dst)
|
||||
ty = jax.ShapeDtypeStruct(shape, dtype)
|
||||
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ())
|
||||
expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
||||
np.testing.assert_array_equal(f(), expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user