[Mosaic GPU] Add support for grid tiling to improve L2 cache utilization

While CUDA technically does not guarantee anything about the order in
which blocks will be executed, in practice they are generally scheduled
in column-major order within the grid. We can use this property to launch
the blocks in a tiled way, which can lead to an improved rate of L2 hits
and a significant performance boost.

PiperOrigin-RevId: 662834982
This commit is contained in:
Adam Paszke 2024-08-14 02:17:18 -07:00 committed by jax authors
parent f384497f68
commit 2ab7558425
2 changed files with 46 additions and 19 deletions

View File

@ -117,6 +117,7 @@ def build_kernel(
swizzle: int = 128,
cluster_m: int = 1,
cluster_n: int = 1,
grid_tile_n: int = 1,
rhs_transpose: bool = False,
wgmma_impl=WGMMADefaultImpl,
profiler_spec: profiler.ProfilerSpec | None = None,
@ -126,8 +127,8 @@ def build_kernel(
raise ValueError(f"{tile_m=} must be divisible by 64")
if m % tile_m != 0:
raise ValueError(f"{m=} must be divisible by {tile_m=}")
if n % 64 != 0:
raise ValueError(f"n must be divisible by 64, but got {n=}")
if n % tile_n != 0:
raise ValueError(f"{n=} must be divisible by {tile_n=}")
if stages < 2:
raise ValueError(f"Need at least 2 stages, but got {stages=}")
if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2:
@ -174,7 +175,11 @@ def build_kernel(
assert x % y == 0, (x, y)
return x // y
grid = (safe_div(m, block_tiling.m), safe_div(n, block_tiling.n), 1)
grid = (
grid_tile_n,
safe_div(m, block_tiling.m),
safe_div(n, block_tiling.n * grid_tile_n),
)
block = (128, 1, 1)
c = arith.ConstantOp.create_index
@ -191,10 +196,12 @@ def build_kernel(
((lhs_smem, rhs_smem, impl_smem), epilogue_smem), *barriers = smem
tma_barriers, cluster_barrier = barriers
memref.assume_alignment(c_device, 16)
m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.x))
n_start = arith.muli(c(block_tiling.n), gpu.block_id(gpu.Dimension.y))
m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.y))
n_block_idx = arith.addi(
gpu.block_id(gpu.Dimension.x),
arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_n)),
)
n_start = arith.muli(c(block_tiling.n), n_block_idx)
def fetch(slot, ki):
barrier = tma_barriers[slot]
@ -212,7 +219,7 @@ def build_kernel(
dst_ref=memref_slice(lhs_smem, slot),
gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)),
gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk),
collective=gpu.Dimension.y,
collective=(gpu.Dimension.x, gpu.Dimension.z),
**common_copy_args,
)
rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n))
@ -226,7 +233,7 @@ def build_kernel(
dst_ref=memref_slice(rhs_smem, slot),
gmem_slice=rhs_slice,
gmem_transform=rhs_transform,
collective=gpu.Dimension.x,
collective=gpu.Dimension.y,
**common_copy_args,
)
@ -290,6 +297,13 @@ def build_kernel(
)
ctx.await_async_copy(0)
cluster_tile_n = min(cluster_n, grid_tile_n)
if cluster_n % cluster_tile_n:
raise ValueError(
f"{cluster_n=} must be divisible by {cluster_tile_n} (due to"
f" {grid_tile_n=})"
)
cluster = (cluster_tile_n, cluster_m, cluster_n // cluster_tile_n)
return mosaic_gpu.as_gpu_kernel(
_main,
grid,
@ -303,12 +317,12 @@ def build_kernel(
smem_shape,
TMABarrier(num_barriers=stages),
ClusterBarrier(
collective_dims=(gpu.Dimension.x, gpu.Dimension.y),
collective_dims=((gpu.Dimension.x, gpu.Dimension.z), gpu.Dimension.y),
num_barriers=stages,
) if cluster_m * cluster_n > 1 else None,
),
profiler_spec,
cluster=(cluster_n, cluster_m, 1),
cluster=cluster,
)
@ -321,6 +335,7 @@ def verify(
tile_n=128,
cluster_m=1,
cluster_n=1,
grid_tile_n=1,
swizzle=128,
profile=False,
in_dtype=jnp.float16,
@ -344,6 +359,7 @@ def verify(
cluster_n=cluster_n,
rhs_transpose=rhs_transpose,
swizzle=swizzle,
grid_tile_n=grid_tile_n,
wgmma_impl=WGMMADefaultImpl,
profiler_spec=prof_spec,
)
@ -384,12 +400,13 @@ if __name__ == "__main__":
x = random.uniform(kx, (m, k), dtype=dtype)
y = random.uniform(ky, (k, n), dtype=dtype)
tile_m = tile_n = (64, 128, 256)
tile_m = tile_n = (64, 128)
cluster_m = cluster_n = (1, 2)
swizzle = (128,)
swizzle = (128,) # 64 can be a good choice for some shapes too!
stages = (2, 4, 5, 6)
configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle)
names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle")
grid_tile_n = (1, 4, 16)
configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle, grid_tile_n)
names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle", "grid_tile_n")
best_runtime = float("inf")
best_kwargs = {}
for config in configs:
@ -398,9 +415,15 @@ if __name__ == "__main__":
continue
if m < kwargs["tile_m"] or n < kwargs["tile_n"]:
continue
if (m // kwargs["tile_m"]) % kwargs["cluster_n"]:
if (m // kwargs["tile_m"]) % kwargs["cluster_m"]:
continue
if (n // kwargs["tile_n"]) % kwargs["cluster_m"]:
if (n // kwargs["tile_n"]) % kwargs["cluster_n"]:
continue
if n % kwargs["grid_tile_n"]:
continue
# This is a heuristic, not a strict correctness check. You can relax it
# for a more complete search space.
if kwargs["tile_m"] == kwargs["tile_n"] == 64:
continue
try:
f = build_kernel(

View File

@ -88,10 +88,13 @@ class MatmulTestCase(jtu.JaxTestCase):
tile_n = data.draw(
hps.sampled_from([t for t in [64, 128, 256] if t <= n]), label="tile_n"
)
grid_m, grid_n = m // tile_m, n // tile_n
grid_tile_n = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_n")
hp.assume(grid_n % grid_tile_n == 0)
cluster_m = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_m")
hp.assume((m // tile_m) % cluster_m == 0)
hp.assume(grid_m % cluster_m == 0)
cluster_n = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_n")
hp.assume((n // tile_n) % cluster_n == 0)
hp.assume(grid_n % cluster_n == 0)
# TODO(apaszke): Non-portable clusters (16 blocks) sometimes deadlock.
hp.assume(cluster_m * cluster_n <= 8)
if bytewidth == 4:
@ -111,6 +114,7 @@ class MatmulTestCase(jtu.JaxTestCase):
out_dtype=out_dtype,
cluster_m=cluster_m,
cluster_n=cluster_n,
grid_tile_n=grid_tile_n,
swizzle=swizzle,
rhs_transpose=rhs_transpose,
)