[Mosaic GPU] Support reads/writes from SMEM to WGMMARowFragLayout arrays.

PiperOrigin-RevId: 738121106
This commit is contained in:
Gleb Pobudzey 2025-03-18 13:22:10 -07:00 committed by jax authors
parent 76d9890bb7
commit 54691b125a
2 changed files with 74 additions and 1 deletions

View File

@ -387,7 +387,21 @@ class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0."""
def thread_idxs(self, shape):
raise NotImplementedError
index = ir.IndexType.get()
assert len(shape) == 1
assert shape[0] % 64 == 0
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
warp_idx = arith.divui(tid_wg, c(32, index))
lane_id = arith.remui(tid_wg, c(32, index))
row_base = arith.addi(
arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index))
)
for row_group in range(0, shape[0], 64):
for row_subgroup in (0, 8):
row = arith.addi(row_base, c(row_group + row_subgroup, index))
yield (row,)
@dataclasses.dataclass(frozen=True)
@ -660,6 +674,31 @@ class FragmentedArray:
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)
@classmethod
def load_wgmma_row(
cls,
ref: ir.Value,
*,
is_signed: bool | None = None,
):
if not ir.MemRefType.isinstance(ref.type):
raise TypeError(ref.type)
ref_ty = ir.MemRefType(ref.type)
shape = tuple(ref_ty.shape)
if len(shape) != 1:
raise ValueError("WGMMARowFragLayout requires a 1D shape")
if shape[0] % 64:
raise ValueError(
"WGMMARowFragLayout requires shape[0] to be a multiple of 64"
)
layout = WGMMARowFragLayout()
registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)]
registers = np.array(registers).reshape(-1, 2)
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
@classmethod
def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None):
layout = layout or WGSplatFragLayout(shape)
@ -1743,6 +1782,8 @@ class FragmentedArray:
)
match self.layout:
case WGMMARowFragLayout():
self._store_untiled_wgmma_row(ref)
case WGSplatFragLayout():
vs_unsupported()
self._store_untiled_splat(ref)
@ -1789,6 +1830,23 @@ class FragmentedArray:
for idx, reg in zip(idxs, self.registers.flat):
vector.store(reg, ref_, idx)
def _store_untiled_wgmma_row(self, ref: ir.Value):
"""Stores an array with a WGMMA row layout."""
assert self.layout == WGMMA_ROW_LAYOUT
index = ir.IndexType.get()
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
is_first = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index)
)
# Consecutive groups of 4 threads hold the same value in this layout,
# therefore we only need to transfer data from one of them.
with utils.when(is_first):
for (idx,), value in zip(
self.layout.thread_idxs(self.shape), self.registers.flatten()
):
memref.store(value, ref, [idx])
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
"""Stores an array with a tiled layout. Not optimized at the moment."""
if utils.bitwidth(self.mlir_dtype) < 8:

View File

@ -1946,6 +1946,21 @@ class FragmentedArrayTest(TestCase):
)(inp)
np.testing.assert_array_equal(inp, result)
@parameterized.product(in_shape=((128,), (64,)))
def test_wgmma_row_load_store_with_layout(self, in_shape):
def kernel(ctx, *args):
gmem_input, gmem_output, (smem_input, smem_output) = args
copy(gmem_input, smem_input)
t = mgpu.FragmentedArray.load_wgmma_row(smem_input)
t.store_untiled(smem_output)
copy(smem_output, gmem_output)
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
)(inp)
np.testing.assert_array_equal(inp, result)
def test_warp_tree_reduce(self):
def kernel(ctx, out, *_):
del ctx