[Mosaic GPU] Add a new tiled layout, optimized for upcasting before WGMMA

PiperOrigin-RevId: 707860467
This commit is contained in:
Adam Paszke 2024-12-19 04:02:44 -08:00 committed by jax authors
parent 66ad2082ba
commit 006c65d8d4
2 changed files with 59 additions and 18 deletions

View File

@ -65,15 +65,17 @@ class Tiling:
tiles: tuple[tuple[int, ...], ...]
def __post_init__(self):
max_rank = math.inf
if not self.tiles:
return
tiled_rank = len(self.tiles[0])
for tile in self.tiles:
if len(tile) > tiled_rank:
raise ValueError("Only the first tile can refer to value dimensions")
if not tile:
raise ValueError("Tiles must not be empty")
if len(tile) > max_rank:
raise ValueError("Tile ranks must be non-increasing")
max_rank = len(tile)
if any(d <= 0 for d in tile):
raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}")
tiled_rank += len(tile)
def __str__(self):
return f"Tiling({''.join(map(str, self.tiles))})"
@ -156,7 +158,7 @@ class TiledLayout:
(64, 8)(16, 8)(8, 8)(1, 2)
and warp_dim=-8, lane_dims={-4, -3}, vector_dim=-1.
and warp_dim=-8, lane_dims=(-4, -3), vector_dim=-1.
We begin by applying the tiling (note that it always applies to a suffix):
@ -171,7 +173,7 @@ class TiledLayout:
The last expression is our final shape. At this stage, we're ready to
interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the
end is partitioned over 4 warps in a warpgroup (and so it must be of size 4).
lane_dims={-4, -3} indicate that those two dimensions are partitioned over
lane_dims=(-4, -3) indicate that those two dimensions are partitioned over
the lanes within a warp (their product must be equal to 32, i.e. warp size).
Finally, vector_dim=-1 indicates that each (logical) register is a vector
containing 2 elements (there are no shape restrictions here).
@ -184,7 +186,7 @@ class TiledLayout:
"""
tiling: Tiling
warp_dim: int
lane_dims: frozenset[int]
lane_dims: tuple[int, ...] # major-to-minor
vector_dim: int
def __post_init__(self):
@ -253,19 +255,19 @@ class TiledLayout:
def lane_indices(self) -> tuple[ir.Value, ...]:
i32 = ir.IntegerType.get_signless(32)
tiled_shape = tuple(
d if i in self.lane_dims else 1
for i, d in enumerate_negative(self.tiled_tiling_shape)
)
assert math.prod(tiled_shape) == WARP_SIZE
lane_strides = utils.get_contiguous_strides(tiled_shape)
tiled_shape = self.tiled_tiling_shape
lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims)
assert math.prod(lanes_shape) == WARP_SIZE
lane_strides = utils.get_contiguous_strides(lanes_shape)
lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32))
# TODO(apaszke): Rewrite so that we can be sure that this never actually
# does arithmetic for any dimensions that are not in lane_dims.
return tuple(
lane_indices = tuple(
arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32))
for stride, size in zip(lane_strides, tiled_shape)
for stride, size in zip(lane_strides, lanes_shape)
)
full_indices = [arith.constant(i32, 0)] * len(tiled_shape)
for d, i in zip(self.lane_dims, lane_indices):
full_indices[d] = i
return tuple(full_indices)
def warp_indices(self) -> tuple[ir.Value, ...]:
i32 = ir.IntegerType.get_signless(32)
@ -298,9 +300,22 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]):
return TiledLayout(
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
warp_dim=-8,
lane_dims=frozenset((-4, -3)),
lane_dims=(-4, -3),
vector_dim=-1,
)
def _tiled_wgmma_layout_for_upcast(shape: tuple[int, ...]):
"""Returns a tiled layout that is easy to relayout to WGMMA layout after doubling the bitwidth."""
if len(shape) != 2:
raise ValueError(f"Shape {shape} is not 2D")
if shape[0] % 64 != 0 or shape[1] % 8 != 0:
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
t = Tiling(((64, 16), (16, 16), (8, 16), (4,), (2, 1)))
return TiledLayout(
t,
warp_dim=-9,
lane_dims=(-5, -2, -4),
vector_dim=-3,
)
@dataclasses.dataclass(frozen=True)

View File

@ -1865,6 +1865,32 @@ class LayoutTest(TestCase):
used_regs = {get_reg(addr) for addr in addrs}
self.assertLessEqual(len(used_regs), expected_regs)
def test_copy_for_upcast(self):
dtype = jnp.int8
swizzle = 128
col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype))
m, n = 128, col_tiling * 2
tiling = (64, col_tiling)
tiled_layout = fa._tiled_wgmma_layout_for_upcast((m, n))
def kernel(ctx, in_, out, smems):
smem_in, smem_out, barrier = smems
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
barrier.wait()
t = mgpu.FragmentedArray.load_tiled(
smem_in, swizzle=swizzle, is_signed=True, layout=tiled_layout
)
t.store_tiled(smem_out, swizzle=swizzle)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
ctx.await_async_copy(0)
x = jax.random.randint(
jax.random.key(42), tile_shape((m, n), tiling), -128, 127, dtype=dtype
)
f = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, x, [x, x, mgpu.TMABarrier()],
)
np.testing.assert_array_equal(f(x), x)
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
"""Device tests with lowering from the MLIR dialect and layout inference."""