[Mosaic GPU] Fix the ordering of transforms in async_copy

Previously we didn't really fully discharge squeezing the indexed
dims before applying other GMEM transforms, leading to potential
failures because they were not anticipating the increased rank.

PiperOrigin-RevId: 694098739
This commit is contained in:
Adam Paszke 2024-11-07 06:41:04 -08:00 committed by jax authors
parent 4cc80889b6
commit 506671291a
3 changed files with 68 additions and 17 deletions

View File

@ -133,6 +133,14 @@ class MemRefTransform:
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
raise NotImplementedError("Subclasses should override this method")
def batch(self, leading_rank: int) -> 'MemRefTransform':
"""Returns a transform that accepts a ref with the extra `leading_rank` dims.
The returned transform should leave the leading dimensions unchanged and
only apply to the suffix of the shape.
"""
raise NotImplementedError("Subclasses should override this method")
@dataclasses.dataclass(frozen=True)
class TileTransform(MemRefTransform):
@ -198,6 +206,9 @@ class TileTransform(MemRefTransform):
*self.tiling,
)
def batch(self, leading_rank: int) -> MemRefTransform:
return self
@dataclasses.dataclass(frozen=True)
class TransposeTransform(MemRefTransform):
@ -217,6 +228,11 @@ class TransposeTransform(MemRefTransform):
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
return tuple(shape[p] for p in self.permutation)
def batch(self, leading_rank: int) -> MemRefTransform:
return TransposeTransform(
(*range(leading_rank), *(d + leading_rank for d in self.permutation))
)
OnDeviceProfiler = profiler.OnDeviceProfiler
@ -388,16 +404,26 @@ class LaunchContext:
dyn_base_indices = tuple(
c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
)
squeezed_dims = [i for i, squeezed in enumerate(is_squeezed) if squeezed]
sliced_dims = [i for i, squeezed in enumerate(is_squeezed) if not squeezed]
# Indexing is really slicing + squeezing, and user transforms are meant to
# apply after that. However, we actually have to apply the indexing last
# (it's fused into the TMA) and so we need to commute it with all the user
# transforms. For slicing this is done using transform_index and
# transform_shape. For squeezing we actually move all the squeezed dims to
# the front, and then batch each transform, making it ignore the extra dims.
if squeezed_dims:
gmem_transform = (TransposeTransform((*squeezed_dims, *sliced_dims)),
*(t.batch(len(squeezed_dims)) for t in gmem_transform))
slice_shape = tuple(slice_shape)
for t in gmem_transform:
dyn_base_indices = t.transform_index(dyn_base_indices)
slice_shape = t.transform_shape(slice_shape)
for dim, squeezed in enumerate(is_squeezed):
if squeezed:
smem_ref = utils.memref_unsqueeze(smem_ref, dim)
smem_ref_ty = ir.MemRefType(smem_ref.type)
if slice_shape != tuple(smem_ref_ty.shape):
smem_ref_ty = ir.MemRefType(smem_ref.type)
# We moved all squeezed dims to the front.
if slice_shape[len(squeezed_dims):] != tuple(smem_ref_ty.shape):
raise ValueError(
"Expected the SMEM reference to have the same shape as the"
f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
@ -411,6 +437,7 @@ class LaunchContext:
dyn_base_indices = list(dyn_base_indices)
slice_shape = list(slice_shape)
assert all(d == 1 for d in slice_shape[:len(squeezed_dims)])
collective_size = 1
if collective is not None:
if isinstance(collective, gpu.Dimension):
@ -418,13 +445,16 @@ class LaunchContext:
collective_size = math.prod(self.cluster_size[d] for d in collective)
if collective_size > 1:
def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
# No need to partition squeezed dims. They don't even exist in smem_ref.
assert dim >= len(squeezed_dims)
nonlocal smem_ref
slice_shape[dim] //= num_chunks
block_offset = arith.muli(idx, c(slice_shape[dim], index))
dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset)
smem_ref = utils.memref_slice(
smem_ref,
(slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),)
(slice(None),) * (dim - len(squeezed_dims))
+ (utils.ds(block_offset, slice_shape[dim]),),
)
stride = 1
idx = c(0, index)
@ -440,10 +470,12 @@ class LaunchContext:
rem_collective_size = 1
break
elif rem_collective_size % slice_size == 0:
dim_idx = arith.remui(idx, c(slice_size, index))
partition_dim(dim, dim_idx, slice_size)
idx = arith.divui(idx, c(slice_size, index))
rem_collective_size //= slice_size
# This is an optimization and it lets us skip squeezed dims.
if slice_size > 1:
dim_idx = arith.remui(idx, c(slice_size, index))
partition_dim(dim, dim_idx, slice_size)
idx = arith.divui(idx, c(slice_size, index))
rem_collective_size //= slice_size
else:
break # We failed to partition the leading dimensions.
del idx # We overwrote the block index in the loop.

View File

@ -300,9 +300,7 @@ def build_kernel(
with ir.InsertionPoint(if_compute.else_block):
nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease)
with single_thread(per_block=False):
k_tr = (
TileTransform(tiling), TransposeTransform((0, 2, 1, 3, 4)),
)
k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3)))
v_tr = TileTransform(tiling)
kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head))
def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform):
@ -396,10 +394,7 @@ def build_kernel(
with single_thread(per_block=False):
txcount = 2 * blocks.kv * head_dim * bytewidth(f16)
barriers[slot].arrive_expect_tx(txcount)
k_tr = (
TileTransform(tiling),
TransposeTransform((0, 2, 1, 3, 4)),
)
k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3)))
v_tr = TileTransform(tiling)
for smem, gmem, t in ((k_smem, k_gmem, k_tr), (v_smem, v_gmem, v_tr)):
ctx.async_copy(

View File

@ -1060,6 +1060,30 @@ class TMATest(TestCase):
y = f(x)
np.testing.assert_array_equal(y, x)
def test_tma_load_indexed_tiled(self):
shape = (128, 2, 128)
tiling = mgpu.TileTransform((32, 32))
def kernel(ctx, src, dst, scratch):
tmp, barrier = scratch
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
barrier=barrier,
gmem_transform=tiling,
gmem_slice=(slice(None), 1, slice(None)),
)
barrier.wait()
ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_transform=tiling)
ctx.await_async_copy(0)
x = np.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
smem = (
jax.ShapeDtypeStruct((4, 4, 32, 32), jnp.float32),
mgpu.TMABarrier(),
)
out_shape = jax.ShapeDtypeStruct((128, 128), jnp.float32)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, out_shape, smem)
np.testing.assert_array_equal(f(x), x[:, 1, :])
@parameterized.product(
swizzle=(None, 128),
dtype=(jnp.float16, jnp.float32),