From d4bd2570ae32fe9c7329520c8d768b042910bc77 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 18 Mar 2025 04:47:04 -0700 Subject: [PATCH] [Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts PiperOrigin-RevId: 737956598 --- .../mosaic/gpu/fragmented_array.py | 188 +++++++++++++----- tests/mosaic/gpu_test.py | 79 +++++--- 2 files changed, 190 insertions(+), 77 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index dc5ad48c4..ded17d5d4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -526,6 +526,18 @@ WGMMA_LAYOUT_UPCAST_2X = TiledLayout( lane_dims=(-4, -2, -3), vector_dim=-1, ) +# This layout should be used when upcasting 4-bit elements to 16-bit, for the +# purpose of passing them into WGMMA later. The core matrices stored by a warp +# are 8x32, because each of the 4 threads in a row holds 8 elements in a single +# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each +# group of 4 threads in order (as opposed to the swapping between 1 and 2, +# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does). +WGMMA_LAYOUT_UPCAST_4X = TiledLayout( + Tiling(((64, 32), (16, 32), (8, 32), (8,))), + warp_dim=-7, + lane_dims=(-3, -2), + vector_dim=-1, +) # This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8 # submatrix in the following way (we only show the first 4 rows for brevity): # @@ -739,58 +751,132 @@ class FragmentedArray: _layout=new_layout, _is_signed=self.is_signed, ) - if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 16 == 0: - if ( - self.layout == WGMMA_LAYOUT_UPCAST_2X - and new_layout == WGMMA_LAYOUT - and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) in {8, 16} - ): - assert shape[1] % 16 == 0 # Should be implied by the layout - new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) - is_even = arith.cmpi( - arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0) + if ( + self.layout == WGMMA_LAYOUT_UPCAST_2X + and new_layout == WGMMA_LAYOUT + and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16 + ): + assert shape[1] % 16 == 0 # Should be implied by the layout + new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) + is_even = arith.cmpi( + arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0) + ) + registers = self.registers + if dtype_bitwidth == 4: + if registers.shape[1] % 2: + raise NotImplementedError( + "This relayout implementation requires an even number of column" + " tiles (to pack pairs of them for efficiency)" + ) + # We pair up the consecutive column tiles, so each register is 32-bit. + # If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout, + # LLVM will realize that the paired up vectors actually came from the + # same 32-bit register and it will become a no-op. + col_minor_registers = np.moveaxis(registers, 1, -1) + flat_registers = [ + utils.vector_concat((l, h)) + for l, h in zip( + col_minor_registers.flat[::2], col_minor_registers.flat[1::2] + ) + ] + registers = np.asarray(flat_registers, dtype=object).reshape( + *col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2 ) - for idx, reg in np.ndenumerate(self.registers): - assert ir.VectorType(reg.type).shape == [4] - if dtype_bitwidth == 16: - # A single vector is 64-bits, but shuffles are only 32-bit wide. - # We only shuffle the half that needs to go to other thread. - low = utils.vector_slice(reg, slice(0, 2)) - high = utils.vector_slice(reg, slice(2, 4)) - to_exchange = arith.select(is_even, high, low) - # Exchange values between even and odd threads. - exchanged = utils.shfl_bfly(to_exchange, 1) - low = arith.select(is_even, low, exchanged) - high = arith.select(is_even, exchanged, high) - elif dtype_bitwidth == 8: - # The vector is 32-bits, so we just shuffle the whole thing and - # use prmt to blend it with the local register. - exchanged = utils.shfl_bfly(reg, 1) - # Consider lanes 0 and 1, because the situation is symmetric for - # each pair. If we feed reg[lane] and exchanged[lane] (which is - # really the same as reg of the other lane) to prmt, we can index - # the elements of the result using the following indices: - # reg[0]: 0 1 2 3 reg[1]: 8 9 10 11 - # prmt[0]: 0 1 2 3 4 5 6 7 - # prmt[1]: 4 5 6 7 0 1 2 3 - # The expected outputs and their respective permutations are: - # out[0]: 0 1 8 9 out[1]: 2 3 10 11 - # prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3 - # Note that the patterns still need to be flipped, since we listed - # bytes with LSB on the left, which is the opposite of how the - # numeric constants are spelled in Python (LSB on the right). - perm = arith.select(is_even, c(0x5410), c(0x3276)) - blend = utils.prmt(reg, exchanged, perm) - low = utils.vector_slice(blend, slice(0, 2)) - high = utils.vector_slice(blend, slice(2, 4)) - else: - raise NotImplementedError(dtype_bitwidth) + registers = np.moveaxis(registers, -1, 1) + for idx, reg in np.ndenumerate(registers): + if dtype_bitwidth == 16: + assert reg.type.shape == [4] + # A single vector is 64-bits, but shuffles are only 32-bit wide. + # We only shuffle the half that needs to go to other thread. + low = utils.vector_slice(reg, slice(0, 2)) + high = utils.vector_slice(reg, slice(2, 4)) + to_exchange = arith.select(is_even, high, low) + # Exchange values between even and odd threads. + exchanged = utils.shfl_bfly(to_exchange, 1) + low = arith.select(is_even, low, exchanged) + high = arith.select(is_even, exchanged, high) new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high - assert all(r is not None for r in new_registers) - return FragmentedArray( - _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, - ) + elif dtype_bitwidth == 8: + assert reg.type.shape == [4] + # The vector is 32-bits, so we just shuffle the whole thing and + # use prmt to blend it with the local register. + exchanged = utils.shfl_bfly(reg, 1) + # Consider lanes 0 and 1, because the situation is symmetric for + # each pair. If we feed reg[lane] and exchanged[lane] (which is + # really the same as reg of the other lane) to prmt, we can index + # the elements of the result using the following indices: + # reg[0]: 0 1 2 3 reg[1]: 8 9 10 11 + # prmt[0]: 0 1 2 3 4 5 6 7 + # prmt[1]: 4 5 6 7 0 1 2 3 + # The expected outputs and their respective permutations are: + # out[0]: 0 1 8 9 out[1]: 2 3 10 11 + # prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3 + # Note that the patterns still need to be flipped, since we listed + # bytes with LSB on the left, which is the opposite of how the + # numeric constants are spelled in Python (LSB on the right). + perm = arith.select(is_even, c(0x5410), c(0x3276)) + blend = utils.prmt(reg, exchanged, perm) + for i in range(2): + reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2)) + new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg + else: + assert dtype_bitwidth == 4 + assert reg.type.shape == [8] # We paired up the registers above. + exchanged = utils.shfl_bfly(reg, 1) + # See comment above for a more complete explanation. + # reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27 + # prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7-- + # prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3-- + # The expected outputs and their respective permutations are: + # out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27 + # prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3-- + perm = arith.select(is_even, c(0x6240), c(0x3715)) + blend = utils.prmt(reg, exchanged, perm) + for i in range(4): + reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2)) + new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg + assert all(r is not None for r in new_registers) + return FragmentedArray( + _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, + ) + if ( + self.layout == WGMMA_LAYOUT_UPCAST_4X + and new_layout == WGMMA_LAYOUT_UPCAST_2X + and utils.bitwidth(self.mlir_dtype) == 4 + ): + assert shape[0] % 64 == 0 # Should be implied by the layout + assert shape[1] % 32 == 0 # Should be implied by the layout + new_registers = np.empty(new_layout.registers_shape(shape), dtype=object) + i32 = ir.IntegerType.get_signless(32) + c = lambda x: arith.constant(i32, x) + is_01 = arith.cmpi( + arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2) + ) + for idx, reg in np.ndenumerate(self.registers): + assert ir.VectorType(reg.type).shape == [8] + # The vector is 32-bits, so we just shuffle the whole thing and + # use prmt to blend it with the local register. + exchanged = utils.shfl_bfly(reg, 2) + # See comments above for conventions. Here we exchange data between + # threads with lane index related by flipping 2nd bit (e.g. 0 and 2). + # reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23 + # prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7-- + # prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3-- + # The expected outputs and their respective permutations are: + # out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23 + # prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3-- + perm = arith.select(is_01, c(0x5410), c(0x3276)) + blend = utils.prmt(reg, exchanged, perm) + for i in range(2): + reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4)) + new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg + assert all(r is not None for r in new_registers) + return FragmentedArray( + _registers=new_registers, _layout=new_layout, _is_signed=self.is_signed, + ) + if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT: + return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout) if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError( f"Cannot convert from {self.layout} to {new_layout}" @@ -1288,7 +1374,9 @@ class FragmentedArray: int_ty = ir.IntegerType.get_signless(group_size * 4) while vector_len - offset >= group_size: reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) - reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty)) + reg_slice_int = utils.bitcast(reg_slice, int_ty) + if int_ty != i32: + reg_slice_int = arith.extsi(i32, reg_slice_int) reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) out_int_regs.extend( upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 91644be5c..bc56f21d0 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -14,6 +14,7 @@ # ============================================================================== from collections.abc import Sequence +import contextlib import dataclasses import enum import itertools @@ -83,6 +84,20 @@ def mlir_sum(elems): return total +@contextlib.contextmanager +def get_sass(): + prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) + os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" + try: + with jtu.capture_stdout() as output: + yield output + finally: + if prev_dump is not None: + os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump + else: + del os.environ["MOSAIC_GPU_DUMP_SASS"] + + def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): index = ir.IndexType.get() thread_id = gpu.thread_id(gpu.Dimension.x) @@ -542,7 +557,11 @@ class WGMMALayoutTest(TestCase): (jnp.int8, jnp.bfloat16), (jnp.int4, jnp.bfloat16), ), - layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X), + layout=( + fa.WGMMA_LAYOUT, + fa.WGMMA_LAYOUT_UPCAST_2X, + fa.WGMMA_LAYOUT_UPCAST_4X, + ), ) def test_optimized_conversion(self, jax_dtype_from_to, layout): jax_dtype_from, jax_dtype_to = jax_dtype_from_to @@ -2194,19 +2213,11 @@ class LayoutTest(TestCase): .transpose(0, 2, 1, 3) ) - prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) - os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" - try: - with jtu.capture_stdout() as get_sass: - iota = mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), expected, expected, - [expected, expected, mgpu.TMABarrier()], - )(expected) - finally: - if prev_dump is not None: - os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump - else: - del os.environ["MOSAIC_GPU_DUMP_SASS"] + with get_sass() as sass: + iota = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, + [expected, expected, mgpu.TMABarrier()], + )(expected) np.testing.assert_array_equal(iota, expected) # Verify that we don't use too many registers for the transfers. @@ -2219,7 +2230,7 @@ class LayoutTest(TestCase): expected_regs //= 2 for instr in ("STS", "LDS"): with self.subTest(instr + " count"): - addrs = re.findall(instr + r".* \[(.*)\]", get_sass()) + addrs = re.findall(instr + r".* \[(.*)\]", sass()) def get_reg(addr): if (pos := addr.find("+")) != -1: return addr[:pos] @@ -2294,30 +2305,38 @@ class LayoutTest(TestCase): )(x) np.testing.assert_array_equal(y, y_ref) - @parameterized.product( - upcast_before_layout_change=[True, False], + @parameterized.parameters( + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int8, 1), + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int16, 1), + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, jnp.int4, jnp.int4, 1), + (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5), + (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2), ) - def test_upcast_to_wgmma(self, upcast_before_layout_change): - in_dtype = jnp.dtype(jnp.int8) + def test_upcast_to_wgmma( + self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg + ): + in_dtype = jnp.dtype(in_dtype) out_dtype = jnp.dtype(jnp.int16) + out_dtype_mlir = utils.dtype_to_ir_type(out_dtype) swizzle = 128 in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits in_tiling = (8, in_col_tiling) out_col_tiling = swizzle // out_dtype.itemsize out_tiling = (8, out_col_tiling) m, n = 128, in_col_tiling * 2 + regs_per_thread = None def kernel(ctx, in_, out, smems): + nonlocal regs_per_thread smem_in, smem_out, barrier = smems ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) barrier.wait() t = mgpu.FragmentedArray.load_tiled( - smem_in, swizzle=swizzle, is_signed=True, layout=fa.WGMMA_LAYOUT_UPCAST_2X + smem_in, swizzle=swizzle, is_signed=True, layout=start_layout ) - if upcast_before_layout_change: - t = t.astype(ir.IntegerType.get_signless(16), is_signed=True) - t = t.to_layout(fa.WGMMA_LAYOUT) - if not upcast_before_layout_change: - t = t.astype(ir.IntegerType.get_signless(16), is_signed=True) + regs_per_thread = t.registers.size + t = t.astype(utils.dtype_to_ir_type(cast_dtype), is_signed=True) + t = t.to_layout(end_layout) + t = t.astype(out_dtype_mlir, is_signed=True) t.store_tiled(smem_out, swizzle=swizzle) mgpu.commit_shared() ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) @@ -2326,14 +2345,20 @@ class LayoutTest(TestCase): return x.reshape( x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1] ).transpose(0, 2, 1, 3) - x = jax.random.randint(jax.random.key(42), (m, n), -128, 127, dtype=in_dtype) + in_iinfo = jnp.iinfo(in_dtype) + x = jax.random.randint( + jax.random.key(42), (m, n), in_iinfo.min, in_iinfo.max, dtype=jnp.int32 + ).astype(in_dtype) xt = tile(x, in_tiling) y = x.astype(out_dtype) yt = tile(y, out_tiling) f = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()], ) - np.testing.assert_array_equal(f(xt), yt) + with get_sass() as sass: + yt_kernel = f(xt) + np.testing.assert_array_equal(yt_kernel, yt) + self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg) @dataclasses.dataclass(frozen=True)