From bb96226dd8a3b4b92431edf2ca75bc82741b4be9 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 28 Feb 2025 05:40:39 -0800 Subject: [PATCH] [Mosaic GPU] Add support for small RHS tile sizes in WGMMA This is useful for more fine-grained autotuning and can help avoid wave quantization effects. PiperOrigin-RevId: 732105219 --- jax/experimental/mosaic/gpu/tcgen05.py | 19 +-- jax/experimental/mosaic/gpu/wgmma.py | 166 +++++++++++++++++-------- tests/mosaic/BUILD | 2 +- tests/mosaic/gpu_test.py | 78 +++++++++--- 4 files changed, 188 insertions(+), 77 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 400ab965c..c8e12a72f 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -106,6 +106,7 @@ def mma( raise ValueError(f"B must be a memref, got: {b.type}") if a_swizzle != b_swizzle: raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}") + swizzle = a_swizzle if isinstance(accumulate, bool): accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate) num_cta = 2 if collective else 1 @@ -114,11 +115,11 @@ def mma( a_desc_base, b_desc_base, (m, k, n), - (m_mem_tiling, kn_mem_tiling), + (m_mem_tiling, n_mem_tiling), element_type, mma_params, - a_k_byte_stride, - b_k_byte_stride, + a_k_group_stride, + b_k_group_stride, ) = _wgmma._validate_mma( a, b, @@ -127,7 +128,7 @@ def mma( ) # The sizes of instruction we'll be using - k_instr_tiling = kn_mem_tiling + k_instr_tiling = swizzle // utils.bytewidth(element_type) if (m_instr_tiling := d.layout.elements_in_tile[0]) != m_mem_tiling: raise ValueError( f"A's row tiling must be equal to {m_instr_tiling} (inferred from" @@ -143,14 +144,14 @@ def mma( raise NotImplementedError("The only supported N larger than 256 is 512") a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset() - a_m_byte_stride = a_strides[0] * utils.bytewidth(element_type) + a_m_group_stride = a_strides[0] * utils.bytewidth(element_type) b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset() - b_n_byte_stride = b_strides[1] * utils.bytewidth(element_type) + b_n_group_stride = b_strides[1] * utils.bytewidth(element_type) groups_k = k // k_instr_tiling groups_m = m // m_instr_tiling groups_n = n // n_instr_tiling - n_mem_tiles_per_instr = n_instr_tiling // kn_mem_tiling + n_mem_tiles_per_instr = n_instr_tiling // n_mem_tiling # TODO(apaszke): Verify that the cluster shape matches the expectation of # collective MMA. @@ -162,9 +163,9 @@ def mma( true = arith.constant(ir.IntegerType.get_signless(1), 1) for mi, ni, ki in np.ndindex(groups_m, groups_n, groups_k): - a_offset = mi * a_m_byte_stride + ki * a_k_byte_stride + a_offset = mi * a_m_group_stride + ki * a_k_group_stride a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64)) - b_offset = ni * n_mem_tiles_per_instr * b_n_byte_stride + ki * b_k_byte_stride + b_offset = ni * n_mem_tiles_per_instr * b_n_group_stride + ki * b_k_group_stride b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64)) if groups_m != 1: raise NotImplementedError("D needs to be sliced") diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index d58bd902d..0ddff1f27 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -345,90 +345,152 @@ def _validate_mma( if element_type not in supported_types: raise ValueError(a_element_type) element_bytewidth = bytewidth(element_type) - kn_tiling = swizzle // element_bytewidth + swizzle_elems = swizzle // element_bytewidth # Verify the shape and strides of B are as expected. - k_tiles, n_tiles, k_tiling, n_tiling = b_ty.shape - if k_tiling != kn_tiling: - raise ValueError(b_ty.shape) - # Note that while this technically allows n to be smaller than kn_tile, - # the stride checks above will still enforce that the memory region is padded. - # It might be possible to relax that requirement, but I haven't tested it. - if n_tiling > kn_tiling and n_tiling % kn_tiling: - raise ValueError(n_tiling, kn_tiling) - k = k_tiles * kn_tiling + b_k_tiles, n_tiles, b_k_tiling, n_tiling = b_ty.shape + k = b_k_tiles * b_k_tiling n = n_tiles * n_tiling b_strides, _ = b_ty.get_strides_and_offset() b_byte_strides = [s * element_bytewidth for s in b_strides] b_k_byte_stride, b_n_byte_stride, *b_tile_byte_strides = b_byte_strides - # TODO(apaszke): Relax tiling here! But make sure that the space between N tiles is same. - if b_byte_strides[1] != swizzle * kn_tiling: - raise ValueError(b_byte_strides) - if b_tile_byte_strides == [swizzle, element_bytewidth]: + if ( + b_byte_strides[1] != n_tiling * b_k_tiling * element_bytewidth + and n_tiles != 1 # When there's only one tile, we never jump between them + ): + raise ValueError("B tiles must be contiguous along the N dimension") + if b_tile_byte_strides == [swizzle, element_bytewidth]: # N-fastest b_order = WGMMALayout.ROW_MAJOR - elif b_tile_byte_strides == [element_bytewidth, swizzle]: + # This first case (n_tiles == 1) is to allow the somewhat weird case of + # loading a small amount of N-fastest data, that needs to be padded to a + # larger tile due to swizzle. In this case we allow slicing the big tile + # before WGMMA to avoid unnecessary compute on padding. + if n_tiles == 1: + if n_tiling % 8: + raise ValueError("N tile size must be a multiple of 8") + elif n_tiling != swizzle_elems: + raise ValueError( + "Row major RHS (N-fastest) requires the N tile size to be equal to" + f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}" + ) + if b_k_tiling not in {32 // element_bytewidth, swizzle_elems}: + raise ValueError( + "Row major RHS (N-fastest) requires the K tile size to be either" + f" the swizzle tile size ({swizzle_elems}) or 32 bytes" + f" ({32 // element_bytewidth}), but got {b_k_tiling}" + ) + elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest b_order = WGMMALayout.COL_MAJOR + if b_k_tiling != swizzle_elems: + raise ValueError( + "Column major RHS (K-fastest) requires the K tile size to be equal" + f" to the swizzle tile size ({swizzle_elems}), but got {b_k_tiling}" + ) + # See the explanation in the N-fastest case when n_tiles == 1. + if n_tiles == 1: + if n_tiling % 8: + raise ValueError("N tile size must be a multiple of 8") + elif n_tiling not in {8, swizzle_elems}: + raise ValueError( + "Column major RHS (K-fastest) requires the N tile size to be either" + f" to the swizzle tile size ({swizzle_elems}) or 8, but got {n_tiling}" + ) else: raise ValueError(b_byte_strides) # Verify the shape and strides of A are as expected. if not a_in_smem: m = a_shape[0] - a_order = m_tiling = a_m_byte_stride = None + a_order = m_tiling = None else: a_ty = ir.MemRefType(a.type) - m_tiles, k_tiles, m_tiling, k_tiling = a_ty.shape + m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape m = m_tiles * m_tiling - if k_tiling != kn_tiling or k_tiles * k_tiling != k: + if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k: raise ValueError(a_ty.shape) a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset() a_byte_strides = [s * element_bytewidth for s in a_strides] - a_m_byte_stride = a_byte_strides[0] if a_byte_strides[2:] == [swizzle, element_bytewidth]: a_order = WGMMALayout.ROW_MAJOR elif a_byte_strides[2:] == [element_bytewidth, swizzle]: a_order = WGMMALayout.COL_MAJOR else: raise ValueError(a_byte_strides) - if a_order != WGMMALayout.ROW_MAJOR and m_tiling != kn_tiling: + if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems: # Not sure what the layout is like, since the tiles aren't square. raise NotImplementedError b_k_fastest = b_order == WGMMALayout.COL_MAJOR a_k_fastest = a_order == WGMMALayout.ROW_MAJOR - - # Here "leading" refers to the fastest changing dimension. - # Leading byte offset (LBO) - # K-fastest: ignored - # MN-fastest: stride between consecutive that share the same K coordinate. - # Stride byte offset (SBO) - # K-fastest: offset from one swizzle atom to the next - # MN-fastest: offset from one swizzle atom to the next # This is the number of rows until consecutive repeats of the swizzle pattern. swizzle_pattern_rows = swizzle // 16 + # A swizzle atom is a 2D matrix with the dimensions below. + swizzle_atom_bytes = swizzle_pattern_rows * 128 + + # Here "leading" refers to the fastest changing dimension. There are two + # strides we have to define per value: + # Leading byte offset (LBO) + # K-fastest: ignored + # MN-fastest: stride between consecutive swizzle atoms that share the same + # K coordinate. + # Stride byte offset (SBO) + # As far as I can tell this is just the offset between two consecutive + # swizzle atoms along the non-leading dimension. + IGNORED = 0 a_desc_fields = dict( - # TODO(apaszke): a_m_byte_stride works, but is not convincing to me. - # After all MMA can only consume a fixed number of bytes from LHS. - leading_byte_offset=16 if a_k_fastest else a_m_byte_stride, - stride_byte_offset=128 * swizzle_pattern_rows, + # I can't fully explain why WGMMA ignores LBO for A. For a_k_fastest, it + # is documented in the PTX docs, and my best explanation for the other + # case is that the instruction has a fixed shape and so it does not care + # about strides. It's possible that it's an artifact of the fact that we + # use tiling of 64. + leading_byte_offset=IGNORED, + stride_byte_offset=swizzle_atom_bytes, swizzle=swizzle, memory_space=3, ) + # If B is N-fastest, all swizzle atoms within a tile share the same N + # coordinate, so we simply take the stride between consecutive N tiles. + # If B is K-fastest, all swizzle atoms within a tile share the same K + # coordinate, which forces us to lay out the tiles in N-fastest order or else + # they would have uneven strides. b_desc_fields = dict( - leading_byte_offset=16 if b_k_fastest else b_n_byte_stride, - stride_byte_offset=128 * swizzle_pattern_rows, + leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride, + stride_byte_offset=swizzle_atom_bytes, swizzle=swizzle, memory_space=3, ) + # The K strides indicate the stride between the consecutive places where all + # coordinates are 0 except for K being incremented by the instruction width. # If an input is K-fastest, we increment the descriptor by 32 bytes, since # that is the K-width of all MMA instructions. - # TODO(apaszke): I don't have a good explanation for the MN-fastest case yet. + if b_k_fastest: + b_k_wgmma_stride = 32 + b_k_group_stride = b_k_byte_stride # The tile has only one K swizzle atom. + elif b_k_tiling == swizzle_elems: + # When B is N-fastest and we use the large square tiling, the relevant + # slices all fall within the first tile. A single MMA instruction for 16-bit + # types reads a subtile of shape 16x(swizzle bytes), giving us the necessary + # expression. + assert n_tiling == swizzle_elems or n_tiles == 1 + b_k_wgmma_stride = swizzle * 16 + b_k_group_stride = b_k_byte_stride + else: + # If we use the small non-square tiling and N-fastest layout, each tile only + # contains a single swizzle atom with the K coordinate, so we just look up + # the next tile. + b_k_wgmma_stride = b_k_byte_stride + wgmma_in_group = swizzle // 32 + b_k_group_stride = b_k_byte_stride * wgmma_in_group wgmma_params = dict( a_transpose=not a_k_fastest, b_transpose=not b_k_fastest, + # TODO(apaszke): This explanation is quite bad. We should better figure + # out how to do LHS transposes. + # We only support swizzle=128 for M-fastest A. In this case the tile is + # swizzle x 64 (= swizzle elems) and so we just take a quarter of its size. a_k_stride=32 if a_k_fastest else swizzle * 16, - b_k_stride=32 if b_k_fastest else swizzle * 16, + b_k_stride=b_k_wgmma_stride, swizzle=swizzle, element_type=ir.FloatTF32Type.get() if ir.F32Type.isinstance(element_type) @@ -436,13 +498,13 @@ def _validate_mma( ) if not a_in_smem: wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None - a_k_byte_stride = a_desc_base = None + a_k_group_stride = a_desc_base = None else: a_desc_base = create_descriptor( a, **a_desc_fields, const_init=descriptor_const_init ) a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset() - a_k_byte_stride = a_strides[1] * element_bytewidth + a_k_group_stride = a_strides[1] * element_bytewidth b_desc_base = create_descriptor( b, **b_desc_fields, const_init=descriptor_const_init ) @@ -451,11 +513,11 @@ def _validate_mma( a_desc_base, b_desc_base, (m, k, n), - (m_tiling, kn_tiling), + (m_tiling, n_tiling), element_type, wgmma_params, - a_k_byte_stride, - b_k_byte_stride, + a_k_group_stride, + b_k_group_stride, ) @@ -488,11 +550,11 @@ def wgmma( a_desc_base, b_desc_base, (m, k, n), - (m_tiling, kn_tiling), + (m_tiling, _), element_type, wgmma_params, - a_k_byte_stride, - b_k_byte_stride, + a_k_group_stride, + b_k_group_stride, ) = _validate_mma(a, b, swizzle) if n > 256: @@ -505,14 +567,15 @@ def wgmma( ) if a.shape[0] % 64: raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}") - a_m_byte_stride = None + a_m_group_stride = None else: if m_tiling != 64: raise ValueError(f"A must have rows tiled by 64, got: {m_tiling}") a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset() - a_m_byte_stride = a_strides[0] * bytewidth(element_type) + a_m_group_stride = a_strides[0] * bytewidth(element_type) - groups_k = k // kn_tiling + k_group_width = swizzle // bytewidth(element_type) + groups_k = k // k_group_width groups_m = m // 64 expected_acc_shape = (groups_m * 64, n) @@ -530,13 +593,16 @@ def wgmma( for mi in range(groups_m): for ki in range(groups_k): if a_in_regs: - a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tiling : (ki + 1) * kn_tiling] + a_mk = a[ + mi * 64 : (mi + 1) * 64, + ki * k_group_width : (ki + 1) * k_group_width + ] else: a_mk = llvm_add( a_desc_base, - c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64), + c(wgmma_encode(mi * a_m_group_stride + ki * a_k_group_stride), i64), ) - b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64)) + b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64)) new_acc_regs[mi : mi + 1] = wgmma_m64( new_acc_regs[mi : mi + 1], a_mk, b_k, n=n, **wgmma_params ) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 26efd2cd7..4b231939a 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -38,7 +38,7 @@ jax_multiplatform_test( "gpu_h100x2", ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, - shard_count = 8, + shard_count = 16, tags = ["multiaccelerator"], deps = [ "//jax:mosaic_gpu", diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 874cc1276..dcd390c5f 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -652,6 +652,7 @@ class WGMMATest(TestCase): k_steps=(1, 2), swizzle=(32, 64, 128), jax_out_dtype=(jnp.float16, jnp.float32), + small_rhs_tile=(False, True,), ) def test_wgmma_basic( self, @@ -663,6 +664,7 @@ class WGMMATest(TestCase): rhs_transpose, swizzle, jax_out_dtype, + small_rhs_tile, ): if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: raise self.skipTest("Only f16 input is supported for f16 output.") @@ -693,13 +695,18 @@ class WGMMATest(TestCase): k = nk_tile * k_steps assert m % 64 == 0 and n % nk_tile == 0 + small_nk_tile = 8 if rhs_transpose else 16 + rhs_tiling = ( + (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) + ) + def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers = scratch lhs_transform = (mgpu.TileTransform((64, nk_tile)),) if lhs_transpose: assert nk_tile == 64 # Make sure we didn't have to transpose tiling. lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) - rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),) + rhs_transform = (mgpu.TileTransform(rhs_tiling),) if rhs_transpose: rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( @@ -737,10 +744,14 @@ class WGMMATest(TestCase): y_shape = (n, k) if rhs_transpose else (k, n) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype) + rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling scratch_shape = [ - jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype), jax.ShapeDtypeStruct( - (k // nk_tile, n // nk_tile, nk_tile, nk_tile), in_jax_dtype + (m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype + ), + jax.ShapeDtypeStruct( + (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling), + in_jax_dtype, ), mgpu.TMABarrier(2), ] @@ -812,18 +823,23 @@ class WGMMATest(TestCase): @parameterized.product( rhs_transpose=(False, True), swizzle=(32, 64, 128), + n=(8, 16), + small_rhs_tile=(False, True), ) - def test_narrow_n(self, rhs_transpose, swizzle): - m, n, k_steps = 64, 8, 2 - + def test_narrow_n(self, rhs_transpose, swizzle, n, small_rhs_tile): + m, k_steps = 64, 2 bytewidth = 2 nk_tile = swizzle // bytewidth k = nk_tile * k_steps + if small_rhs_tile and not rhs_transpose: + self.skipTest("Small tiles only supported for transposed RHS") + + n_tile = 8 if small_rhs_tile else nk_tile def kernel(ctx, rhs, out, smem): rhs_smem, barrier = smem - gmem_slice = (ds(0, k), ds(0, nk_tile)) - transform = (mgpu.TileTransform((nk_tile, nk_tile)),) + gmem_slice = (ds(0, k), ds(0, max(n_tile, n))) + transform = (mgpu.TileTransform((n_tile, nk_tile)),) if rhs_transpose: gmem_slice = gmem_slice[::-1] transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) @@ -840,8 +856,9 @@ class WGMMATest(TestCase): lhs_regs = iota_tensor(m, k, jnp.float16) if rhs_transpose: rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) - smem_slice = (slice(None), slice(None), slice(None), ds(0, n)) - rhs_smem = memref_slice(rhs_smem, smem_slice) + if not small_rhs_tile: + smem_slice = (slice(None), slice(None), slice(None), ds(0, n)) + rhs_smem = memref_slice(rhs_smem, smem_slice) acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) @@ -852,7 +869,7 @@ class WGMMATest(TestCase): y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) rhs_scratch_shape = jax.ShapeDtypeStruct( - (k_steps, 1, nk_tile, nk_tile), jax_dtype + (k_steps, (n + n_tile - 1) // n_tile, n_tile, nk_tile), jax_dtype ) z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()), @@ -881,6 +898,7 @@ class TCGen05Test(TestCase): n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), + small_rhs_tile=(False, True), ) def test_mma_basic( self, @@ -892,6 +910,7 @@ class TCGen05Test(TestCase): rhs_transpose, in_jax_dtype, out_jax_dtype, + small_rhs_tile, ): if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: raise self.skipTest("Only f16 input is supported for f16 output.") @@ -902,13 +921,16 @@ class TCGen05Test(TestCase): k = nk_tile * k_steps assert m % m_tile == 0 and n % nk_tile == 0 + small_nk_tile = 8 if rhs_transpose else 16 + rhs_tiling = (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile) + def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers, acc = scratch lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),) if lhs_transpose: assert nk_tile == m_tile # Make sure we didn't have to transpose tiling lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) - rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),) + rhs_transform = (mgpu.TileTransform(rhs_tiling),) if rhs_transpose: rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( @@ -950,9 +972,15 @@ class TCGen05Test(TestCase): y_shape = (n, k) if rhs_transpose else (k, n) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling scratch_shape = [ - jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype), - jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype), + jax.ShapeDtypeStruct( + tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype + ), + jax.ShapeDtypeStruct( + (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling), + in_jax_dtype, + ), mgpu.TMABarrier(3), mgpu.TMEM((128, n), out_jax_dtype), ] @@ -973,6 +1001,7 @@ class TCGen05Test(TestCase): n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), + small_rhs_tile=(False, True), ) def test_mma_collective( self, @@ -984,6 +1013,7 @@ class TCGen05Test(TestCase): rhs_transpose, in_jax_dtype, out_jax_dtype, + small_rhs_tile, ): if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: raise self.skipTest("Only f16 input is supported for f16 output.") @@ -997,13 +1027,20 @@ class TCGen05Test(TestCase): assert m % m_tma_tile == 0 and n % nk_tma_tile == 0 index = ir.IndexType.get() + small_nk_tile = 8 if rhs_transpose else 16 + rhs_tiling = ( + (small_nk_tile, nk_tma_tile) + if small_rhs_tile + else (nk_tma_tile, nk_tma_tile) + ) + def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers, acc = scratch lhs_transform = (mgpu.TileTransform((m_tma_tile, nk_tma_tile)),) if lhs_transpose: assert nk_tma_tile == m_tma_tile # Make sure we didn't have to transpose tiling lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) - rhs_transform = (mgpu.TileTransform((nk_tma_tile, nk_tma_tile)),) + rhs_transform = (mgpu.TileTransform(rhs_tiling),) if rhs_transpose: rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) block_id = gpu.cluster_block_id(gpu.Dimension.x) @@ -1053,9 +1090,16 @@ class TCGen05Test(TestCase): y_shape = (n, k) if rhs_transpose else (k, n) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling scratch_shape = [ - jax.ShapeDtypeStruct(tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)), in_jax_dtype), - jax.ShapeDtypeStruct(tile_shape((k, n_block_tile), (nk_tma_tile, nk_tma_tile)), in_jax_dtype), + jax.ShapeDtypeStruct( + tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)), + in_jax_dtype, + ), + jax.ShapeDtypeStruct( + (k // rhs_tiling_t[0], n_block_tile // rhs_tiling_t[1], *rhs_tiling), + in_jax_dtype, + ), mgpu.TMABarrier(3), mgpu.TMEM((128, n), out_jax_dtype, collective=True), ]