[Mosaic GPU] Add support for clusters in the matmul example

With the collective async_copy API, the changes are quite minimal!

PiperOrigin-RevId: 655937185
This commit is contained in:
Adam Paszke 2024-07-25 06:46:04 -07:00 committed by jax authors
parent e59303cf3e
commit 2e6da35e97
2 changed files with 56 additions and 6 deletions

View File

@ -17,6 +17,7 @@
import dataclasses
import functools
from typing import Any
import math
import jax
from jax import random
@ -112,6 +113,7 @@ def build_kernel(
stages: int = 2,
tile_m: int = 128,
tile_n: int = 128,
cluster: tuple[int, int] = (1, 1),
rhs_transpose: bool = False,
wgmma_impl=WGMMADefaultImpl,
profiler_spec: profiler.ProfilerSpec | None = None,
@ -167,7 +169,8 @@ def build_kernel(
smem_shape = mosaic_gpu.Union([compute_scratch_shape, epilogue_scratch_shape])
def _main(ctx, a_device, b_device, c_device, smem):
((lhs_smem, rhs_smem, impl_smem), epilogue_smem), barriers = smem
((lhs_smem, rhs_smem, impl_smem), epilogue_smem), *barriers = smem
tma_barriers, cluster_barrier = barriers
memref.assume_alignment(c_device, 16)
@ -175,7 +178,7 @@ def build_kernel(
n_start = arith.muli(c(block_tiling.n), gpu.block_id(gpu.Dimension.y))
def fetch(slot, ki):
barrier = barriers[slot]
barrier = tma_barriers[slot]
k_start = arith.muli(c(block_tiling.k), ki)
lhs_tma_tile_bytes = int(np.prod(block_tiling.mk) * lhs_elem_bytes)
rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes)
@ -190,6 +193,7 @@ def build_kernel(
dst_ref=memref_slice(lhs_smem, slot),
gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)),
gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk),
collective=gpu.Dimension.y,
**common_copy_args,
)
rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n))
@ -203,6 +207,7 @@ def build_kernel(
dst_ref=memref_slice(rhs_smem, slot),
gmem_slice=rhs_slice,
gmem_transform=rhs_transform,
collective=gpu.Dimension.x,
**common_copy_args,
)
@ -217,7 +222,7 @@ def build_kernel(
si = arith.remui(ki, c(stages))
with ctx.named_region("TMA wait"):
barriers[si].wait()
tma_barriers[si].wait()
with ctx.named_region("WGMMA"):
a_slice = memref_slice(lhs_smem, si)
@ -236,6 +241,10 @@ def build_kernel(
)
do_tma = arith.andi(not_first_step, tma_ki_in_bounds)
with ir.InsertionPoint(scf.IfOp(do_tma).then_block):
if cluster_barrier is not None:
with ctx.named_region("Cluster barrier"):
cluster_barrier[tma_si].arrive()
cluster_barrier[tma_si].wait() # Make sure everyone is done.
fetch(tma_si, tma_ki)
scf.yield_([])
@ -269,8 +278,16 @@ def build_kernel(
jax.ShapeDtypeStruct((n, k) if rhs_transpose else (k, n), rhs_dtype),
),
jax.ShapeDtypeStruct((m, n), jnp.float32),
(smem_shape, TMABarrier(stages)),
(
smem_shape,
TMABarrier(num_barriers=stages),
ClusterBarrier(
collective_dims=(gpu.Dimension.x, gpu.Dimension.y),
num_barriers=stages,
) if math.prod(cluster) > 1 else None,
),
profiler_spec,
cluster=(*cluster, 1),
)
@ -281,6 +298,8 @@ def verify(
stages=4,
tile_m=128,
tile_n=128,
cluster_m=1,
cluster_n=1,
profile=False,
lhs_dtype=jnp.float16,
rhs_dtype=jnp.float16,
@ -305,6 +324,7 @@ def verify(
stages=stages,
tile_m=tile_m,
tile_n=tile_n,
cluster=(cluster_m, cluster_n),
rhs_transpose=rhs_transpose,
wgmma_impl=impl,
profiler_spec=prof_spec,
@ -334,8 +354,8 @@ def verify(
if __name__ == "__main__":
m, k, n = 33 * 128, 2048, 4 * 128
runtime, ref_runtime = verify(m=m, k=k, n=n)
m, k, n = 4 * 33 * 128, 2048, 4 * 128
runtime, ref_runtime = verify(m=m, k=k, n=n, cluster_m=1, cluster_n=4)
tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
print(f"Kernel: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")

View File

@ -117,6 +117,36 @@ class MatmulTestCase(jtu.JaxTestCase):
self.skipTest("Not enough shared memory for test, skipping.")
raise e
@parameterized.product(
m=(512, 2048),
n=(512, 2048),
k=(512, 2048),
stages=(2, 4),
tile_m=(64, 128),
tile_n=(64, 128),
cluster_m=(1, 2, 4),
cluster_n=(1, 2, 4),
)
def test_matmul_clusters(self, m, k, n, stages, tile_m, tile_n, cluster_m, cluster_n):
try:
matmul.verify(
m,
k,
n,
stages,
tile_m=tile_m,
tile_n=tile_n,
cluster_m=cluster_m,
cluster_n=cluster_n,
lhs_dtype=jnp.float32,
rhs_dtype=jnp.float32,
rhs_transpose=True,
)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" in str(e):
self.skipTest("Not enough shared memory for test, skipping.")
raise e
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())