mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add support for clusters in the matmul example
With the collective async_copy API, the changes are quite minimal! PiperOrigin-RevId: 655937185
This commit is contained in:
parent
e59303cf3e
commit
2e6da35e97
@ -17,6 +17,7 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jax import random
|
||||
@ -112,6 +113,7 @@ def build_kernel(
|
||||
stages: int = 2,
|
||||
tile_m: int = 128,
|
||||
tile_n: int = 128,
|
||||
cluster: tuple[int, int] = (1, 1),
|
||||
rhs_transpose: bool = False,
|
||||
wgmma_impl=WGMMADefaultImpl,
|
||||
profiler_spec: profiler.ProfilerSpec | None = None,
|
||||
@ -167,7 +169,8 @@ def build_kernel(
|
||||
smem_shape = mosaic_gpu.Union([compute_scratch_shape, epilogue_scratch_shape])
|
||||
|
||||
def _main(ctx, a_device, b_device, c_device, smem):
|
||||
((lhs_smem, rhs_smem, impl_smem), epilogue_smem), barriers = smem
|
||||
((lhs_smem, rhs_smem, impl_smem), epilogue_smem), *barriers = smem
|
||||
tma_barriers, cluster_barrier = barriers
|
||||
|
||||
memref.assume_alignment(c_device, 16)
|
||||
|
||||
@ -175,7 +178,7 @@ def build_kernel(
|
||||
n_start = arith.muli(c(block_tiling.n), gpu.block_id(gpu.Dimension.y))
|
||||
|
||||
def fetch(slot, ki):
|
||||
barrier = barriers[slot]
|
||||
barrier = tma_barriers[slot]
|
||||
k_start = arith.muli(c(block_tiling.k), ki)
|
||||
lhs_tma_tile_bytes = int(np.prod(block_tiling.mk) * lhs_elem_bytes)
|
||||
rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes)
|
||||
@ -190,6 +193,7 @@ def build_kernel(
|
||||
dst_ref=memref_slice(lhs_smem, slot),
|
||||
gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)),
|
||||
gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk),
|
||||
collective=gpu.Dimension.y,
|
||||
**common_copy_args,
|
||||
)
|
||||
rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n))
|
||||
@ -203,6 +207,7 @@ def build_kernel(
|
||||
dst_ref=memref_slice(rhs_smem, slot),
|
||||
gmem_slice=rhs_slice,
|
||||
gmem_transform=rhs_transform,
|
||||
collective=gpu.Dimension.x,
|
||||
**common_copy_args,
|
||||
)
|
||||
|
||||
@ -217,7 +222,7 @@ def build_kernel(
|
||||
si = arith.remui(ki, c(stages))
|
||||
|
||||
with ctx.named_region("TMA wait"):
|
||||
barriers[si].wait()
|
||||
tma_barriers[si].wait()
|
||||
|
||||
with ctx.named_region("WGMMA"):
|
||||
a_slice = memref_slice(lhs_smem, si)
|
||||
@ -236,6 +241,10 @@ def build_kernel(
|
||||
)
|
||||
do_tma = arith.andi(not_first_step, tma_ki_in_bounds)
|
||||
with ir.InsertionPoint(scf.IfOp(do_tma).then_block):
|
||||
if cluster_barrier is not None:
|
||||
with ctx.named_region("Cluster barrier"):
|
||||
cluster_barrier[tma_si].arrive()
|
||||
cluster_barrier[tma_si].wait() # Make sure everyone is done.
|
||||
fetch(tma_si, tma_ki)
|
||||
scf.yield_([])
|
||||
|
||||
@ -269,8 +278,16 @@ def build_kernel(
|
||||
jax.ShapeDtypeStruct((n, k) if rhs_transpose else (k, n), rhs_dtype),
|
||||
),
|
||||
jax.ShapeDtypeStruct((m, n), jnp.float32),
|
||||
(smem_shape, TMABarrier(stages)),
|
||||
(
|
||||
smem_shape,
|
||||
TMABarrier(num_barriers=stages),
|
||||
ClusterBarrier(
|
||||
collective_dims=(gpu.Dimension.x, gpu.Dimension.y),
|
||||
num_barriers=stages,
|
||||
) if math.prod(cluster) > 1 else None,
|
||||
),
|
||||
profiler_spec,
|
||||
cluster=(*cluster, 1),
|
||||
)
|
||||
|
||||
|
||||
@ -281,6 +298,8 @@ def verify(
|
||||
stages=4,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
cluster_m=1,
|
||||
cluster_n=1,
|
||||
profile=False,
|
||||
lhs_dtype=jnp.float16,
|
||||
rhs_dtype=jnp.float16,
|
||||
@ -305,6 +324,7 @@ def verify(
|
||||
stages=stages,
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
cluster=(cluster_m, cluster_n),
|
||||
rhs_transpose=rhs_transpose,
|
||||
wgmma_impl=impl,
|
||||
profiler_spec=prof_spec,
|
||||
@ -334,8 +354,8 @@ def verify(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
m, k, n = 33 * 128, 2048, 4 * 128
|
||||
runtime, ref_runtime = verify(m=m, k=k, n=n)
|
||||
m, k, n = 4 * 33 * 128, 2048, 4 * 128
|
||||
runtime, ref_runtime = verify(m=m, k=k, n=n, cluster_m=1, cluster_n=4)
|
||||
tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12
|
||||
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
|
||||
print(f"Kernel: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
|
||||
|
@ -117,6 +117,36 @@ class MatmulTestCase(jtu.JaxTestCase):
|
||||
self.skipTest("Not enough shared memory for test, skipping.")
|
||||
raise e
|
||||
|
||||
@parameterized.product(
|
||||
m=(512, 2048),
|
||||
n=(512, 2048),
|
||||
k=(512, 2048),
|
||||
stages=(2, 4),
|
||||
tile_m=(64, 128),
|
||||
tile_n=(64, 128),
|
||||
cluster_m=(1, 2, 4),
|
||||
cluster_n=(1, 2, 4),
|
||||
)
|
||||
def test_matmul_clusters(self, m, k, n, stages, tile_m, tile_n, cluster_m, cluster_n):
|
||||
try:
|
||||
matmul.verify(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
stages,
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
cluster_m=cluster_m,
|
||||
cluster_n=cluster_n,
|
||||
lhs_dtype=jnp.float32,
|
||||
rhs_dtype=jnp.float32,
|
||||
rhs_transpose=True,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "Mosaic GPU kernel exceeds available shared memory" in str(e):
|
||||
self.skipTest("Not enough shared memory for test, skipping.")
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user