[Mosaic GPU] Add support for tiled swizzle=16 (i.e. no swizzle) loads and stores

The tiling still makes it possible to do it without bank conflicts.

PiperOrigin-RevId: 721701635
This commit is contained in:
Adam Paszke 2025-01-31 02:49:27 -08:00 committed by jax authors
parent efacec4cfb
commit 10ac6b7e12
2 changed files with 8 additions and 3 deletions

View File

@ -1695,7 +1695,10 @@ class FragmentedArray:
raise ValueError("Memory tiling must be a multiple of the register tiling")
ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):]
if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)):
raise ValueError("Memory tiling must be a multiple of the register tiling")
raise ValueError(
f"Memory tiling ({ref_tiling_suffix}) must be a multiple of the"
f" register tiling ({layout.base_tile_shape})"
)
elem_tiled_strides = list(tiling.tile_strides(tuple(ref_strides)))
tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape)))
@ -1728,7 +1731,7 @@ class FragmentedArray:
" vector dimension"
)
if swizzle not in {32, 64, 128}:
if swizzle not in {16, 32, 64, 128}:
raise ValueError("Only swizzled transfers supported")
# We will be computing the offsets in units of vectors, not elements,
# to better support sub-byte types.

View File

@ -1884,13 +1884,15 @@ class LayoutTest(TestCase):
load_tiled=[False, True],
store_tiled=[False, True],
dtype=[jnp.int8, jnp.int16, jnp.int32],
swizzle=[32, 64, 128],
swizzle=[16, 32, 64, 128],
num_col_tiles=[1, 2, 3],
)
def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles):
mlir_dtype = utils.dtype_to_ir_type(dtype)
bw = bytewidth(mlir_dtype)
col_tiling = swizzle // bw
if col_tiling % 8:
self.skipTest("WGMMA layout requires col_tiling % 8 == 0")
m, n = 128, col_tiling * num_col_tiles
tiling = (64, col_tiling)
tiled_layout = fa._tiled_wgmma_layout((m, n))