[Mosaic GPU] Add support for cluster collective loads and barriers over multiple dimensions

This will be useful for an upcoming change to the matmul kernel that splits the N blocks
over two cluster dimensions.

PiperOrigin-RevId: 662825455
This commit is contained in:
Adam Paszke 2024-08-14 01:46:31 -07:00 committed by jax authors
parent 4c4660a117
commit f384497f68
3 changed files with 112 additions and 42 deletions

View File

@ -325,7 +325,7 @@ class LaunchContext:
swizzle: int | None = None,
arrive: bool | None = None,
uniform: bool = True,
collective: gpu.Dimension | None = None,
collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None,
):
index = ir.IndexType.get()
i16 = ir.IntegerType.get_signless(16)
@ -388,7 +388,11 @@ class LaunchContext:
dyn_base_indices = list(dyn_base_indices)
slice_shape = list(slice_shape)
collective_size = 1 if collective is None else self.cluster_size[collective]
collective_size = 1
if collective is not None:
if isinstance(collective, gpu.Dimension):
collective = (collective,)
collective_size = math.prod(self.cluster_size[d] for d in collective)
if collective_size > 1:
def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
nonlocal smem_ref
@ -399,18 +403,28 @@ class LaunchContext:
smem_ref,
(slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),)
)
idx = gpu.cluster_block_id(collective)
stride = 1
idx = c(0, index)
for d in sorted(collective):
if self.cluster_size[d] == 1: # Optimize a multiply by 0.
continue
idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index)))
stride *= self.cluster_size[d]
rem_collective_size = collective_size
for dim, slice_size in enumerate(slice_shape[:-1]):
if slice_size % rem_collective_size == 0:
partition_dim(dim, idx, rem_collective_size)
rem_collective_size = 1
break
elif collective_size % slice_size == 0:
elif rem_collective_size % slice_size == 0:
dim_idx = arith.remui(idx, c(slice_size, index))
partition_dim(dim, dim_idx, slice_size)
idx = arith.divui(idx, c(slice_size, index))
rem_collective_size //= slice_size
else:
else:
break # We failed to partition the leading dimensions.
del idx # We overwrote the block index in the loop.
if rem_collective_size > 1:
raise ValueError(
"None of the leading dimensions in the transformed slice shape"
f" {slice_shape} is divisible by the collective size"

View File

@ -622,21 +622,27 @@ class CollectiveBarrierRef:
def initialize(
address: ir.Value,
num_barriers: int,
dims: Sequence[gpu.Dimension],
dims: Sequence[gpu.Dimension | Sequence[gpu.Dimension]],
cluster_shape: tuple[int, int, int],
) -> "CollectiveBarrierRef":
i32 = ir.IntegerType.get_signless(32)
# With the exception of the current device, each pair of slices along
# collective dims is disjoint. Since the current device is overcounted,
# we must decrease the arrival count a little.
arrival_count = sum(cluster_shape[d] for d in dims) - len(dims) + 1
if math.prod(cluster_shape[d] for d in dims) == 1:
dims_shape = [
cluster_shape[d]
if isinstance(d, gpu.Dimension)
else math.prod(cluster_shape[dd] for dd in d)
for d in dims
]
arrival_count = sum(dims_shape) - len(dims) + 1
if arrival_count == 1:
assert all(s == 1 for s in dims_shape)
cluster_mask = None
assert arrival_count == 1
else:
cluster_mask = c(0, i32)
for d in dims:
if cluster_shape[d] == 1:
for d, size in zip(dims, dims_shape):
if size == 1:
# Only the current device is in this mask, but it will also be
# present in one of the non-trivial cluster dims.
continue
@ -887,8 +893,11 @@ def memref_ptr(memref_arg, memory_space=None):
def cluster_collective_mask(
cluster_shape: tuple[int, int, int], collective: gpu.Dimension
cluster_shape: tuple[int, int, int],
collective: Sequence[gpu.Dimension] | gpu.Dimension,
):
if isinstance(collective, gpu.Dimension):
collective = (collective,)
# We first compute the linearized index of the slice along the collective
# dim that contains the current block. Then, the mask is a sequence of 1s
# strided by the position of the collective dim, shifted left by the linear
@ -896,20 +905,20 @@ def cluster_collective_mask(
# TODO(apaszke): Make sure this gets hoisted outside of any loops.
# If not, we might need to do it manually.
i32 = ir.IntegerType.get_signless(32)
stride = 1
mask_shift = c(0, i32)
collective_stride = None
for cluster_dim in gpu.Dimension:
if cluster_dim != collective:
if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0.
dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim))
mask_shift = arith.addi(
mask_shift, arith.muli(dim_idx, c(stride, i32)),
)
else:
collective_stride = stride
stride *= cluster_shape[cluster_dim]
# NOTE: GPU dimensions are minor-to-major.
cluster_strides = get_contiguous_strides(cluster_shape[::-1])[::-1]
for stride, cluster_dim in zip(cluster_strides, gpu.Dimension):
if cluster_dim in collective:
continue
if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0.
dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim))
mask_shift = arith.addi(
mask_shift, arith.muli(dim_idx, c(stride, i32)),
)
mask_unshifted = 0
for i in range(cluster_shape[collective]):
mask_unshifted |= 1 << (i * collective_stride)
collective_strides = [cluster_strides[d] for d in collective]
collective_shape = tuple(cluster_shape[d] for d in collective)
for idx in np.ndindex(collective_shape):
mask_unshifted |= 1 << sum(i * s for i, s in zip(idx, collective_strides))
return arith.shli(c(mask_unshifted, i32), mask_shift)

