diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 8b8fdaceb..5daed8416 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -387,7 +387,21 @@ class WGMMARowFragLayout: """[m] matrix, where m % 64 == 0.""" def thread_idxs(self, shape): - raise NotImplementedError + index = ir.IndexType.get() + assert len(shape) == 1 + assert shape[0] % 64 == 0 + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) + warp_idx = arith.divui(tid_wg, c(32, index)) + lane_id = arith.remui(tid_wg, c(32, index)) + row_base = arith.addi( + arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index)) + ) + + for row_group in range(0, shape[0], 64): + for row_subgroup in (0, 8): + row = arith.addi(row_base, c(row_group + row_subgroup, index)) + yield (row,) @dataclasses.dataclass(frozen=True) @@ -660,6 +674,31 @@ class FragmentedArray: vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) + @classmethod + def load_wgmma_row( + cls, + ref: ir.Value, + *, + is_signed: bool | None = None, + ): + if not ir.MemRefType.isinstance(ref.type): + raise TypeError(ref.type) + + ref_ty = ir.MemRefType(ref.type) + shape = tuple(ref_ty.shape) + if len(shape) != 1: + raise ValueError("WGMMARowFragLayout requires a 1D shape") + if shape[0] % 64: + raise ValueError( + "WGMMARowFragLayout requires shape[0] to be a multiple of 64" + ) + + layout = WGMMARowFragLayout() + registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)] + registers = np.array(registers).reshape(-1, 2) + return cls(_registers=registers, _layout=layout, _is_signed=is_signed) + + @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) @@ -1743,6 +1782,8 @@ class FragmentedArray: ) match self.layout: + case WGMMARowFragLayout(): + self._store_untiled_wgmma_row(ref) case WGSplatFragLayout(): vs_unsupported() self._store_untiled_splat(ref) @@ -1789,6 +1830,23 @@ class FragmentedArray: for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) + def _store_untiled_wgmma_row(self, ref: ir.Value): + """Stores an array with a WGMMA row layout.""" + assert self.layout == WGMMA_ROW_LAYOUT + index = ir.IndexType.get() + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + + is_first = arith.cmpi( + arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index) + ) + # Consecutive groups of 4 threads hold the same value in this layout, + # therefore we only need to transfer data from one of them. + with utils.when(is_first): + for (idx,), value in zip( + self.layout.thread_idxs(self.shape), self.registers.flatten() + ): + memref.store(value, ref, [idx]) + def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): """Stores an array with a tiled layout. Not optimized at the moment.""" if utils.bitwidth(self.mlir_dtype) < 8: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1fcd68641..e7bd7fad3 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1946,6 +1946,21 @@ class FragmentedArrayTest(TestCase): )(inp) np.testing.assert_array_equal(inp, result) + @parameterized.product(in_shape=((128,), (64,))) + def test_wgmma_row_load_store_with_layout(self, in_shape): + def kernel(ctx, *args): + gmem_input, gmem_output, (smem_input, smem_output) = args + copy(gmem_input, smem_input) + t = mgpu.FragmentedArray.load_wgmma_row(smem_input) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(inp, result) + def test_warp_tree_reduce(self): def kernel(ctx, out, *_): del ctx