mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add support for grid tiling to improve L2 cache utilization
While CUDA technically does not guarantee anything about the order in which blocks will be executed, in practice they are generally scheduled in column-major order within the grid. We can use this property to launch the blocks in a tiled way, which can lead to an improved rate of L2 hits and a significant performance boost. PiperOrigin-RevId: 662834982
This commit is contained in:
parent
f384497f68
commit
2ab7558425
@ -117,6 +117,7 @@ def build_kernel(
|
||||
swizzle: int = 128,
|
||||
cluster_m: int = 1,
|
||||
cluster_n: int = 1,
|
||||
grid_tile_n: int = 1,
|
||||
rhs_transpose: bool = False,
|
||||
wgmma_impl=WGMMADefaultImpl,
|
||||
profiler_spec: profiler.ProfilerSpec | None = None,
|
||||
@ -126,8 +127,8 @@ def build_kernel(
|
||||
raise ValueError(f"{tile_m=} must be divisible by 64")
|
||||
if m % tile_m != 0:
|
||||
raise ValueError(f"{m=} must be divisible by {tile_m=}")
|
||||
if n % 64 != 0:
|
||||
raise ValueError(f"n must be divisible by 64, but got {n=}")
|
||||
if n % tile_n != 0:
|
||||
raise ValueError(f"{n=} must be divisible by {tile_n=}")
|
||||
if stages < 2:
|
||||
raise ValueError(f"Need at least 2 stages, but got {stages=}")
|
||||
if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2:
|
||||
@ -174,7 +175,11 @@ def build_kernel(
|
||||
assert x % y == 0, (x, y)
|
||||
return x // y
|
||||
|
||||
grid = (safe_div(m, block_tiling.m), safe_div(n, block_tiling.n), 1)
|
||||
grid = (
|
||||
grid_tile_n,
|
||||
safe_div(m, block_tiling.m),
|
||||
safe_div(n, block_tiling.n * grid_tile_n),
|
||||
)
|
||||
block = (128, 1, 1)
|
||||
|
||||
c = arith.ConstantOp.create_index
|
||||
@ -191,10 +196,12 @@ def build_kernel(
|
||||
((lhs_smem, rhs_smem, impl_smem), epilogue_smem), *barriers = smem
|
||||
tma_barriers, cluster_barrier = barriers
|
||||
|
||||
memref.assume_alignment(c_device, 16)
|
||||
|
||||
m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.x))
|
||||
n_start = arith.muli(c(block_tiling.n), gpu.block_id(gpu.Dimension.y))
|
||||
m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.y))
|
||||
n_block_idx = arith.addi(
|
||||
gpu.block_id(gpu.Dimension.x),
|
||||
arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_n)),
|
||||
)
|
||||
n_start = arith.muli(c(block_tiling.n), n_block_idx)
|
||||
|
||||
def fetch(slot, ki):
|
||||
barrier = tma_barriers[slot]
|
||||
@ -212,7 +219,7 @@ def build_kernel(
|
||||
dst_ref=memref_slice(lhs_smem, slot),
|
||||
gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)),
|
||||
gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk),
|
||||
collective=gpu.Dimension.y,
|
||||
collective=(gpu.Dimension.x, gpu.Dimension.z),
|
||||
**common_copy_args,
|
||||
)
|
||||
rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n))
|
||||
@ -226,7 +233,7 @@ def build_kernel(
|
||||
dst_ref=memref_slice(rhs_smem, slot),
|
||||
gmem_slice=rhs_slice,
|
||||
gmem_transform=rhs_transform,
|
||||
collective=gpu.Dimension.x,
|
||||
collective=gpu.Dimension.y,
|
||||
**common_copy_args,
|
||||
)
|
||||
|
||||
@ -290,6 +297,13 @@ def build_kernel(
|
||||
)
|
||||
ctx.await_async_copy(0)
|
||||
|
||||
cluster_tile_n = min(cluster_n, grid_tile_n)
|
||||
if cluster_n % cluster_tile_n:
|
||||
raise ValueError(
|
||||
f"{cluster_n=} must be divisible by {cluster_tile_n} (due to"
|
||||
f" {grid_tile_n=})"
|
||||
)
|
||||
cluster = (cluster_tile_n, cluster_m, cluster_n // cluster_tile_n)
|
||||
return mosaic_gpu.as_gpu_kernel(
|
||||
_main,
|
||||
grid,
|
||||
@ -303,12 +317,12 @@ def build_kernel(
|
||||
smem_shape,
|
||||
TMABarrier(num_barriers=stages),
|
||||
ClusterBarrier(
|
||||
collective_dims=(gpu.Dimension.x, gpu.Dimension.y),
|
||||
collective_dims=((gpu.Dimension.x, gpu.Dimension.z), gpu.Dimension.y),
|
||||
num_barriers=stages,
|
||||
) if cluster_m * cluster_n > 1 else None,
|
||||
),
|
||||
profiler_spec,
|
||||
cluster=(cluster_n, cluster_m, 1),
|
||||
cluster=cluster,
|
||||
)
|
||||
|
||||
|
||||
@ -321,6 +335,7 @@ def verify(
|
||||
tile_n=128,
|
||||
cluster_m=1,
|
||||
cluster_n=1,
|
||||
grid_tile_n=1,
|
||||
swizzle=128,
|
||||
profile=False,
|
||||
in_dtype=jnp.float16,
|
||||
@ -344,6 +359,7 @@ def verify(
|
||||
cluster_n=cluster_n,
|
||||
rhs_transpose=rhs_transpose,
|
||||
swizzle=swizzle,
|
||||
grid_tile_n=grid_tile_n,
|
||||
wgmma_impl=WGMMADefaultImpl,
|
||||
profiler_spec=prof_spec,
|
||||
)
|
||||
@ -384,12 +400,13 @@ if __name__ == "__main__":
|
||||
x = random.uniform(kx, (m, k), dtype=dtype)
|
||||
y = random.uniform(ky, (k, n), dtype=dtype)
|
||||
|
||||
tile_m = tile_n = (64, 128, 256)
|
||||
tile_m = tile_n = (64, 128)
|
||||
cluster_m = cluster_n = (1, 2)
|
||||
swizzle = (128,)
|
||||
swizzle = (128,) # 64 can be a good choice for some shapes too!
|
||||
stages = (2, 4, 5, 6)
|
||||
configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle)
|
||||
names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle")
|
||||
grid_tile_n = (1, 4, 16)
|
||||
configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle, grid_tile_n)
|
||||
names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle", "grid_tile_n")
|
||||
best_runtime = float("inf")
|
||||
best_kwargs = {}
|
||||
for config in configs:
|
||||
@ -398,9 +415,15 @@ if __name__ == "__main__":
|
||||
continue
|
||||
if m < kwargs["tile_m"] or n < kwargs["tile_n"]:
|
||||
continue
|
||||
if (m // kwargs["tile_m"]) % kwargs["cluster_n"]:
|
||||
if (m // kwargs["tile_m"]) % kwargs["cluster_m"]:
|
||||
continue
|
||||
if (n // kwargs["tile_n"]) % kwargs["cluster_m"]:
|
||||
if (n // kwargs["tile_n"]) % kwargs["cluster_n"]:
|
||||
continue
|
||||
if n % kwargs["grid_tile_n"]:
|
||||
continue
|
||||
# This is a heuristic, not a strict correctness check. You can relax it
|
||||
# for a more complete search space.
|
||||
if kwargs["tile_m"] == kwargs["tile_n"] == 64:
|
||||
continue
|
||||
try:
|
||||
f = build_kernel(
|
||||
|
@ -88,10 +88,13 @@ class MatmulTestCase(jtu.JaxTestCase):
|
||||
tile_n = data.draw(
|
||||
hps.sampled_from([t for t in [64, 128, 256] if t <= n]), label="tile_n"
|
||||
)
|
||||
grid_m, grid_n = m // tile_m, n // tile_n
|
||||
grid_tile_n = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_n")
|
||||
hp.assume(grid_n % grid_tile_n == 0)
|
||||
cluster_m = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_m")
|
||||
hp.assume((m // tile_m) % cluster_m == 0)
|
||||
hp.assume(grid_m % cluster_m == 0)
|
||||
cluster_n = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_n")
|
||||
hp.assume((n // tile_n) % cluster_n == 0)
|
||||
hp.assume(grid_n % cluster_n == 0)
|
||||
# TODO(apaszke): Non-portable clusters (16 blocks) sometimes deadlock.
|
||||
hp.assume(cluster_m * cluster_n <= 8)
|
||||
if bytewidth == 4:
|
||||
@ -111,6 +114,7 @@ class MatmulTestCase(jtu.JaxTestCase):
|
||||
out_dtype=out_dtype,
|
||||
cluster_m=cluster_m,
|
||||
cluster_n=cluster_n,
|
||||
grid_tile_n=grid_tile_n,
|
||||
swizzle=swizzle,
|
||||
rhs_transpose=rhs_transpose,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user