View File

@ -44,6 +44,7 @@ except ImportError:
else:
from jax.experimental.mosaic import gpu as mosaic_gpu
from jax.experimental.mosaic.gpu import dsl as mgpu
from jax.experimental.mosaic.gpu import utils as utils
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu
@ -749,10 +750,11 @@ class BarrierTest(TestCase):
@parameterized.named_parameters(
(
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}{'_group' if group_dims else ''}",
collective_dims,
noncollective_dims,
collective_size,
group_dims,
)
for collective_dims in itertools.chain.from_iterable(
itertools.combinations(Dimension, n) for n in range(1, 4)
@ -761,9 +763,10 @@ class BarrierTest(TestCase):
itertools.combinations(Dimension, n) for n in range(3)
)
for collective_size in (1, 2, 4)
for group_dims in (False,) + ((True,) if len(collective_dims) > 1 else ())
if all(d not in noncollective_dims for d in collective_dims)
)
def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size):
def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size, group_dims):
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
cluster = [1, 1, 1]
@ -773,9 +776,21 @@ class BarrierTest(TestCase):
cluster[d] = 2
if math.prod(cluster) > 16:
self.skipTest("Cluster too big")
def kernel(ctx, dst, collective_barrier):
is_trivial = math.prod(cluster[d] for d in collective_dims) == 1
def kernel(ctx, dst, mask, collective_barrier):
memref.store(arith.constant(i32, 1 << 17), mask, [c(0, index)])
gpu.barrier()
collective_barrier.arrive()
collective_barrier.wait()
if not is_trivial:
llvm.atomicrmw(
llvm.AtomicBinOp.min,
utils.memref_ptr(mask),
collective_barrier.cluster_mask,
llvm.AtomicOrdering.monotonic,
)
else:
assert collective_barrier.cluster_mask is None
tid = thread_idx()
linear_idx = arith.index_cast(index, tid)
stride = c(128, index)
@ -784,13 +799,30 @@ class BarrierTest(TestCase):
stride = arith.muli(stride, gpu.grid_dim(d))
memref.store(arith.index_cast(i32, linear_idx), dst, [linear_idx])
out_shape = jax.ShapeDtypeStruct((math.prod(cluster) * 128,), jnp.int32)
scratch = mgpu.ClusterBarrier(collective_dims)
y = mosaic_gpu.as_gpu_kernel(
kernel, cluster, (128, 1, 1), (), out_shape, scratch, cluster=cluster,
mask_shape = jax.ShapeDtypeStruct((1,), jnp.int32)
barrier_dims = collective_dims
if group_dims:
barrier_dims = (collective_dims[:2], *collective_dims[2:])
scratch = mgpu.ClusterBarrier(barrier_dims)
y, mask = mosaic_gpu.as_gpu_kernel(
kernel, cluster, (128, 1, 1), (), (out_shape, mask_shape), scratch, cluster=cluster,
)()
np.testing.assert_array_equal(
y, np.arange(math.prod(cluster) * 128, dtype=np.int32)
)
if not is_trivial:
# Verify that the mask is correct. Blocks are column-major, hence the transpose.
block_bits = 1 << np.arange(math.prod(cluster), dtype=np.int32).reshape(cluster[::-1]).T
expected_mask = 0
for bd in barrier_dims:
if isinstance(bd, gpu.Dimension):
bd = (bd,)
least_significant_slice = tuple(
slice(None) if d in bd else 0 for d in gpu.Dimension
)
mask_bits = block_bits[least_significant_slice]
expected_mask |= np.bitwise_or.reduce(mask_bits, axis=None)
self.assertEqual(mask, expected_mask)
class TMATest(TestCase):
@ -816,30 +848,36 @@ class TMATest(TestCase):
@parameterized.named_parameters(
(
f"_{collective_dim}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
collective_dim,
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
collective_dims,
noncollective_dims,
collective_size,
)
for collective_dim in Dimension
for collective_dims in itertools.chain.from_iterable(
itertools.combinations(Dimension, n) for n in range(1, 4)
)
for noncollective_dims in itertools.chain.from_iterable(
itertools.combinations(Dimension, n) for n in range(3)
)
for collective_size in (1, 2, 4)
if collective_dim not in noncollective_dims
if all(d not in noncollective_dims for d in collective_dims)
)
def test_tma_load_multicast(self, collective_dim, noncollective_dims, collective_size):
def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size):
index = ir.IndexType.get()
swizzle = 128
dtype = jnp.float16
cluster = [1, 1, 1]
cluster[collective_dim] = collective_size
for d in collective_dims:
cluster[d] = collective_dim_size
for d in noncollective_dims:
cluster[d] = 2
noncollective_size = math.prod(cluster) // cluster[collective_dim]
if math.prod(cluster) > 16:
self.skipTest("Cluster too big")
collective_size = math.prod(cluster[d] for d in collective_dims)
noncollective_size = math.prod(cluster) // collective_size
# We use the 2 dimension to exercise splitting the collective over
# multiple dimensions when the cluster is large.
shape = (noncollective_size, 2, 16 * cluster[collective_dim], 64)
shape = (noncollective_size, 2, 16 * collective_size, 64)
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
shape = (*shape[:-1], minor_size)
# Note that this kernel does not use the non-collective dimensions in any
@ -861,11 +899,20 @@ class TMATest(TestCase):
gmem_slice=(noncollective_idx,),
swizzle=swizzle,
barrier=barrier,
collective=collective_dim,
collective=collective_dims,
)
barrier.wait()
# This is _not_ the real cluster block idx, because it does not consider
# the column-major ordering of the grid dimensions.
idx = c(0, index)
stride = 1
for d in collective_dims:
idx = arith.addi(
idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))
)
stride *= cluster[d]
slc = ds(
arith.muli(gpu.cluster_block_id(collective_dim), c(16, index)), 16
arith.muli(idx, c(16, index)), 16
)
copy(
memref_slice(tmp, (slice(None), slc)),