mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[mgpu] Non-vector untiled stores for tiling layouts.
Useful for storing in memrefs where the minormost stride is >1. PiperOrigin-RevId: 733551038
This commit is contained in:
parent
766315f791
commit
51719a1afe
@ -1441,19 +1441,27 @@ class FragmentedArray:
|
||||
if create_array:
|
||||
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
|
||||
|
||||
def store_untiled(self, ref: ir.Value):
|
||||
def store_untiled(self, ref: ir.Value, *, vector_store: bool = True):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
raise ValueError(ref)
|
||||
|
||||
def vs_unsupported():
|
||||
if not vector_store:
|
||||
raise NotImplementedError(
|
||||
f"Can't use non-vector stores with layout {self.layout}"
|
||||
)
|
||||
|
||||
match self.layout:
|
||||
case WGMMAFragLayout():
|
||||
self._store_untiled_wgmma(ref)
|
||||
case WGSplatFragLayout():
|
||||
vs_unsupported()
|
||||
self._store_untiled_splat(ref)
|
||||
case WGStridedFragLayout():
|
||||
vs_unsupported()
|
||||
self._store_untiled_wg_strided(ref)
|
||||
case TiledLayout():
|
||||
self._store_untiled_tiled(ref)
|
||||
self._store_untiled_tiled(ref, vector_store=vector_store)
|
||||
case _:
|
||||
raise NotImplementedError(self.layout)
|
||||
|
||||
@ -1520,7 +1528,7 @@ class FragmentedArray:
|
||||
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
|
||||
memref.store(value, ref, [row, col])
|
||||
|
||||
def _store_untiled_tiled(self, ref: ir.Value):
|
||||
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
|
||||
"""Stores an array with a tiled layout. Not optimized at the moment."""
|
||||
if utils.bitwidth(self.mlir_dtype) < 8:
|
||||
raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})")
|
||||
@ -1528,7 +1536,7 @@ class FragmentedArray:
|
||||
layout = self.layout
|
||||
assert isinstance(layout, TiledLayout)
|
||||
ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset()
|
||||
if ref_strides[layout.vector_dim] != 1:
|
||||
if vector_store and ref_strides[layout.vector_dim] != 1:
|
||||
raise NotImplementedError(
|
||||
"Can't use vector stores with non-unit minormost stride"
|
||||
)
|
||||
@ -1549,9 +1557,21 @@ class FragmentedArray:
|
||||
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
|
||||
# All warp tile offsets are static and can be fused into the store.
|
||||
for tile_idx, reg in np.ndenumerate(self.registers):
|
||||
lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True))
|
||||
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
|
||||
llvm.store(reg, reg_ptr)
|
||||
if vector_store:
|
||||
elems = [reg]
|
||||
else:
|
||||
index = ir.IndexType.get()
|
||||
elems = [
|
||||
vector.extractelement(reg, position=c(i, index))
|
||||
for i in range(ir.VectorType(reg.type).shape[0])
|
||||
]
|
||||
for i, e in enumerate(elems):
|
||||
tile_idx_local = list(tile_idx)
|
||||
tile_idx_local[layout.vector_dim] += i
|
||||
tile_idx_local = list(tile_idx_local)
|
||||
lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True))
|
||||
reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype)
|
||||
llvm.store(e, reg_ptr)
|
||||
|
||||
def store_tiled(self, ref, swizzle: int | None):
|
||||
match self.layout:
|
||||
|
@ -485,12 +485,20 @@ def get_packed_shape(strides, shape):
|
||||
|
||||
class WGMMALayoutTest(TestCase):
|
||||
|
||||
@parameterized.named_parameters(("f32", jnp.float32), ("f16", jnp.float16))
|
||||
def test_store_untiled(self, dtype):
|
||||
@parameterized.product(dtype=[jnp.float16, jnp.float32],
|
||||
tiled_layout=[False, True],
|
||||
transposed_smem=[False, True])
|
||||
def test_store_untiled(self, dtype, tiled_layout, transposed_smem):
|
||||
def kernel(ctx, out, _):
|
||||
del ctx
|
||||
iota_tensor(64, 64, dtype).store_untiled(out)
|
||||
if transposed_smem:
|
||||
out = memref_transpose(out, (1, 0))
|
||||
iota_tensor(64, 64, dtype, tiled_layout=tiled_layout).store_untiled(
|
||||
out, vector_store=not transposed_smem
|
||||
)
|
||||
expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64)
|
||||
if transposed_smem:
|
||||
expected = expected.T
|
||||
iota = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
|
||||
)()
|
||||
|
Loading…
x
Reference in New Issue
Block a user