mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Make the small WGMMA tile independent of transpose flags
Now the small tiling is always `(8, swizzle // bytewidth(dtype))`, no matter whether the input is transposed or not. This should simply the follow-up refactoring of the code and make it easier to enable small tiling for LHS too. PiperOrigin-RevId: 732933005
This commit is contained in:
parent
ed4a7bbab1
commit
e9f95cc3a7
@ -375,11 +375,10 @@ def _validate_mma(
|
||||
"Row major RHS (N-fastest) requires the N tile size to be equal to"
|
||||
f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}"
|
||||
)
|
||||
if b_k_tiling not in {32 // element_bytewidth, swizzle_elems}:
|
||||
if b_k_tiling not in {8, swizzle_elems}:
|
||||
raise ValueError(
|
||||
"Row major RHS (N-fastest) requires the K tile size to be either"
|
||||
f" the swizzle tile size ({swizzle_elems}) or 32 bytes"
|
||||
f" ({32 // element_bytewidth}), but got {b_k_tiling}"
|
||||
f" the swizzle tile size ({swizzle_elems}) or 8, but got {b_k_tiling}"
|
||||
)
|
||||
elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest
|
||||
b_order = WGMMALayout.COL_MAJOR
|
||||
@ -479,7 +478,11 @@ def _validate_mma(
|
||||
# they would have uneven strides.
|
||||
b_desc_fields = dict(
|
||||
leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride,
|
||||
stride_byte_offset=swizzle_atom_bytes,
|
||||
# N tiles are contiguous, so the next N swizzle atom follows immediately.
|
||||
# K tiles are not contiguous, so we take the stride between them.
|
||||
stride_byte_offset=swizzle_atom_bytes
|
||||
if b_k_fastest or b_k_tiling == swizzle_elems
|
||||
else b_k_byte_stride,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
@ -498,9 +501,9 @@ def _validate_mma(
|
||||
b_k_wgmma_stride = swizzle * 16
|
||||
else:
|
||||
# If we use the small non-square tiling and N-fastest layout, each tile only
|
||||
# contains a single swizzle atom with the K coordinate, so we just look up
|
||||
# the next tile.
|
||||
b_k_wgmma_stride = b_k_byte_stride
|
||||
# contains a single swizzle atom with the K coordinate. But, each tile has
|
||||
# 8 rows, while the WGMMA K width is 16, so we need to jump over 2 tiles.
|
||||
b_k_wgmma_stride = b_k_byte_stride * 2
|
||||
wgmma_params = dict(
|
||||
a_transpose=not a_k_fastest,
|
||||
b_transpose=not b_k_fastest,
|
||||
|
@ -695,7 +695,7 @@ class WGMMATest(TestCase):
|
||||
k = nk_tile * k_steps
|
||||
assert m % 64 == 0 and n % nk_tile == 0
|
||||
|
||||
small_nk_tile = 8 if rhs_transpose else 16
|
||||
small_nk_tile = 8
|
||||
rhs_tiling = (
|
||||
(small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
|
||||
)
|
||||
@ -921,7 +921,7 @@ class TCGen05Test(TestCase):
|
||||
k = nk_tile * k_steps
|
||||
assert m % m_tile == 0 and n % nk_tile == 0
|
||||
|
||||
small_nk_tile = 8 if rhs_transpose else 16
|
||||
small_nk_tile = 8
|
||||
rhs_tiling = (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
|
||||
|
||||
def kernel(ctx, lhs, rhs, out, scratch):
|
||||
|
Loading…
x
Reference in New Issue
Block a user