[Mosaic GPU] Use union to avoid excessive SMEM usage in the Blackwell matmul

PiperOrigin-RevId: 725667871
This commit is contained in:
Adam Paszke 2025-02-11 09:47:54 -08:00 committed by jax authors
parent 74e86bab26
commit 70007471c7

View File

@ -71,7 +71,7 @@ def build_kernel(
tile_n *= 2
def kernel(ctx, a, b, d, smem):
a_smem, b_smem, d_smem, barriers, mma_done_barrier, acc = smem
((a_smem, b_smem), d_smem), barriers, mma_done_barrier, acc = smem
(ab_full_barriers, ab_empty_barriers) = barriers
warp_idx = mgpu.warp_idx(sync=True)
@ -161,14 +161,24 @@ def build_kernel(
gmem_transform=mgpu.TileTransform((128, 64)),
swizzle=swizzle,
)
# TODO(apaszke): Free up TMEM?
ctx.await_async_copy(0)
# TODO(apaszke): Use a union for output SMEM.
compute_buffers = (
jax.ShapeDtypeStruct(
mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k),
(tma_tile_m, tma_tile_kn)),
jnp.float16),
jax.ShapeDtypeStruct(
mgpu.tile_shape((max_concurrent_steps, tile_k, block_tile_n),
(tma_tile_kn, tma_tile_kn)),
jnp.float16),
)
epilogue_buffer = jax.ShapeDtypeStruct(
mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)),
jnp.float16)
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
smem = (
jax.ShapeDtypeStruct((max_concurrent_steps, *mgpu.tile_shape((block_tile_m, tile_k), (tma_tile_m, tma_tile_kn))), jnp.float16),
jax.ShapeDtypeStruct((max_concurrent_steps, *mgpu.tile_shape((tile_k, block_tile_n), (tma_tile_kn, tma_tile_kn))), jnp.float16),
jax.ShapeDtypeStruct(mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
smem_buffers,
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
mgpu.Barrier(arrival_count=1),
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D, collective=collective),