mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add support for cluster collective loads and barriers over multiple dimensions
This will be useful for an upcoming change to the matmul kernel that splits the N blocks over two cluster dimensions. PiperOrigin-RevId: 662825455
This commit is contained in:
parent
4c4660a117
commit
f384497f68
@ -325,7 +325,7 @@ class LaunchContext:
|
||||
swizzle: int | None = None,
|
||||
arrive: bool | None = None,
|
||||
uniform: bool = True,
|
||||
collective: gpu.Dimension | None = None,
|
||||
collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None,
|
||||
):
|
||||
index = ir.IndexType.get()
|
||||
i16 = ir.IntegerType.get_signless(16)
|
||||
@ -388,7 +388,11 @@ class LaunchContext:
|
||||
|
||||
dyn_base_indices = list(dyn_base_indices)
|
||||
slice_shape = list(slice_shape)
|
||||
collective_size = 1 if collective is None else self.cluster_size[collective]
|
||||
collective_size = 1
|
||||
if collective is not None:
|
||||
if isinstance(collective, gpu.Dimension):
|
||||
collective = (collective,)
|
||||
collective_size = math.prod(self.cluster_size[d] for d in collective)
|
||||
if collective_size > 1:
|
||||
def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
|
||||
nonlocal smem_ref
|
||||
@ -399,18 +403,28 @@ class LaunchContext:
|
||||
smem_ref,
|
||||
(slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),)
|
||||
)
|
||||
idx = gpu.cluster_block_id(collective)
|
||||
stride = 1
|
||||
idx = c(0, index)
|
||||
for d in sorted(collective):
|
||||
if self.cluster_size[d] == 1: # Optimize a multiply by 0.
|
||||
continue
|
||||
idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index)))
|
||||
stride *= self.cluster_size[d]
|
||||
rem_collective_size = collective_size
|
||||
for dim, slice_size in enumerate(slice_shape[:-1]):
|
||||
if slice_size % rem_collective_size == 0:
|
||||
partition_dim(dim, idx, rem_collective_size)
|
||||
rem_collective_size = 1
|
||||
break
|
||||
elif collective_size % slice_size == 0:
|
||||
elif rem_collective_size % slice_size == 0:
|
||||
dim_idx = arith.remui(idx, c(slice_size, index))
|
||||
partition_dim(dim, dim_idx, slice_size)
|
||||
idx = arith.divui(idx, c(slice_size, index))
|
||||
rem_collective_size //= slice_size
|
||||
else:
|
||||
else:
|
||||
break # We failed to partition the leading dimensions.
|
||||
del idx # We overwrote the block index in the loop.
|
||||
if rem_collective_size > 1:
|
||||
raise ValueError(
|
||||
"None of the leading dimensions in the transformed slice shape"
|
||||
f" {slice_shape} is divisible by the collective size"
|
||||
|
@ -622,21 +622,27 @@ class CollectiveBarrierRef:
|
||||
def initialize(
|
||||
address: ir.Value,
|
||||
num_barriers: int,
|
||||
dims: Sequence[gpu.Dimension],
|
||||
dims: Sequence[gpu.Dimension | Sequence[gpu.Dimension]],
|
||||
cluster_shape: tuple[int, int, int],
|
||||
) -> "CollectiveBarrierRef":
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
# With the exception of the current device, each pair of slices along
|
||||
# collective dims is disjoint. Since the current device is overcounted,
|
||||
# we must decrease the arrival count a little.
|
||||
arrival_count = sum(cluster_shape[d] for d in dims) - len(dims) + 1
|
||||
if math.prod(cluster_shape[d] for d in dims) == 1:
|
||||
dims_shape = [
|
||||
cluster_shape[d]
|
||||
if isinstance(d, gpu.Dimension)
|
||||
else math.prod(cluster_shape[dd] for dd in d)
|
||||
for d in dims
|
||||
]
|
||||
arrival_count = sum(dims_shape) - len(dims) + 1
|
||||
if arrival_count == 1:
|
||||
assert all(s == 1 for s in dims_shape)
|
||||
cluster_mask = None
|
||||
assert arrival_count == 1
|
||||
else:
|
||||
cluster_mask = c(0, i32)
|
||||
for d in dims:
|
||||
if cluster_shape[d] == 1:
|
||||
for d, size in zip(dims, dims_shape):
|
||||
if size == 1:
|
||||
# Only the current device is in this mask, but it will also be
|
||||
# present in one of the non-trivial cluster dims.
|
||||
continue
|
||||
@ -887,8 +893,11 @@ def memref_ptr(memref_arg, memory_space=None):
|
||||
|
||||
|
||||
def cluster_collective_mask(
|
||||
cluster_shape: tuple[int, int, int], collective: gpu.Dimension
|
||||
cluster_shape: tuple[int, int, int],
|
||||
collective: Sequence[gpu.Dimension] | gpu.Dimension,
|
||||
):
|
||||
if isinstance(collective, gpu.Dimension):
|
||||
collective = (collective,)
|
||||
# We first compute the linearized index of the slice along the collective
|
||||
# dim that contains the current block. Then, the mask is a sequence of 1s
|
||||
# strided by the position of the collective dim, shifted left by the linear
|
||||
@ -896,20 +905,20 @@ def cluster_collective_mask(
|
||||
# TODO(apaszke): Make sure this gets hoisted outside of any loops.
|
||||
# If not, we might need to do it manually.
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
stride = 1
|
||||
mask_shift = c(0, i32)
|
||||
collective_stride = None
|
||||
for cluster_dim in gpu.Dimension:
|
||||
if cluster_dim != collective:
|
||||
if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0.
|
||||
dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim))
|
||||
mask_shift = arith.addi(
|
||||
mask_shift, arith.muli(dim_idx, c(stride, i32)),
|
||||
)
|
||||
else:
|
||||
collective_stride = stride
|
||||
stride *= cluster_shape[cluster_dim]
|
||||
# NOTE: GPU dimensions are minor-to-major.
|
||||
cluster_strides = get_contiguous_strides(cluster_shape[::-1])[::-1]
|
||||
for stride, cluster_dim in zip(cluster_strides, gpu.Dimension):
|
||||
if cluster_dim in collective:
|
||||
continue
|
||||
if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0.
|
||||
dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim))
|
||||
mask_shift = arith.addi(
|
||||
mask_shift, arith.muli(dim_idx, c(stride, i32)),
|
||||
)
|
||||
mask_unshifted = 0
|
||||
for i in range(cluster_shape[collective]):
|
||||
mask_unshifted |= 1 << (i * collective_stride)
|
||||
collective_strides = [cluster_strides[d] for d in collective]
|
||||
collective_shape = tuple(cluster_shape[d] for d in collective)
|
||||
for idx in np.ndindex(collective_shape):
|
||||
mask_unshifted |= 1 << sum(i * s for i, s in zip(idx, collective_strides))
|
||||
return arith.shli(c(mask_unshifted, i32), mask_shift)
|
||||
|
@ -44,6 +44,7 @@ except ImportError:
|
||||
else:
|
||||
from jax.experimental.mosaic import gpu as mosaic_gpu
|
||||
from jax.experimental.mosaic.gpu import dsl as mgpu
|
||||
from jax.experimental.mosaic.gpu import utils as utils
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
@ -749,10 +750,11 @@ class BarrierTest(TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(
|
||||
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
|
||||
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}{'_group' if group_dims else ''}",
|
||||
collective_dims,
|
||||
noncollective_dims,
|
||||
collective_size,
|
||||
group_dims,
|
||||
)
|
||||
for collective_dims in itertools.chain.from_iterable(
|
||||
itertools.combinations(Dimension, n) for n in range(1, 4)
|
||||
@ -761,9 +763,10 @@ class BarrierTest(TestCase):
|
||||
itertools.combinations(Dimension, n) for n in range(3)
|
||||
)
|
||||
for collective_size in (1, 2, 4)
|
||||
for group_dims in (False,) + ((True,) if len(collective_dims) > 1 else ())
|
||||
if all(d not in noncollective_dims for d in collective_dims)
|
||||
)
|
||||
def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size):
|
||||
def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size, group_dims):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
index = ir.IndexType.get()
|
||||
cluster = [1, 1, 1]
|
||||
@ -773,9 +776,21 @@ class BarrierTest(TestCase):
|
||||
cluster[d] = 2
|
||||
if math.prod(cluster) > 16:
|
||||
self.skipTest("Cluster too big")
|
||||
def kernel(ctx, dst, collective_barrier):
|
||||
is_trivial = math.prod(cluster[d] for d in collective_dims) == 1
|
||||
def kernel(ctx, dst, mask, collective_barrier):
|
||||
memref.store(arith.constant(i32, 1 << 17), mask, [c(0, index)])
|
||||
gpu.barrier()
|
||||
collective_barrier.arrive()
|
||||
collective_barrier.wait()
|
||||
if not is_trivial:
|
||||
llvm.atomicrmw(
|
||||
llvm.AtomicBinOp.min,
|
||||
utils.memref_ptr(mask),
|
||||
collective_barrier.cluster_mask,
|
||||
llvm.AtomicOrdering.monotonic,
|
||||
)
|
||||
else:
|
||||
assert collective_barrier.cluster_mask is None
|
||||
tid = thread_idx()
|
||||
linear_idx = arith.index_cast(index, tid)
|
||||
stride = c(128, index)
|
||||
@ -784,13 +799,30 @@ class BarrierTest(TestCase):
|
||||
stride = arith.muli(stride, gpu.grid_dim(d))
|
||||
memref.store(arith.index_cast(i32, linear_idx), dst, [linear_idx])
|
||||
out_shape = jax.ShapeDtypeStruct((math.prod(cluster) * 128,), jnp.int32)
|
||||
scratch = mgpu.ClusterBarrier(collective_dims)
|
||||
y = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, cluster, (128, 1, 1), (), out_shape, scratch, cluster=cluster,
|
||||
mask_shape = jax.ShapeDtypeStruct((1,), jnp.int32)
|
||||
barrier_dims = collective_dims
|
||||
if group_dims:
|
||||
barrier_dims = (collective_dims[:2], *collective_dims[2:])
|
||||
scratch = mgpu.ClusterBarrier(barrier_dims)
|
||||
y, mask = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, cluster, (128, 1, 1), (), (out_shape, mask_shape), scratch, cluster=cluster,
|
||||
)()
|
||||
np.testing.assert_array_equal(
|
||||
y, np.arange(math.prod(cluster) * 128, dtype=np.int32)
|
||||
)
|
||||
if not is_trivial:
|
||||
# Verify that the mask is correct. Blocks are column-major, hence the transpose.
|
||||
block_bits = 1 << np.arange(math.prod(cluster), dtype=np.int32).reshape(cluster[::-1]).T
|
||||
expected_mask = 0
|
||||
for bd in barrier_dims:
|
||||
if isinstance(bd, gpu.Dimension):
|
||||
bd = (bd,)
|
||||
least_significant_slice = tuple(
|
||||
slice(None) if d in bd else 0 for d in gpu.Dimension
|
||||
)
|
||||
mask_bits = block_bits[least_significant_slice]
|
||||
expected_mask |= np.bitwise_or.reduce(mask_bits, axis=None)
|
||||
self.assertEqual(mask, expected_mask)
|
||||
|
||||
|
||||
class TMATest(TestCase):
|
||||
@ -816,30 +848,36 @@ class TMATest(TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(
|
||||
f"_{collective_dim}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
|
||||
collective_dim,
|
||||
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
|
||||
collective_dims,
|
||||
noncollective_dims,
|
||||
collective_size,
|
||||
)
|
||||
for collective_dim in Dimension
|
||||
for collective_dims in itertools.chain.from_iterable(
|
||||
itertools.combinations(Dimension, n) for n in range(1, 4)
|
||||
)
|
||||
for noncollective_dims in itertools.chain.from_iterable(
|
||||
itertools.combinations(Dimension, n) for n in range(3)
|
||||
)
|
||||
for collective_size in (1, 2, 4)
|
||||
if collective_dim not in noncollective_dims
|
||||
if all(d not in noncollective_dims for d in collective_dims)
|
||||
)
|
||||
def test_tma_load_multicast(self, collective_dim, noncollective_dims, collective_size):
|
||||
def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size):
|
||||
index = ir.IndexType.get()
|
||||
swizzle = 128
|
||||
dtype = jnp.float16
|
||||
cluster = [1, 1, 1]
|
||||
cluster[collective_dim] = collective_size
|
||||
for d in collective_dims:
|
||||
cluster[d] = collective_dim_size
|
||||
for d in noncollective_dims:
|
||||
cluster[d] = 2
|
||||
noncollective_size = math.prod(cluster) // cluster[collective_dim]
|
||||
if math.prod(cluster) > 16:
|
||||
self.skipTest("Cluster too big")
|
||||
collective_size = math.prod(cluster[d] for d in collective_dims)
|
||||
noncollective_size = math.prod(cluster) // collective_size
|
||||
# We use the 2 dimension to exercise splitting the collective over
|
||||
# multiple dimensions when the cluster is large.
|
||||
shape = (noncollective_size, 2, 16 * cluster[collective_dim], 64)
|
||||
shape = (noncollective_size, 2, 16 * collective_size, 64)
|
||||
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
|
||||
shape = (*shape[:-1], minor_size)
|
||||
# Note that this kernel does not use the non-collective dimensions in any
|
||||
@ -861,11 +899,20 @@ class TMATest(TestCase):
|
||||
gmem_slice=(noncollective_idx,),
|
||||
swizzle=swizzle,
|
||||
barrier=barrier,
|
||||
collective=collective_dim,
|
||||
collective=collective_dims,
|
||||
)
|
||||
barrier.wait()
|
||||
# This is _not_ the real cluster block idx, because it does not consider
|
||||
# the column-major ordering of the grid dimensions.
|
||||
idx = c(0, index)
|
||||
stride = 1
|
||||
for d in collective_dims:
|
||||
idx = arith.addi(
|
||||
idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))
|
||||
)
|
||||
stride *= cluster[d]
|
||||
slc = ds(
|
||||
arith.muli(gpu.cluster_block_id(collective_dim), c(16, index)), 16
|
||||
arith.muli(idx, c(16, index)), 16
|
||||
)
|
||||
copy(
|
||||
memref_slice(tmp, (slice(None), slc)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user