mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Support reads/writes from SMEM to WGMMARowFragLayout arrays.
PiperOrigin-RevId: 738121106
This commit is contained in:
parent
76d9890bb7
commit
54691b125a
@ -387,7 +387,21 @@ class WGMMARowFragLayout:
|
||||
"""[m] matrix, where m % 64 == 0."""
|
||||
|
||||
def thread_idxs(self, shape):
|
||||
raise NotImplementedError
|
||||
index = ir.IndexType.get()
|
||||
assert len(shape) == 1
|
||||
assert shape[0] % 64 == 0
|
||||
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
|
||||
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
|
||||
warp_idx = arith.divui(tid_wg, c(32, index))
|
||||
lane_id = arith.remui(tid_wg, c(32, index))
|
||||
row_base = arith.addi(
|
||||
arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index))
|
||||
)
|
||||
|
||||
for row_group in range(0, shape[0], 64):
|
||||
for row_subgroup in (0, 8):
|
||||
row = arith.addi(row_base, c(row_group + row_subgroup, index))
|
||||
yield (row,)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -660,6 +674,31 @@ class FragmentedArray:
|
||||
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
|
||||
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)
|
||||
|
||||
@classmethod
|
||||
def load_wgmma_row(
|
||||
cls,
|
||||
ref: ir.Value,
|
||||
*,
|
||||
is_signed: bool | None = None,
|
||||
):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
raise TypeError(ref.type)
|
||||
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
shape = tuple(ref_ty.shape)
|
||||
if len(shape) != 1:
|
||||
raise ValueError("WGMMARowFragLayout requires a 1D shape")
|
||||
if shape[0] % 64:
|
||||
raise ValueError(
|
||||
"WGMMARowFragLayout requires shape[0] to be a multiple of 64"
|
||||
)
|
||||
|
||||
layout = WGMMARowFragLayout()
|
||||
registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)]
|
||||
registers = np.array(registers).reshape(-1, 2)
|
||||
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
|
||||
|
||||
|
||||
@classmethod
|
||||
def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None):
|
||||
layout = layout or WGSplatFragLayout(shape)
|
||||
@ -1743,6 +1782,8 @@ class FragmentedArray:
|
||||
)
|
||||
|
||||
match self.layout:
|
||||
case WGMMARowFragLayout():
|
||||
self._store_untiled_wgmma_row(ref)
|
||||
case WGSplatFragLayout():
|
||||
vs_unsupported()
|
||||
self._store_untiled_splat(ref)
|
||||
@ -1789,6 +1830,23 @@ class FragmentedArray:
|
||||
for idx, reg in zip(idxs, self.registers.flat):
|
||||
vector.store(reg, ref_, idx)
|
||||
|
||||
def _store_untiled_wgmma_row(self, ref: ir.Value):
|
||||
"""Stores an array with a WGMMA row layout."""
|
||||
assert self.layout == WGMMA_ROW_LAYOUT
|
||||
index = ir.IndexType.get()
|
||||
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
|
||||
|
||||
is_first = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index)
|
||||
)
|
||||
# Consecutive groups of 4 threads hold the same value in this layout,
|
||||
# therefore we only need to transfer data from one of them.
|
||||
with utils.when(is_first):
|
||||
for (idx,), value in zip(
|
||||
self.layout.thread_idxs(self.shape), self.registers.flatten()
|
||||
):
|
||||
memref.store(value, ref, [idx])
|
||||
|
||||
def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True):
|
||||
"""Stores an array with a tiled layout. Not optimized at the moment."""
|
||||
if utils.bitwidth(self.mlir_dtype) < 8:
|
||||
|
@ -1946,6 +1946,21 @@ class FragmentedArrayTest(TestCase):
|
||||
)(inp)
|
||||
np.testing.assert_array_equal(inp, result)
|
||||
|
||||
@parameterized.product(in_shape=((128,), (64,)))
|
||||
def test_wgmma_row_load_store_with_layout(self, in_shape):
|
||||
def kernel(ctx, *args):
|
||||
gmem_input, gmem_output, (smem_input, smem_output) = args
|
||||
copy(gmem_input, smem_input)
|
||||
t = mgpu.FragmentedArray.load_wgmma_row(smem_input)
|
||||
t.store_untiled(smem_output)
|
||||
copy(smem_output, gmem_output)
|
||||
|
||||
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
|
||||
result = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
|
||||
)(inp)
|
||||
np.testing.assert_array_equal(inp, result)
|
||||
|
||||
def test_warp_tree_reduce(self):
|
||||
def kernel(ctx, out, *_):
|
||||
del ctx
|
||||
|
Loading…
x
Reference in New Issue
Block a user