mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Replace multicast_mask by a nicer collective async copy interface
Instead of asking the user to compute the transfer size, manually slice up the transfer and compute and specify the multicast mask, we fold all that functionality into the `async_copy` function. The copy should be called by all blocks in a given cluster slice along the specified dimension, and will collectively load all the requested data into all blocks in that slice. PiperOrigin-RevId: 655077439
This commit is contained in:
parent
a2b2fbf513
commit
51732c5caf
@ -209,6 +209,7 @@ OnDeviceProfiler = profiler.OnDeviceProfiler
|
||||
class LaunchContext:
|
||||
launch_op: gpu.LaunchOp
|
||||
gmem_scratch_ptr: ir.Value
|
||||
cluster_size: tuple[int, ...]
|
||||
profiler: OnDeviceProfiler | None = None
|
||||
next_scratch_offset: int = 0
|
||||
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
|
||||
@ -322,9 +323,10 @@ class LaunchContext:
|
||||
swizzle: int | None = None,
|
||||
arrive: bool | None = None,
|
||||
uniform: bool = True,
|
||||
multicast_mask: ir.Value | None = None,
|
||||
collective: gpu.Dimension | None = None,
|
||||
):
|
||||
index = ir.IndexType.get()
|
||||
i16 = ir.IntegerType.get_signless(16)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
src_ref_ty = ir.MemRefType(src_ref.type)
|
||||
@ -378,11 +380,68 @@ class LaunchContext:
|
||||
|
||||
if slice_shape != tuple(smem_ref_ty.shape):
|
||||
raise ValueError(
|
||||
"Expected the SMEM reference to have the same shape as the tiled"
|
||||
f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
|
||||
"Expected the SMEM reference to have the same shape as the"
|
||||
f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
|
||||
)
|
||||
|
||||
dyn_base_indices = list(dyn_base_indices)
|
||||
slice_shape = list(slice_shape)
|
||||
if (
|
||||
collective is not None
|
||||
and (collective_size := self.cluster_size[collective]) != 1
|
||||
):
|
||||
for collective_slice_dim, slice_size in enumerate(slice_shape[:-1]):
|
||||
if slice_size % collective_size == 0:
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"None of the leading dimensions in the transformed slice shape"
|
||||
f" {slice_shape} is divisible by the collective size"
|
||||
f" {collective_size}"
|
||||
)
|
||||
# Make each block load a smaller slice, adjust the GMEM indices and slice
|
||||
# the SMEM reference accordingly.
|
||||
slice_shape[collective_slice_dim] //= collective_size
|
||||
block_idx = gpu.cluster_block_id(collective)
|
||||
block_offset = arith.muli(block_idx, c(slice_shape[collective_slice_dim], index))
|
||||
dyn_base_indices[collective_slice_dim] = arith.addi(
|
||||
dyn_base_indices[collective_slice_dim], block_offset,
|
||||
)
|
||||
smem_ref = mgpu.memref_slice(
|
||||
smem_ref,
|
||||
(slice(None),) * collective_slice_dim
|
||||
+ (mgpu.ds(block_offset, slice_shape[collective_slice_dim]),),
|
||||
)
|
||||
# Compute the multicast mask. We first compute the linearized index of the
|
||||
# slice along the collective dim that contains the current block. Then,
|
||||
# the mask is a sequence of 1s strided by the position of the collective
|
||||
# dim, shifted left by the linear slice index.
|
||||
# TODO(apaszke): Make sure this gets hoisted outside of any loops.
|
||||
# If not, we might need to do it manually.
|
||||
stride = 1
|
||||
mask_shift = c(0, i32)
|
||||
collective_stride = None
|
||||
for cluster_dim in gpu.Dimension:
|
||||
if self.cluster_size[cluster_dim] == 1:
|
||||
continue
|
||||
if cluster_dim != collective:
|
||||
dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim))
|
||||
mask_shift = arith.addi(
|
||||
mask_shift, arith.muli(dim_idx, c(stride, i32)),
|
||||
)
|
||||
else:
|
||||
collective_stride = stride
|
||||
stride *= self.cluster_size[cluster_dim]
|
||||
multicast_mask_unshifted = 0
|
||||
for i in range(collective_size):
|
||||
multicast_mask_unshifted |= 1 << (i * collective_stride)
|
||||
multicast_mask = arith.shli(c(multicast_mask_unshifted, i32), mask_shift)
|
||||
multicast_mask = arith.trunci(i16, multicast_mask)
|
||||
else:
|
||||
multicast_mask = None
|
||||
|
||||
tma_desc = self._get_tma_desc(
|
||||
gmem_ref, gmem_transform, slice_shape, swizzle,
|
||||
gmem_ref, gmem_transform, tuple(slice_shape), swizzle,
|
||||
)
|
||||
|
||||
# We constuct TMA descriptors in column-major order.
|
||||
@ -555,7 +614,7 @@ def _launch(
|
||||
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
|
||||
yield LaunchContext(launch_op, scratch_ptr, prof), smem_ref_tree
|
||||
yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree
|
||||
if prof is not None:
|
||||
prof.finalize(grid=grid, block=block)
|
||||
gpu.terminator()
|
||||
|
@ -14,7 +14,9 @@
|
||||
# ==============================================================================
|
||||
"""Tests for Mosaic GPU DSL functions and utilities."""
|
||||
|
||||
import enum
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
|
||||
@ -34,6 +36,11 @@ try:
|
||||
HAS_MOSAIC_GPU = True
|
||||
except ImportError:
|
||||
HAS_MOSAIC_GPU = False
|
||||
|
||||
class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok
|
||||
x = 0
|
||||
y = 1
|
||||
z = 2
|
||||
else:
|
||||
from jax.experimental.mosaic import gpu as mosaic_gpu
|
||||
from jax.experimental.mosaic.gpu import dsl as mgpu
|
||||
@ -41,9 +48,11 @@ else:
|
||||
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
Dimension = gpu.Dimension
|
||||
|
||||
|
||||
# ruff: noqa: F405
|
||||
# pylint: disable=g-complex-comprehension
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def nd_loop(bounds, body, *, _idxs = ()):
|
||||
@ -741,38 +750,67 @@ class TMATest(TestCase):
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_tma_load_multicast(self):
|
||||
dtype = jnp.float16
|
||||
shape = (64, 64)
|
||||
@parameterized.named_parameters(
|
||||
(
|
||||
f"_{collective_dim}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
|
||||
collective_dim,
|
||||
noncollective_dims,
|
||||
)
|
||||
for collective_dim in Dimension
|
||||
for noncollective_dims in itertools.chain.from_iterable(
|
||||
itertools.combinations(Dimension, n) for n in range(3)
|
||||
)
|
||||
if collective_dim not in noncollective_dims
|
||||
)
|
||||
def test_tma_load_multicast(self, collective_dim, noncollective_dims):
|
||||
index = ir.IndexType.get()
|
||||
swizzle = 128
|
||||
dtype = jnp.float16
|
||||
cluster = [1, 1, 1]
|
||||
for d in (collective_dim, *noncollective_dims):
|
||||
cluster[d] = 2
|
||||
noncollective_size = math.prod(cluster) // cluster[collective_dim]
|
||||
shape = (noncollective_size, 64, 64)
|
||||
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
|
||||
shape = (*shape[:-1], minor_size)
|
||||
i16 = ir.IntegerType.get_signless(16)
|
||||
index = ir.IndexType.get()
|
||||
# Note that this kernel does not use the non-collective dimensions in any
|
||||
# interesting way and so they don't really have to be part of the cluster.
|
||||
# We use them to test that the multicast mask is generated correctly.
|
||||
def kernel(ctx, src, dst, tmp):
|
||||
stride = 1
|
||||
noncollective_idx = c(0, index)
|
||||
for d in noncollective_dims:
|
||||
noncollective_idx = arith.addi(
|
||||
noncollective_idx,
|
||||
arith.muli(gpu.cluster_block_id(d), c(stride, index))
|
||||
)
|
||||
stride *= 2
|
||||
barrier = BarrierArray(1)[0]
|
||||
nvvm.fence_mbarrier_init()
|
||||
nvvm.cluster_arrive_relaxed()
|
||||
nvvm.cluster_wait()
|
||||
slc = ds(
|
||||
arith.muli(gpu.cluster_block_id(gpu.Dimension.x), c(32, index)), 32
|
||||
ctx.async_copy(
|
||||
src_ref=src,
|
||||
dst_ref=tmp,
|
||||
gmem_slice=(noncollective_idx,),
|
||||
swizzle=swizzle,
|
||||
barrier=barrier,
|
||||
collective=collective_dim,
|
||||
)
|
||||
with single_thread():
|
||||
barrier.arrive_expect_tx(math.prod(shape) * np.dtype(dtype).itemsize)
|
||||
ctx.async_copy(
|
||||
src_ref=src,
|
||||
dst_ref=memref_slice(tmp, slc),
|
||||
swizzle=swizzle,
|
||||
gmem_slice=slc,
|
||||
barrier=barrier,
|
||||
arrive=False,
|
||||
uniform=False,
|
||||
multicast_mask=c(0b11, i16),
|
||||
)
|
||||
barrier.wait()
|
||||
copy(memref_slice(tmp, slc), memref_slice(dst, slc), swizzle=swizzle)
|
||||
slc = ds(
|
||||
arith.muli(gpu.cluster_block_id(collective_dim), c(32, index)), 32
|
||||
)
|
||||
copy(
|
||||
memref_slice(tmp, slc),
|
||||
memref_slice(dst, (noncollective_idx, slc)),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (2, 1, 1), (128, 1, 1), x, x, x, cluster=(2, 1, 1))(x)
|
||||
smem_shape = jax.ShapeDtypeStruct(shape[1:], dtype)
|
||||
y = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, cluster, (128, 1, 1), x, x, smem_shape, cluster=cluster
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
@parameterized.product(
|
||||
|
Loading…
x
Reference in New Issue
Block a user