[Mosaic GPU] Make the small WGMMA tile independent of transpose flags

Now the small tiling is always `(8, swizzle // bytewidth(dtype))`, no matter whether the input
is transposed or not. This should simply the follow-up refactoring of the code and make it easier
to enable small tiling for LHS too.

PiperOrigin-RevId: 732933005
This commit is contained in:
Adam Paszke 2025-03-03 08:28:49 -08:00 committed by jax authors
parent ed4a7bbab1
commit e9f95cc3a7
2 changed files with 12 additions and 9 deletions

View File

@ -375,11 +375,10 @@ def _validate_mma(
"Row major RHS (N-fastest) requires the N tile size to be equal to"
f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}"
)
if b_k_tiling not in {32 // element_bytewidth, swizzle_elems}:
if b_k_tiling not in {8, swizzle_elems}:
raise ValueError(
"Row major RHS (N-fastest) requires the K tile size to be either"
f" the swizzle tile size ({swizzle_elems}) or 32 bytes"
f" ({32 // element_bytewidth}), but got {b_k_tiling}"
f" the swizzle tile size ({swizzle_elems}) or 8, but got {b_k_tiling}"
)
elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest
b_order = WGMMALayout.COL_MAJOR
@ -479,7 +478,11 @@ def _validate_mma(
# they would have uneven strides.
b_desc_fields = dict(
leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride,
stride_byte_offset=swizzle_atom_bytes,
# N tiles are contiguous, so the next N swizzle atom follows immediately.
# K tiles are not contiguous, so we take the stride between them.
stride_byte_offset=swizzle_atom_bytes
if b_k_fastest or b_k_tiling == swizzle_elems
else b_k_byte_stride,
swizzle=swizzle,
memory_space=3,
)
@ -498,9 +501,9 @@ def _validate_mma(
b_k_wgmma_stride = swizzle * 16
else:
# If we use the small non-square tiling and N-fastest layout, each tile only
# contains a single swizzle atom with the K coordinate, so we just look up
# the next tile.
b_k_wgmma_stride = b_k_byte_stride
# contains a single swizzle atom with the K coordinate. But, each tile has
# 8 rows, while the WGMMA K width is 16, so we need to jump over 2 tiles.
b_k_wgmma_stride = b_k_byte_stride * 2
wgmma_params = dict(
a_transpose=not a_k_fastest,
b_transpose=not b_k_fastest,

View File

@ -695,7 +695,7 @@ class WGMMATest(TestCase):
k = nk_tile * k_steps
assert m % 64 == 0 and n % nk_tile == 0
small_nk_tile = 8 if rhs_transpose else 16
small_nk_tile = 8
rhs_tiling = (
(small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
)
@ -921,7 +921,7 @@ class TCGen05Test(TestCase):
k = nk_tile * k_steps
assert m % m_tile == 0 and n % nk_tile == 0
small_nk_tile = 8 if rhs_transpose else 16
small_nk_tile = 8
rhs_tiling = (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
def kernel(ctx, lhs, rhs, out, scratch):