mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Fix the ordering of transforms in async_copy
Previously we didn't really fully discharge squeezing the indexed dims before applying other GMEM transforms, leading to potential failures because they were not anticipating the increased rank. PiperOrigin-RevId: 694098739
This commit is contained in:
parent
4cc80889b6
commit
506671291a
@ -133,6 +133,14 @@ class MemRefTransform:
|
||||
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
||||
raise NotImplementedError("Subclasses should override this method")
|
||||
|
||||
def batch(self, leading_rank: int) -> 'MemRefTransform':
|
||||
"""Returns a transform that accepts a ref with the extra `leading_rank` dims.
|
||||
|
||||
The returned transform should leave the leading dimensions unchanged and
|
||||
only apply to the suffix of the shape.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses should override this method")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TileTransform(MemRefTransform):
|
||||
@ -198,6 +206,9 @@ class TileTransform(MemRefTransform):
|
||||
*self.tiling,
|
||||
)
|
||||
|
||||
def batch(self, leading_rank: int) -> MemRefTransform:
|
||||
return self
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TransposeTransform(MemRefTransform):
|
||||
@ -217,6 +228,11 @@ class TransposeTransform(MemRefTransform):
|
||||
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
||||
return tuple(shape[p] for p in self.permutation)
|
||||
|
||||
def batch(self, leading_rank: int) -> MemRefTransform:
|
||||
return TransposeTransform(
|
||||
(*range(leading_rank), *(d + leading_rank for d in self.permutation))
|
||||
)
|
||||
|
||||
|
||||
OnDeviceProfiler = profiler.OnDeviceProfiler
|
||||
|
||||
@ -388,16 +404,26 @@ class LaunchContext:
|
||||
dyn_base_indices = tuple(
|
||||
c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
|
||||
)
|
||||
squeezed_dims = [i for i, squeezed in enumerate(is_squeezed) if squeezed]
|
||||
sliced_dims = [i for i, squeezed in enumerate(is_squeezed) if not squeezed]
|
||||
# Indexing is really slicing + squeezing, and user transforms are meant to
|
||||
# apply after that. However, we actually have to apply the indexing last
|
||||
# (it's fused into the TMA) and so we need to commute it with all the user
|
||||
# transforms. For slicing this is done using transform_index and
|
||||
# transform_shape. For squeezing we actually move all the squeezed dims to
|
||||
# the front, and then batch each transform, making it ignore the extra dims.
|
||||
if squeezed_dims:
|
||||
gmem_transform = (TransposeTransform((*squeezed_dims, *sliced_dims)),
|
||||
*(t.batch(len(squeezed_dims)) for t in gmem_transform))
|
||||
|
||||
slice_shape = tuple(slice_shape)
|
||||
for t in gmem_transform:
|
||||
dyn_base_indices = t.transform_index(dyn_base_indices)
|
||||
slice_shape = t.transform_shape(slice_shape)
|
||||
for dim, squeezed in enumerate(is_squeezed):
|
||||
if squeezed:
|
||||
smem_ref = utils.memref_unsqueeze(smem_ref, dim)
|
||||
smem_ref_ty = ir.MemRefType(smem_ref.type)
|
||||
|
||||
if slice_shape != tuple(smem_ref_ty.shape):
|
||||
smem_ref_ty = ir.MemRefType(smem_ref.type)
|
||||
# We moved all squeezed dims to the front.
|
||||
if slice_shape[len(squeezed_dims):] != tuple(smem_ref_ty.shape):
|
||||
raise ValueError(
|
||||
"Expected the SMEM reference to have the same shape as the"
|
||||
f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
|
||||
@ -411,6 +437,7 @@ class LaunchContext:
|
||||
|
||||
dyn_base_indices = list(dyn_base_indices)
|
||||
slice_shape = list(slice_shape)
|
||||
assert all(d == 1 for d in slice_shape[:len(squeezed_dims)])
|
||||
collective_size = 1
|
||||
if collective is not None:
|
||||
if isinstance(collective, gpu.Dimension):
|
||||
@ -418,13 +445,16 @@ class LaunchContext:
|
||||
collective_size = math.prod(self.cluster_size[d] for d in collective)
|
||||
if collective_size > 1:
|
||||
def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
|
||||
# No need to partition squeezed dims. They don't even exist in smem_ref.
|
||||
assert dim >= len(squeezed_dims)
|
||||
nonlocal smem_ref
|
||||
slice_shape[dim] //= num_chunks
|
||||
block_offset = arith.muli(idx, c(slice_shape[dim], index))
|
||||
dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset)
|
||||
smem_ref = utils.memref_slice(
|
||||
smem_ref,
|
||||
(slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),)
|
||||
(slice(None),) * (dim - len(squeezed_dims))
|
||||
+ (utils.ds(block_offset, slice_shape[dim]),),
|
||||
)
|
||||
stride = 1
|
||||
idx = c(0, index)
|
||||
@ -440,10 +470,12 @@ class LaunchContext:
|
||||
rem_collective_size = 1
|
||||
break
|
||||
elif rem_collective_size % slice_size == 0:
|
||||
dim_idx = arith.remui(idx, c(slice_size, index))
|
||||
partition_dim(dim, dim_idx, slice_size)
|
||||
idx = arith.divui(idx, c(slice_size, index))
|
||||
rem_collective_size //= slice_size
|
||||
# This is an optimization and it lets us skip squeezed dims.
|
||||
if slice_size > 1:
|
||||
dim_idx = arith.remui(idx, c(slice_size, index))
|
||||
partition_dim(dim, dim_idx, slice_size)
|
||||
idx = arith.divui(idx, c(slice_size, index))
|
||||
rem_collective_size //= slice_size
|
||||
else:
|
||||
break # We failed to partition the leading dimensions.
|
||||
del idx # We overwrote the block index in the loop.
|
||||
|
@ -300,9 +300,7 @@ def build_kernel(
|
||||
with ir.InsertionPoint(if_compute.else_block):
|
||||
nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease)
|
||||
with single_thread(per_block=False):
|
||||
k_tr = (
|
||||
TileTransform(tiling), TransposeTransform((0, 2, 1, 3, 4)),
|
||||
)
|
||||
k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3)))
|
||||
v_tr = TileTransform(tiling)
|
||||
kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head))
|
||||
def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform):
|
||||
@ -396,10 +394,7 @@ def build_kernel(
|
||||
with single_thread(per_block=False):
|
||||
txcount = 2 * blocks.kv * head_dim * bytewidth(f16)
|
||||
barriers[slot].arrive_expect_tx(txcount)
|
||||
k_tr = (
|
||||
TileTransform(tiling),
|
||||
TransposeTransform((0, 2, 1, 3, 4)),
|
||||
)
|
||||
k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3)))
|
||||
v_tr = TileTransform(tiling)
|
||||
for smem, gmem, t in ((k_smem, k_gmem, k_tr), (v_smem, v_gmem, v_tr)):
|
||||
ctx.async_copy(
|
||||
|
@ -1060,6 +1060,30 @@ class TMATest(TestCase):
|
||||
y = f(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_tma_load_indexed_tiled(self):
|
||||
shape = (128, 2, 128)
|
||||
tiling = mgpu.TileTransform((32, 32))
|
||||
def kernel(ctx, src, dst, scratch):
|
||||
tmp, barrier = scratch
|
||||
ctx.async_copy(
|
||||
src_ref=src,
|
||||
dst_ref=tmp,
|
||||
barrier=barrier,
|
||||
gmem_transform=tiling,
|
||||
gmem_slice=(slice(None), 1, slice(None)),
|
||||
)
|
||||
barrier.wait()
|
||||
ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_transform=tiling)
|
||||
ctx.await_async_copy(0)
|
||||
x = np.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
|
||||
smem = (
|
||||
jax.ShapeDtypeStruct((4, 4, 32, 32), jnp.float32),
|
||||
mgpu.TMABarrier(),
|
||||
)
|
||||
out_shape = jax.ShapeDtypeStruct((128, 128), jnp.float32)
|
||||
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, out_shape, smem)
|
||||
np.testing.assert_array_equal(f(x), x[:, 1, :])
|
||||
|
||||
@parameterized.product(
|
||||
swizzle=(None, 128),
|
||||
dtype=(jnp.float16, jnp.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user