mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Use union to avoid excessive SMEM usage in the Blackwell matmul
PiperOrigin-RevId: 725667871
This commit is contained in:
parent
74e86bab26
commit
70007471c7
@ -71,7 +71,7 @@ def build_kernel(
|
||||
tile_n *= 2
|
||||
|
||||
def kernel(ctx, a, b, d, smem):
|
||||
a_smem, b_smem, d_smem, barriers, mma_done_barrier, acc = smem
|
||||
((a_smem, b_smem), d_smem), barriers, mma_done_barrier, acc = smem
|
||||
(ab_full_barriers, ab_empty_barriers) = barriers
|
||||
|
||||
warp_idx = mgpu.warp_idx(sync=True)
|
||||
@ -161,14 +161,24 @@ def build_kernel(
|
||||
gmem_transform=mgpu.TileTransform((128, 64)),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
# TODO(apaszke): Free up TMEM?
|
||||
ctx.await_async_copy(0)
|
||||
|
||||
# TODO(apaszke): Use a union for output SMEM.
|
||||
compute_buffers = (
|
||||
jax.ShapeDtypeStruct(
|
||||
mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k),
|
||||
(tma_tile_m, tma_tile_kn)),
|
||||
jnp.float16),
|
||||
jax.ShapeDtypeStruct(
|
||||
mgpu.tile_shape((max_concurrent_steps, tile_k, block_tile_n),
|
||||
(tma_tile_kn, tma_tile_kn)),
|
||||
jnp.float16),
|
||||
)
|
||||
epilogue_buffer = jax.ShapeDtypeStruct(
|
||||
mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)),
|
||||
jnp.float16)
|
||||
smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer])
|
||||
smem = (
|
||||
jax.ShapeDtypeStruct((max_concurrent_steps, *mgpu.tile_shape((block_tile_m, tile_k), (tma_tile_m, tma_tile_kn))), jnp.float16),
|
||||
jax.ShapeDtypeStruct((max_concurrent_steps, *mgpu.tile_shape((tile_k, block_tile_n), (tma_tile_kn, tma_tile_kn))), jnp.float16),
|
||||
jax.ShapeDtypeStruct(mgpu.tile_shape((block_tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
|
||||
smem_buffers,
|
||||
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
|
||||
mgpu.Barrier(arrival_count=1),
|
||||
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D, collective=collective),
|
||||
|
Loading…
x
Reference in New Issue
Block a user