[Mosaic GPU] Add support for small RHS tile sizes in WGMMA

This is useful for more fine-grained autotuning and can help avoid
wave quantization effects.

PiperOrigin-RevId: 732105219
This commit is contained in:
Adam Paszke 2025-02-28 05:40:39 -08:00 committed by jax authors
parent 1bc36e623b
commit bb96226dd8
4 changed files with 188 additions and 77 deletions

View File

@ -106,6 +106,7 @@ def mma(
raise ValueError(f"B must be a memref, got: {b.type}")
if a_swizzle != b_swizzle:
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
swizzle = a_swizzle
if isinstance(accumulate, bool):
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
num_cta = 2 if collective else 1
@ -114,11 +115,11 @@ def mma(
a_desc_base,
b_desc_base,
(m, k, n),
(m_mem_tiling, kn_mem_tiling),
(m_mem_tiling, n_mem_tiling),
element_type,
mma_params,
a_k_byte_stride,
b_k_byte_stride,
a_k_group_stride,
b_k_group_stride,
) = _wgmma._validate_mma(
a,
b,
@ -127,7 +128,7 @@ def mma(
)
# The sizes of instruction we'll be using
k_instr_tiling = kn_mem_tiling
k_instr_tiling = swizzle // utils.bytewidth(element_type)
if (m_instr_tiling := d.layout.elements_in_tile[0]) != m_mem_tiling:
raise ValueError(
f"A's row tiling must be equal to {m_instr_tiling} (inferred from"
@ -143,14 +144,14 @@ def mma(
raise NotImplementedError("The only supported N larger than 256 is 512")
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_m_byte_stride = a_strides[0] * utils.bytewidth(element_type)
a_m_group_stride = a_strides[0] * utils.bytewidth(element_type)
b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset()
b_n_byte_stride = b_strides[1] * utils.bytewidth(element_type)
b_n_group_stride = b_strides[1] * utils.bytewidth(element_type)
groups_k = k // k_instr_tiling
groups_m = m // m_instr_tiling
groups_n = n // n_instr_tiling
n_mem_tiles_per_instr = n_instr_tiling // kn_mem_tiling
n_mem_tiles_per_instr = n_instr_tiling // n_mem_tiling
# TODO(apaszke): Verify that the cluster shape matches the expectation of
# collective MMA.
@ -162,9 +163,9 @@ def mma(
true = arith.constant(ir.IntegerType.get_signless(1), 1)
for mi, ni, ki in np.ndindex(groups_m, groups_n, groups_k):
a_offset = mi * a_m_byte_stride + ki * a_k_byte_stride
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64))
b_offset = ni * n_mem_tiles_per_instr * b_n_byte_stride + ki * b_k_byte_stride
b_offset = ni * n_mem_tiles_per_instr * b_n_group_stride + ki * b_k_group_stride
b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64))
if groups_m != 1:
raise NotImplementedError("D needs to be sliced")

View File

@ -345,90 +345,152 @@ def _validate_mma(
if element_type not in supported_types:
raise ValueError(a_element_type)
element_bytewidth = bytewidth(element_type)
kn_tiling = swizzle // element_bytewidth
swizzle_elems = swizzle // element_bytewidth
# Verify the shape and strides of B are as expected.
k_tiles, n_tiles, k_tiling, n_tiling = b_ty.shape
if k_tiling != kn_tiling:
raise ValueError(b_ty.shape)
# Note that while this technically allows n to be smaller than kn_tile,
# the stride checks above will still enforce that the memory region is padded.
# It might be possible to relax that requirement, but I haven't tested it.
if n_tiling > kn_tiling and n_tiling % kn_tiling:
raise ValueError(n_tiling, kn_tiling)
k = k_tiles * kn_tiling
b_k_tiles, n_tiles, b_k_tiling, n_tiling = b_ty.shape
k = b_k_tiles * b_k_tiling
n = n_tiles * n_tiling
b_strides, _ = b_ty.get_strides_and_offset()
b_byte_strides = [s * element_bytewidth for s in b_strides]
b_k_byte_stride, b_n_byte_stride, *b_tile_byte_strides = b_byte_strides
# TODO(apaszke): Relax tiling here! But make sure that the space between N tiles is same.
if b_byte_strides[1] != swizzle * kn_tiling:
raise ValueError(b_byte_strides)
if b_tile_byte_strides == [swizzle, element_bytewidth]:
if (
b_byte_strides[1] != n_tiling * b_k_tiling * element_bytewidth
and n_tiles != 1 # When there's only one tile, we never jump between them
):
raise ValueError("B tiles must be contiguous along the N dimension")
if b_tile_byte_strides == [swizzle, element_bytewidth]: # N-fastest
b_order = WGMMALayout.ROW_MAJOR
elif b_tile_byte_strides == [element_bytewidth, swizzle]:
# This first case (n_tiles == 1) is to allow the somewhat weird case of
# loading a small amount of N-fastest data, that needs to be padded to a
# larger tile due to swizzle. In this case we allow slicing the big tile
# before WGMMA to avoid unnecessary compute on padding.
if n_tiles == 1:
if n_tiling % 8:
raise ValueError("N tile size must be a multiple of 8")
elif n_tiling != swizzle_elems:
raise ValueError(
"Row major RHS (N-fastest) requires the N tile size to be equal to"
f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}"
)
if b_k_tiling not in {32 // element_bytewidth, swizzle_elems}:
raise ValueError(
"Row major RHS (N-fastest) requires the K tile size to be either"
f" the swizzle tile size ({swizzle_elems}) or 32 bytes"
f" ({32 // element_bytewidth}), but got {b_k_tiling}"
)
elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest
b_order = WGMMALayout.COL_MAJOR
if b_k_tiling != swizzle_elems:
raise ValueError(
"Column major RHS (K-fastest) requires the K tile size to be equal"
f" to the swizzle tile size ({swizzle_elems}), but got {b_k_tiling}"
)
# See the explanation in the N-fastest case when n_tiles == 1.
if n_tiles == 1:
if n_tiling % 8:
raise ValueError("N tile size must be a multiple of 8")
elif n_tiling not in {8, swizzle_elems}:
raise ValueError(
"Column major RHS (K-fastest) requires the N tile size to be either"
f" to the swizzle tile size ({swizzle_elems}) or 8, but got {n_tiling}"
)
else:
raise ValueError(b_byte_strides)
# Verify the shape and strides of A are as expected.
if not a_in_smem:
m = a_shape[0]
a_order = m_tiling = a_m_byte_stride = None
a_order = m_tiling = None
else:
a_ty = ir.MemRefType(a.type)
m_tiles, k_tiles, m_tiling, k_tiling = a_ty.shape
m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape
m = m_tiles * m_tiling
if k_tiling != kn_tiling or k_tiles * k_tiling != k:
if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k:
raise ValueError(a_ty.shape)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_byte_strides = [s * element_bytewidth for s in a_strides]
a_m_byte_stride = a_byte_strides[0]
if a_byte_strides[2:] == [swizzle, element_bytewidth]:
a_order = WGMMALayout.ROW_MAJOR
elif a_byte_strides[2:] == [element_bytewidth, swizzle]:
a_order = WGMMALayout.COL_MAJOR
else:
raise ValueError(a_byte_strides)
if a_order != WGMMALayout.ROW_MAJOR and m_tiling != kn_tiling:
if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems:
# Not sure what the layout is like, since the tiles aren't square.
raise NotImplementedError
b_k_fastest = b_order == WGMMALayout.COL_MAJOR
a_k_fastest = a_order == WGMMALayout.ROW_MAJOR
# Here "leading" refers to the fastest changing dimension.
# Leading byte offset (LBO)
# K-fastest: ignored
# MN-fastest: stride between consecutive that share the same K coordinate.
# Stride byte offset (SBO)
# K-fastest: offset from one swizzle atom to the next
# MN-fastest: offset from one swizzle atom to the next
# This is the number of rows until consecutive repeats of the swizzle pattern.
swizzle_pattern_rows = swizzle // 16
# A swizzle atom is a 2D matrix with the dimensions below.
swizzle_atom_bytes = swizzle_pattern_rows * 128
# Here "leading" refers to the fastest changing dimension. There are two
# strides we have to define per value:
# Leading byte offset (LBO)
# K-fastest: ignored
# MN-fastest: stride between consecutive swizzle atoms that share the same
# K coordinate.
# Stride byte offset (SBO)
# As far as I can tell this is just the offset between two consecutive
# swizzle atoms along the non-leading dimension.
IGNORED = 0
a_desc_fields = dict(
# TODO(apaszke): a_m_byte_stride works, but is not convincing to me.
# After all MMA can only consume a fixed number of bytes from LHS.
leading_byte_offset=16 if a_k_fastest else a_m_byte_stride,
stride_byte_offset=128 * swizzle_pattern_rows,
# I can't fully explain why WGMMA ignores LBO for A. For a_k_fastest, it
# is documented in the PTX docs, and my best explanation for the other
# case is that the instruction has a fixed shape and so it does not care
# about strides. It's possible that it's an artifact of the fact that we
# use tiling of 64.
leading_byte_offset=IGNORED,
stride_byte_offset=swizzle_atom_bytes,
swizzle=swizzle,
memory_space=3,
)
# If B is N-fastest, all swizzle atoms within a tile share the same N
# coordinate, so we simply take the stride between consecutive N tiles.
# If B is K-fastest, all swizzle atoms within a tile share the same K
# coordinate, which forces us to lay out the tiles in N-fastest order or else
# they would have uneven strides.
b_desc_fields = dict(
leading_byte_offset=16 if b_k_fastest else b_n_byte_stride,
stride_byte_offset=128 * swizzle_pattern_rows,
leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride,
stride_byte_offset=swizzle_atom_bytes,
swizzle=swizzle,
memory_space=3,
)
# The K strides indicate the stride between the consecutive places where all
# coordinates are 0 except for K being incremented by the instruction width.
# If an input is K-fastest, we increment the descriptor by 32 bytes, since
# that is the K-width of all MMA instructions.
# TODO(apaszke): I don't have a good explanation for the MN-fastest case yet.
if b_k_fastest:
b_k_wgmma_stride = 32
b_k_group_stride = b_k_byte_stride # The tile has only one K swizzle atom.
elif b_k_tiling == swizzle_elems:
# When B is N-fastest and we use the large square tiling, the relevant
# slices all fall within the first tile. A single MMA instruction for 16-bit
# types reads a subtile of shape 16x(swizzle bytes), giving us the necessary
# expression.
assert n_tiling == swizzle_elems or n_tiles == 1
b_k_wgmma_stride = swizzle * 16
b_k_group_stride = b_k_byte_stride
else:
# If we use the small non-square tiling and N-fastest layout, each tile only
# contains a single swizzle atom with the K coordinate, so we just look up
# the next tile.
b_k_wgmma_stride = b_k_byte_stride
wgmma_in_group = swizzle // 32
b_k_group_stride = b_k_byte_stride * wgmma_in_group
wgmma_params = dict(
a_transpose=not a_k_fastest,
b_transpose=not b_k_fastest,
# TODO(apaszke): This explanation is quite bad. We should better figure
# out how to do LHS transposes.
# We only support swizzle=128 for M-fastest A. In this case the tile is
# swizzle x 64 (= swizzle elems) and so we just take a quarter of its size.
a_k_stride=32 if a_k_fastest else swizzle * 16,
b_k_stride=32 if b_k_fastest else swizzle * 16,
b_k_stride=b_k_wgmma_stride,
swizzle=swizzle,
element_type=ir.FloatTF32Type.get()
if ir.F32Type.isinstance(element_type)
@ -436,13 +498,13 @@ def _validate_mma(
)
if not a_in_smem:
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
a_k_byte_stride = a_desc_base = None
a_k_group_stride = a_desc_base = None
else:
a_desc_base = create_descriptor(
a, **a_desc_fields, const_init=descriptor_const_init
)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_k_byte_stride = a_strides[1] * element_bytewidth
a_k_group_stride = a_strides[1] * element_bytewidth
b_desc_base = create_descriptor(
b, **b_desc_fields, const_init=descriptor_const_init
)
@ -451,11 +513,11 @@ def _validate_mma(
a_desc_base,
b_desc_base,
(m, k, n),
(m_tiling, kn_tiling),
(m_tiling, n_tiling),
element_type,
wgmma_params,
a_k_byte_stride,
b_k_byte_stride,
a_k_group_stride,
b_k_group_stride,
)
@ -488,11 +550,11 @@ def wgmma(
a_desc_base,
b_desc_base,
(m, k, n),
(m_tiling, kn_tiling),
(m_tiling, _),
element_type,
wgmma_params,
a_k_byte_stride,
b_k_byte_stride,
a_k_group_stride,
b_k_group_stride,
) = _validate_mma(a, b, swizzle)
if n > 256:
@ -505,14 +567,15 @@ def wgmma(
)
if a.shape[0] % 64:
raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}")
a_m_byte_stride = None
a_m_group_stride = None
else:
if m_tiling != 64:
raise ValueError(f"A must have rows tiled by 64, got: {m_tiling}")
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_m_byte_stride = a_strides[0] * bytewidth(element_type)
a_m_group_stride = a_strides[0] * bytewidth(element_type)
groups_k = k // kn_tiling
k_group_width = swizzle // bytewidth(element_type)
groups_k = k // k_group_width
groups_m = m // 64
expected_acc_shape = (groups_m * 64, n)
@ -530,13 +593,16 @@ def wgmma(
for mi in range(groups_m):
for ki in range(groups_k):
if a_in_regs:
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tiling : (ki + 1) * kn_tiling]
a_mk = a[
mi * 64 : (mi + 1) * 64,
ki * k_group_width : (ki + 1) * k_group_width
]
else:
a_mk = llvm_add(
a_desc_base,
c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
c(wgmma_encode(mi * a_m_group_stride + ki * a_k_group_stride), i64),
)
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64))
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64))
new_acc_regs[mi : mi + 1] = wgmma_m64(
new_acc_regs[mi : mi + 1], a_mk, b_k, n=n, **wgmma_params
)

