mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add a new tiled layout, optimized for upcasting before WGMMA
PiperOrigin-RevId: 707860467
This commit is contained in:
parent
66ad2082ba
commit
006c65d8d4
@ -65,15 +65,17 @@ class Tiling:
|
||||
tiles: tuple[tuple[int, ...], ...]
|
||||
|
||||
def __post_init__(self):
|
||||
max_rank = math.inf
|
||||
if not self.tiles:
|
||||
return
|
||||
tiled_rank = len(self.tiles[0])
|
||||
for tile in self.tiles:
|
||||
if len(tile) > tiled_rank:
|
||||
raise ValueError("Only the first tile can refer to value dimensions")
|
||||
if not tile:
|
||||
raise ValueError("Tiles must not be empty")
|
||||
if len(tile) > max_rank:
|
||||
raise ValueError("Tile ranks must be non-increasing")
|
||||
max_rank = len(tile)
|
||||
if any(d <= 0 for d in tile):
|
||||
raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}")
|
||||
tiled_rank += len(tile)
|
||||
|
||||
def __str__(self):
|
||||
return f"Tiling({''.join(map(str, self.tiles))})"
|
||||
@ -156,7 +158,7 @@ class TiledLayout:
|
||||
|
||||
(64, 8)(16, 8)(8, 8)(1, 2)
|
||||
|
||||
and warp_dim=-8, lane_dims={-4, -3}, vector_dim=-1.
|
||||
and warp_dim=-8, lane_dims=(-4, -3), vector_dim=-1.
|
||||
|
||||
We begin by applying the tiling (note that it always applies to a suffix):
|
||||
|
||||
@ -171,7 +173,7 @@ class TiledLayout:
|
||||
The last expression is our final shape. At this stage, we're ready to
|
||||
interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the
|
||||
end is partitioned over 4 warps in a warpgroup (and so it must be of size 4).
|
||||
lane_dims={-4, -3} indicate that those two dimensions are partitioned over
|
||||
lane_dims=(-4, -3) indicate that those two dimensions are partitioned over
|
||||
the lanes within a warp (their product must be equal to 32, i.e. warp size).
|
||||
Finally, vector_dim=-1 indicates that each (logical) register is a vector
|
||||
containing 2 elements (there are no shape restrictions here).
|
||||
@ -184,7 +186,7 @@ class TiledLayout:
|
||||
"""
|
||||
tiling: Tiling
|
||||
warp_dim: int
|
||||
lane_dims: frozenset[int]
|
||||
lane_dims: tuple[int, ...] # major-to-minor
|
||||
vector_dim: int
|
||||
|
||||
def __post_init__(self):
|
||||
@ -253,19 +255,19 @@ class TiledLayout:
|
||||
|
||||
def lane_indices(self) -> tuple[ir.Value, ...]:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
tiled_shape = tuple(
|
||||
d if i in self.lane_dims else 1
|
||||
for i, d in enumerate_negative(self.tiled_tiling_shape)
|
||||
)
|
||||
assert math.prod(tiled_shape) == WARP_SIZE
|
||||
lane_strides = utils.get_contiguous_strides(tiled_shape)
|
||||
tiled_shape = self.tiled_tiling_shape
|
||||
lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims)
|
||||
assert math.prod(lanes_shape) == WARP_SIZE
|
||||
lane_strides = utils.get_contiguous_strides(lanes_shape)
|
||||
lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32))
|
||||
# TODO(apaszke): Rewrite so that we can be sure that this never actually
|
||||
# does arithmetic for any dimensions that are not in lane_dims.
|
||||
return tuple(
|
||||
lane_indices = tuple(
|
||||
arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32))
|
||||
for stride, size in zip(lane_strides, tiled_shape)
|
||||
for stride, size in zip(lane_strides, lanes_shape)
|
||||
)
|
||||
full_indices = [arith.constant(i32, 0)] * len(tiled_shape)
|
||||
for d, i in zip(self.lane_dims, lane_indices):
|
||||
full_indices[d] = i
|
||||
return tuple(full_indices)
|
||||
|
||||
def warp_indices(self) -> tuple[ir.Value, ...]:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
@ -298,9 +300,22 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]):
|
||||
return TiledLayout(
|
||||
Tiling(((64, 8), (16, 8), (8, 8), (1, 2))),
|
||||
warp_dim=-8,
|
||||
lane_dims=frozenset((-4, -3)),
|
||||
lane_dims=(-4, -3),
|
||||
vector_dim=-1,
|
||||
)
|
||||
def _tiled_wgmma_layout_for_upcast(shape: tuple[int, ...]):
|
||||
"""Returns a tiled layout that is easy to relayout to WGMMA layout after doubling the bitwidth."""
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Shape {shape} is not 2D")
|
||||
if shape[0] % 64 != 0 or shape[1] % 8 != 0:
|
||||
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
|
||||
t = Tiling(((64, 16), (16, 16), (8, 16), (4,), (2, 1)))
|
||||
return TiledLayout(
|
||||
t,
|
||||
warp_dim=-9,
|
||||
lane_dims=(-5, -2, -4),
|
||||
vector_dim=-3,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -1865,6 +1865,32 @@ class LayoutTest(TestCase):
|
||||
used_regs = {get_reg(addr) for addr in addrs}
|
||||
self.assertLessEqual(len(used_regs), expected_regs)
|
||||
|
||||
def test_copy_for_upcast(self):
|
||||
dtype = jnp.int8
|
||||
swizzle = 128
|
||||
col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype))
|
||||
m, n = 128, col_tiling * 2
|
||||
tiling = (64, col_tiling)
|
||||
tiled_layout = fa._tiled_wgmma_layout_for_upcast((m, n))
|
||||
def kernel(ctx, in_, out, smems):
|
||||
smem_in, smem_out, barrier = smems
|
||||
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
||||
barrier.wait()
|
||||
t = mgpu.FragmentedArray.load_tiled(
|
||||
smem_in, swizzle=swizzle, is_signed=True, layout=tiled_layout
|
||||
)
|
||||
t.store_tiled(smem_out, swizzle=swizzle)
|
||||
mgpu.commit_shared()
|
||||
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
|
||||
ctx.await_async_copy(0)
|
||||
x = jax.random.randint(
|
||||
jax.random.key(42), tile_shape((m, n), tiling), -128, 127, dtype=dtype
|
||||
)
|
||||
f = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), x, x, [x, x, mgpu.TMABarrier()],
|
||||
)
|
||||
np.testing.assert_array_equal(f(x), x)
|
||||
|
||||
|
||||
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
"""Device tests with lowering from the MLIR dialect and layout inference."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user