[Mosaic GPU] Replace multicast_mask by a nicer collective async copy interface

Instead of asking the user to compute the transfer size, manually slice up the
transfer and compute and specify the multicast mask, we fold all that functionality
into the `async_copy` function. The copy should be called by all blocks in a given
cluster slice along the specified dimension, and will collectively load all the
requested data into all blocks in that slice.

PiperOrigin-RevId: 655077439
This commit is contained in:
Adam Paszke 2024-07-23 01:54:26 -07:00 committed by jax authors
parent a2b2fbf513
commit 51732c5caf
2 changed files with 123 additions and 26 deletions

View File

@ -209,6 +209,7 @@ OnDeviceProfiler = profiler.OnDeviceProfiler
class LaunchContext:
launch_op: gpu.LaunchOp
gmem_scratch_ptr: ir.Value
cluster_size: tuple[int, ...]
profiler: OnDeviceProfiler | None = None
next_scratch_offset: int = 0
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
@ -322,9 +323,10 @@ class LaunchContext:
swizzle: int | None = None,
arrive: bool | None = None,
uniform: bool = True,
multicast_mask: ir.Value | None = None,
collective: gpu.Dimension | None = None,
):
index = ir.IndexType.get()
i16 = ir.IntegerType.get_signless(16)
i32 = ir.IntegerType.get_signless(32)
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
src_ref_ty = ir.MemRefType(src_ref.type)
@ -378,11 +380,68 @@ class LaunchContext:
if slice_shape != tuple(smem_ref_ty.shape):
raise ValueError(
"Expected the SMEM reference to have the same shape as the tiled"
f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
"Expected the SMEM reference to have the same shape as the"
f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
)
dyn_base_indices = list(dyn_base_indices)
slice_shape = list(slice_shape)
if (
collective is not None
and (collective_size := self.cluster_size[collective]) != 1
):
for collective_slice_dim, slice_size in enumerate(slice_shape[:-1]):
if slice_size % collective_size == 0:
break
else:
raise ValueError(
"None of the leading dimensions in the transformed slice shape"
f" {slice_shape} is divisible by the collective size"
f" {collective_size}"
)
# Make each block load a smaller slice, adjust the GMEM indices and slice
# the SMEM reference accordingly.
slice_shape[collective_slice_dim] //= collective_size
block_idx = gpu.cluster_block_id(collective)
block_offset = arith.muli(block_idx, c(slice_shape[collective_slice_dim], index))
dyn_base_indices[collective_slice_dim] = arith.addi(
dyn_base_indices[collective_slice_dim], block_offset,
)
smem_ref = mgpu.memref_slice(
smem_ref,
(slice(None),) * collective_slice_dim
+ (mgpu.ds(block_offset, slice_shape[collective_slice_dim]),),
)
# Compute the multicast mask. We first compute the linearized index of the
# slice along the collective dim that contains the current block. Then,
# the mask is a sequence of 1s strided by the position of the collective
# dim, shifted left by the linear slice index.
# TODO(apaszke): Make sure this gets hoisted outside of any loops.
# If not, we might need to do it manually.
stride = 1
mask_shift = c(0, i32)
collective_stride = None
for cluster_dim in gpu.Dimension:
if self.cluster_size[cluster_dim] == 1:
continue
if cluster_dim != collective:
dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim))
mask_shift = arith.addi(
mask_shift, arith.muli(dim_idx, c(stride, i32)),
)
else:
collective_stride = stride
stride *= self.cluster_size[cluster_dim]
multicast_mask_unshifted = 0
for i in range(collective_size):
multicast_mask_unshifted |= 1 << (i * collective_stride)
multicast_mask = arith.shli(c(multicast_mask_unshifted, i32), mask_shift)
multicast_mask = arith.trunci(i16, multicast_mask)
else:
multicast_mask = None
tma_desc = self._get_tma_desc(
gmem_ref, gmem_transform, slice_shape, swizzle,
gmem_ref, gmem_transform, tuple(slice_shape), swizzle,
)
# We constuct TMA descriptors in column-major order.
@ -555,7 +614,7 @@ def _launch(
ptr_ty = ir.Type.parse("!llvm.ptr")
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
yield LaunchContext(launch_op, scratch_ptr, prof), smem_ref_tree
yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree
if prof is not None:
prof.finalize(grid=grid, block=block)
gpu.terminator()

View File

@ -14,7 +14,9 @@
# ==============================================================================
"""Tests for Mosaic GPU DSL functions and utilities."""
import enum
from functools import partial
import itertools
import math
import operator
@ -34,6 +36,11 @@ try:
HAS_MOSAIC_GPU = True
except ImportError:
HAS_MOSAIC_GPU = False
class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok
x = 0
y = 1
z = 2
else:
from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import dsl as mgpu
@ -41,9 +48,11 @@ else:
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
Dimension = gpu.Dimension
# ruff: noqa: F405
# pylint: disable=g-complex-comprehension
config.parse_flags_with_absl()
def nd_loop(bounds, body, *, _idxs = ()):
@ -741,38 +750,67 @@ class TMATest(TestCase):
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)
def test_tma_load_multicast(self):
dtype = jnp.float16
shape = (64, 64)
@parameterized.named_parameters(
(
f"_{collective_dim}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
collective_dim,
noncollective_dims,
)
for collective_dim in Dimension
for noncollective_dims in itertools.chain.from_iterable(
itertools.combinations(Dimension, n) for n in range(3)
)
if collective_dim not in noncollective_dims
)
def test_tma_load_multicast(self, collective_dim, noncollective_dims):
index = ir.IndexType.get()
swizzle = 128
dtype = jnp.float16
cluster = [1, 1, 1]
for d in (collective_dim, *noncollective_dims):
cluster[d] = 2
noncollective_size = math.prod(cluster) // cluster[collective_dim]
shape = (noncollective_size, 64, 64)
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
shape = (*shape[:-1], minor_size)
i16 = ir.IntegerType.get_signless(16)
index = ir.IndexType.get()
# Note that this kernel does not use the non-collective dimensions in any
# interesting way and so they don't really have to be part of the cluster.
# We use them to test that the multicast mask is generated correctly.
def kernel(ctx, src, dst, tmp):
stride = 1
noncollective_idx = c(0, index)
for d in noncollective_dims:
noncollective_idx = arith.addi(
noncollective_idx,
arith.muli(gpu.cluster_block_id(d), c(stride, index))
)
stride *= 2
barrier = BarrierArray(1)[0]
nvvm.fence_mbarrier_init()
nvvm.cluster_arrive_relaxed()
nvvm.cluster_wait()
slc = ds(
arith.muli(gpu.cluster_block_id(gpu.Dimension.x), c(32, index)), 32
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
gmem_slice=(noncollective_idx,),
swizzle=swizzle,
barrier=barrier,
collective=collective_dim,
)
with single_thread():
barrier.arrive_expect_tx(math.prod(shape) * np.dtype(dtype).itemsize)
ctx.async_copy(
src_ref=src,
dst_ref=memref_slice(tmp, slc),
swizzle=swizzle,
gmem_slice=slc,
barrier=barrier,
arrive=False,
uniform=False,
multicast_mask=c(0b11, i16),
)
barrier.wait()
copy(memref_slice(tmp, slc), memref_slice(dst, slc), swizzle=swizzle)
slc = ds(
arith.muli(gpu.cluster_block_id(collective_dim), c(32, index)), 32
)
copy(
memref_slice(tmp, slc),
memref_slice(dst, (noncollective_idx, slc)),
swizzle=swizzle,
)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
y = mosaic_gpu.as_gpu_kernel(kernel, (2, 1, 1), (128, 1, 1), x, x, x, cluster=(2, 1, 1))(x)
smem_shape = jax.ShapeDtypeStruct(shape[1:], dtype)
y = mosaic_gpu.as_gpu_kernel(
kernel, cluster, (128, 1, 1), x, x, smem_shape, cluster=cluster
)(x)
np.testing.assert_array_equal(y, x)
@parameterized.product(