View File

@ -38,7 +38,7 @@ jax_multiplatform_test(
"gpu_h100x2",
],
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
shard_count = 8,
shard_count = 16,
tags = ["multiaccelerator"],
deps = [
"//jax:mosaic_gpu",

View File

@ -652,6 +652,7 @@ class WGMMATest(TestCase):
k_steps=(1, 2),
swizzle=(32, 64, 128),
jax_out_dtype=(jnp.float16, jnp.float32),
small_rhs_tile=(False, True,),
)
def test_wgmma_basic(
self,
@ -663,6 +664,7 @@ class WGMMATest(TestCase):
rhs_transpose,
swizzle,
jax_out_dtype,
small_rhs_tile,
):
if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type:
raise self.skipTest("Only f16 input is supported for f16 output.")
@ -693,13 +695,18 @@ class WGMMATest(TestCase):
k = nk_tile * k_steps
assert m % 64 == 0 and n % nk_tile == 0
small_nk_tile = 8 if rhs_transpose else 16
rhs_tiling = (
(small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
)
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers = scratch
lhs_transform = (mgpu.TileTransform((64, nk_tile)),)
if lhs_transpose:
assert nk_tile == 64 # Make sure we didn't have to transpose tiling.
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
if rhs_transpose:
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
ctx.async_copy(
@ -737,10 +744,14 @@ class WGMMATest(TestCase):
y_shape = (n, k) if rhs_transpose else (k, n)
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype)
rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling
scratch_shape = [
jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype),
jax.ShapeDtypeStruct(
(k // nk_tile, n // nk_tile, nk_tile, nk_tile), in_jax_dtype
(m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype
),
jax.ShapeDtypeStruct(
(k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling),
in_jax_dtype,
),
mgpu.TMABarrier(2),
]
@ -812,18 +823,23 @@ class WGMMATest(TestCase):
@parameterized.product(
rhs_transpose=(False, True),
swizzle=(32, 64, 128),
n=(8, 16),
small_rhs_tile=(False, True),
)
def test_narrow_n(self, rhs_transpose, swizzle):
m, n, k_steps = 64, 8, 2
def test_narrow_n(self, rhs_transpose, swizzle, n, small_rhs_tile):
m, k_steps = 64, 2
bytewidth = 2
nk_tile = swizzle // bytewidth
k = nk_tile * k_steps
if small_rhs_tile and not rhs_transpose:
self.skipTest("Small tiles only supported for transposed RHS")
n_tile = 8 if small_rhs_tile else nk_tile
def kernel(ctx, rhs, out, smem):
rhs_smem, barrier = smem
gmem_slice = (ds(0, k), ds(0, nk_tile))
transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
gmem_slice = (ds(0, k), ds(0, max(n_tile, n)))
transform = (mgpu.TileTransform((n_tile, nk_tile)),)
if rhs_transpose:
gmem_slice = gmem_slice[::-1]
transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
@ -840,8 +856,9 @@ class WGMMATest(TestCase):
lhs_regs = iota_tensor(m, k, jnp.float16)
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
smem_slice = (slice(None), slice(None), slice(None), ds(0, n))
rhs_smem = memref_slice(rhs_smem, smem_slice)
if not small_rhs_tile:
smem_slice = (slice(None), slice(None), slice(None), ds(0, n))
rhs_smem = memref_slice(rhs_smem, smem_slice)
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
@ -852,7 +869,7 @@ class WGMMATest(TestCase):
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
rhs_scratch_shape = jax.ShapeDtypeStruct(
(k_steps, 1, nk_tile, nk_tile), jax_dtype
(k_steps, (n + n_tile - 1) // n_tile, n_tile, nk_tile), jax_dtype
)
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()),
@ -881,6 +898,7 @@ class TCGen05Test(TestCase):
n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
k_steps=(1, 2),
swizzle=(32, 64, 128,),
small_rhs_tile=(False, True),
)
def test_mma_basic(
self,
@ -892,6 +910,7 @@ class TCGen05Test(TestCase):
rhs_transpose,
in_jax_dtype,
out_jax_dtype,
small_rhs_tile,
):
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
raise self.skipTest("Only f16 input is supported for f16 output.")
@ -902,13 +921,16 @@ class TCGen05Test(TestCase):
k = nk_tile * k_steps
assert m % m_tile == 0 and n % nk_tile == 0
small_nk_tile = 8 if rhs_transpose else 16
rhs_tiling = (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers, acc = scratch
lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),)
if lhs_transpose:
assert nk_tile == m_tile # Make sure we didn't have to transpose tiling
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
if rhs_transpose:
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
ctx.async_copy(
@ -950,9 +972,15 @@ class TCGen05Test(TestCase):
y_shape = (n, k) if rhs_transpose else (k, n)
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling
scratch_shape = [
jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype),
jax.ShapeDtypeStruct(
tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype
),
jax.ShapeDtypeStruct(
(k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling),
in_jax_dtype,
),
mgpu.TMABarrier(3),
mgpu.TMEM((128, n), out_jax_dtype),
]
@ -973,6 +1001,7 @@ class TCGen05Test(TestCase):
n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2
k_steps=(1, 2),
swizzle=(32, 64, 128,),
small_rhs_tile=(False, True),
)
def test_mma_collective(
self,
@ -984,6 +1013,7 @@ class TCGen05Test(TestCase):
rhs_transpose,
in_jax_dtype,
out_jax_dtype,
small_rhs_tile,
):
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
raise self.skipTest("Only f16 input is supported for f16 output.")
@ -997,13 +1027,20 @@ class TCGen05Test(TestCase):
assert m % m_tma_tile == 0 and n % nk_tma_tile == 0
index = ir.IndexType.get()
small_nk_tile = 8 if rhs_transpose else 16
rhs_tiling = (
(small_nk_tile, nk_tma_tile)
if small_rhs_tile
else (nk_tma_tile, nk_tma_tile)
)
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers, acc = scratch
lhs_transform = (mgpu.TileTransform((m_tma_tile, nk_tma_tile)),)
if lhs_transpose:
assert nk_tma_tile == m_tma_tile # Make sure we didn't have to transpose tiling
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
rhs_transform = (mgpu.TileTransform((nk_tma_tile, nk_tma_tile)),)
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
if rhs_transpose:
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
block_id = gpu.cluster_block_id(gpu.Dimension.x)
@ -1053,9 +1090,16 @@ class TCGen05Test(TestCase):
y_shape = (n, k) if rhs_transpose else (k, n)
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling
scratch_shape = [
jax.ShapeDtypeStruct(tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape((k, n_block_tile), (nk_tma_tile, nk_tma_tile)), in_jax_dtype),
jax.ShapeDtypeStruct(
tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)),
in_jax_dtype,
),
jax.ShapeDtypeStruct(
(k // rhs_tiling_t[0], n_block_tile // rhs_tiling_t[1], *rhs_tiling),
in_jax_dtype,
),
mgpu.TMABarrier(3),
mgpu.TMEM((128, n), out_jax_dtype, collective=True),
]