mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add support for tiled swizzle=16 (i.e. no swizzle) loads and stores
The tiling still makes it possible to do it without bank conflicts. PiperOrigin-RevId: 721701635
This commit is contained in:
parent
efacec4cfb
commit
10ac6b7e12
@ -1695,7 +1695,10 @@ class FragmentedArray:
|
||||
raise ValueError("Memory tiling must be a multiple of the register tiling")
|
||||
ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):]
|
||||
if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)):
|
||||
raise ValueError("Memory tiling must be a multiple of the register tiling")
|
||||
raise ValueError(
|
||||
f"Memory tiling ({ref_tiling_suffix}) must be a multiple of the"
|
||||
f" register tiling ({layout.base_tile_shape})"
|
||||
)
|
||||
|
||||
elem_tiled_strides = list(tiling.tile_strides(tuple(ref_strides)))
|
||||
tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape)))
|
||||
@ -1728,7 +1731,7 @@ class FragmentedArray:
|
||||
" vector dimension"
|
||||
)
|
||||
|
||||
if swizzle not in {32, 64, 128}:
|
||||
if swizzle not in {16, 32, 64, 128}:
|
||||
raise ValueError("Only swizzled transfers supported")
|
||||
# We will be computing the offsets in units of vectors, not elements,
|
||||
# to better support sub-byte types.
|
||||
|
@ -1884,13 +1884,15 @@ class LayoutTest(TestCase):
|
||||
load_tiled=[False, True],
|
||||
store_tiled=[False, True],
|
||||
dtype=[jnp.int8, jnp.int16, jnp.int32],
|
||||
swizzle=[32, 64, 128],
|
||||
swizzle=[16, 32, 64, 128],
|
||||
num_col_tiles=[1, 2, 3],
|
||||
)
|
||||
def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles):
|
||||
mlir_dtype = utils.dtype_to_ir_type(dtype)
|
||||
bw = bytewidth(mlir_dtype)
|
||||
col_tiling = swizzle // bw
|
||||
if col_tiling % 8:
|
||||
self.skipTest("WGMMA layout requires col_tiling % 8 == 0")
|
||||
m, n = 128, col_tiling * num_col_tiles
|
||||
tiling = (64, col_tiling)
|
||||
tiled_layout = fa._tiled_wgmma_layout((m, n))
|
||||
|
Loading…
x
Reference in New Issue
Block a user