[mgpu] Non-vector untiled stores for tiling layouts.

Useful for storing in memrefs where the minormost stride is >1.

PiperOrigin-RevId: 733551038
This commit is contained in:
Christos Perivolaropoulos 2025-03-04 19:40:14 -08:00 committed by jax authors
parent 766315f791
commit 51719a1afe
2 changed files with 38 additions and 10 deletions

View File

@ -1441,19 +1441,27 @@ class FragmentedArray:
if create_array:
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
def store_untiled(self, ref: ir.Value):
def store_untiled(self, ref: ir.Value, *, vector_store: bool = True):
if not ir.MemRefType.isinstance(ref.type):
raise ValueError(ref)
def vs_unsupported():
if not vector_store:
raise NotImplementedError(
f"Can't use non-vector stores with layout {self.layout}"
)
match self.layout:
case WGMMAFragLayout():
self._store_untiled_wgmma(ref)
case WGSplatFragLayout():
vs_unsupported()
self._store_untiled_splat(ref)
case WGStridedFragLayout():
vs_unsupported()
self._store_untiled_wg_strided(ref)
case TiledLayout():
self._store_untiled_tiled(ref)
self._store_untiled_tiled(ref, vector_store=vector_store)
case _:
raise NotImplementedError(self.layout)
@ -1520,7 +1528,7 @@ class FragmentedArray:
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
memref.store(value, ref, [row, col])
def _store_untiled_tiled(self, ref: ir.Value):
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
"""Stores an array with a tiled layout. Not optimized at the moment."""
if utils.bitwidth(self.mlir_dtype) < 8:
raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})")
@ -1528,7 +1536,7 @@ class FragmentedArray:
layout = self.layout
assert isinstance(layout, TiledLayout)
ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset()
if ref_strides[layout.vector_dim] != 1:
if vector_store and ref_strides[layout.vector_dim] != 1:
raise NotImplementedError(
"Can't use vector stores with non-unit minormost stride"
)
@ -1549,9 +1557,21 @@ class FragmentedArray:
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
# All warp tile offsets are static and can be fused into the store.
for tile_idx, reg in np.ndenumerate(self.registers):
lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True))
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
llvm.store(reg, reg_ptr)
if vector_store:
elems = [reg]
else:
index = ir.IndexType.get()
elems = [
vector.extractelement(reg, position=c(i, index))
for i in range(ir.VectorType(reg.type).shape[0])
]
for i, e in enumerate(elems):
tile_idx_local = list(tile_idx)
tile_idx_local[layout.vector_dim] += i
tile_idx_local = list(tile_idx_local)
lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True))
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
llvm.store(e, reg_ptr)
def store_tiled(self, ref, swizzle: int | None):
match self.layout:

View File

@ -485,12 +485,20 @@ def get_packed_shape(strides, shape):
class WGMMALayoutTest(TestCase):
@parameterized.named_parameters(("f32", jnp.float32), ("f16", jnp.float16))
def test_store_untiled(self, dtype):
@parameterized.product(dtype=[jnp.float16, jnp.float32],
tiled_layout=[False, True],
transposed_smem=[False, True])
def test_store_untiled(self, dtype, tiled_layout, transposed_smem):
def kernel(ctx, out, _):
del ctx
iota_tensor(64, 64, dtype).store_untiled(out)
if transposed_smem:
out = memref_transpose(out, (1, 0))
iota_tensor(64, 64, dtype, tiled_layout=tiled_layout).store_untiled(
out, vector_store=not transposed_smem
)
expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64)
if transposed_smem:
expected = expected.T
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
)()