mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic GPU] Add support for small RHS tile sizes in WGMMA
This is useful for more fine-grained autotuning and can help avoid wave quantization effects. PiperOrigin-RevId: 732105219
This commit is contained in:
parent
1bc36e623b
commit
bb96226dd8
@ -106,6 +106,7 @@ def mma(
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
if a_swizzle != b_swizzle:
|
||||
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
|
||||
swizzle = a_swizzle
|
||||
if isinstance(accumulate, bool):
|
||||
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
|
||||
num_cta = 2 if collective else 1
|
||||
@ -114,11 +115,11 @@ def mma(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_mem_tiling, kn_mem_tiling),
|
||||
(m_mem_tiling, n_mem_tiling),
|
||||
element_type,
|
||||
mma_params,
|
||||
a_k_byte_stride,
|
||||
b_k_byte_stride,
|
||||
a_k_group_stride,
|
||||
b_k_group_stride,
|
||||
) = _wgmma._validate_mma(
|
||||
a,
|
||||
b,
|
||||
@ -127,7 +128,7 @@ def mma(
|
||||
)
|
||||
|
||||
# The sizes of instruction we'll be using
|
||||
k_instr_tiling = kn_mem_tiling
|
||||
k_instr_tiling = swizzle // utils.bytewidth(element_type)
|
||||
if (m_instr_tiling := d.layout.elements_in_tile[0]) != m_mem_tiling:
|
||||
raise ValueError(
|
||||
f"A's row tiling must be equal to {m_instr_tiling} (inferred from"
|
||||
@ -143,14 +144,14 @@ def mma(
|
||||
raise NotImplementedError("The only supported N larger than 256 is 512")
|
||||
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_m_byte_stride = a_strides[0] * utils.bytewidth(element_type)
|
||||
a_m_group_stride = a_strides[0] * utils.bytewidth(element_type)
|
||||
b_strides, _ = ir.MemRefType(b.type).get_strides_and_offset()
|
||||
b_n_byte_stride = b_strides[1] * utils.bytewidth(element_type)
|
||||
b_n_group_stride = b_strides[1] * utils.bytewidth(element_type)
|
||||
|
||||
groups_k = k // k_instr_tiling
|
||||
groups_m = m // m_instr_tiling
|
||||
groups_n = n // n_instr_tiling
|
||||
n_mem_tiles_per_instr = n_instr_tiling // kn_mem_tiling
|
||||
n_mem_tiles_per_instr = n_instr_tiling // n_mem_tiling
|
||||
|
||||
# TODO(apaszke): Verify that the cluster shape matches the expectation of
|
||||
# collective MMA.
|
||||
@ -162,9 +163,9 @@ def mma(
|
||||
|
||||
true = arith.constant(ir.IntegerType.get_signless(1), 1)
|
||||
for mi, ni, ki in np.ndindex(groups_m, groups_n, groups_k):
|
||||
a_offset = mi * a_m_byte_stride + ki * a_k_byte_stride
|
||||
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
||||
a_mk = arith.addi(a_desc_base, utils.c(_wgmma.wgmma_encode(a_offset), i64))
|
||||
b_offset = ni * n_mem_tiles_per_instr * b_n_byte_stride + ki * b_k_byte_stride
|
||||
b_offset = ni * n_mem_tiles_per_instr * b_n_group_stride + ki * b_k_group_stride
|
||||
b_nk = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(b_offset), i64))
|
||||
if groups_m != 1:
|
||||
raise NotImplementedError("D needs to be sliced")
|
||||
|
@ -345,90 +345,152 @@ def _validate_mma(
|
||||
if element_type not in supported_types:
|
||||
raise ValueError(a_element_type)
|
||||
element_bytewidth = bytewidth(element_type)
|
||||
kn_tiling = swizzle // element_bytewidth
|
||||
swizzle_elems = swizzle // element_bytewidth
|
||||
|
||||
# Verify the shape and strides of B are as expected.
|
||||
k_tiles, n_tiles, k_tiling, n_tiling = b_ty.shape
|
||||
if k_tiling != kn_tiling:
|
||||
raise ValueError(b_ty.shape)
|
||||
# Note that while this technically allows n to be smaller than kn_tile,
|
||||
# the stride checks above will still enforce that the memory region is padded.
|
||||
# It might be possible to relax that requirement, but I haven't tested it.
|
||||
if n_tiling > kn_tiling and n_tiling % kn_tiling:
|
||||
raise ValueError(n_tiling, kn_tiling)
|
||||
k = k_tiles * kn_tiling
|
||||
b_k_tiles, n_tiles, b_k_tiling, n_tiling = b_ty.shape
|
||||
k = b_k_tiles * b_k_tiling
|
||||
n = n_tiles * n_tiling
|
||||
|
||||
b_strides, _ = b_ty.get_strides_and_offset()
|
||||
b_byte_strides = [s * element_bytewidth for s in b_strides]
|
||||
b_k_byte_stride, b_n_byte_stride, *b_tile_byte_strides = b_byte_strides
|
||||
# TODO(apaszke): Relax tiling here! But make sure that the space between N tiles is same.
|
||||
if b_byte_strides[1] != swizzle * kn_tiling:
|
||||
raise ValueError(b_byte_strides)
|
||||
if b_tile_byte_strides == [swizzle, element_bytewidth]:
|
||||
if (
|
||||
b_byte_strides[1] != n_tiling * b_k_tiling * element_bytewidth
|
||||
and n_tiles != 1 # When there's only one tile, we never jump between them
|
||||
):
|
||||
raise ValueError("B tiles must be contiguous along the N dimension")
|
||||
if b_tile_byte_strides == [swizzle, element_bytewidth]: # N-fastest
|
||||
b_order = WGMMALayout.ROW_MAJOR
|
||||
elif b_tile_byte_strides == [element_bytewidth, swizzle]:
|
||||
# This first case (n_tiles == 1) is to allow the somewhat weird case of
|
||||
# loading a small amount of N-fastest data, that needs to be padded to a
|
||||
# larger tile due to swizzle. In this case we allow slicing the big tile
|
||||
# before WGMMA to avoid unnecessary compute on padding.
|
||||
if n_tiles == 1:
|
||||
if n_tiling % 8:
|
||||
raise ValueError("N tile size must be a multiple of 8")
|
||||
elif n_tiling != swizzle_elems:
|
||||
raise ValueError(
|
||||
"Row major RHS (N-fastest) requires the N tile size to be equal to"
|
||||
f" the swizzle tile size ({swizzle_elems}), but got {n_tiling}"
|
||||
)
|
||||
if b_k_tiling not in {32 // element_bytewidth, swizzle_elems}:
|
||||
raise ValueError(
|
||||
"Row major RHS (N-fastest) requires the K tile size to be either"
|
||||
f" the swizzle tile size ({swizzle_elems}) or 32 bytes"
|
||||
f" ({32 // element_bytewidth}), but got {b_k_tiling}"
|
||||
)
|
||||
elif b_tile_byte_strides == [element_bytewidth, swizzle]: # K-fastest
|
||||
b_order = WGMMALayout.COL_MAJOR
|
||||
if b_k_tiling != swizzle_elems:
|
||||
raise ValueError(
|
||||
"Column major RHS (K-fastest) requires the K tile size to be equal"
|
||||
f" to the swizzle tile size ({swizzle_elems}), but got {b_k_tiling}"
|
||||
)
|
||||
# See the explanation in the N-fastest case when n_tiles == 1.
|
||||
if n_tiles == 1:
|
||||
if n_tiling % 8:
|
||||
raise ValueError("N tile size must be a multiple of 8")
|
||||
elif n_tiling not in {8, swizzle_elems}:
|
||||
raise ValueError(
|
||||
"Column major RHS (K-fastest) requires the N tile size to be either"
|
||||
f" to the swizzle tile size ({swizzle_elems}) or 8, but got {n_tiling}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(b_byte_strides)
|
||||
|
||||
# Verify the shape and strides of A are as expected.
|
||||
if not a_in_smem:
|
||||
m = a_shape[0]
|
||||
a_order = m_tiling = a_m_byte_stride = None
|
||||
a_order = m_tiling = None
|
||||
else:
|
||||
a_ty = ir.MemRefType(a.type)
|
||||
m_tiles, k_tiles, m_tiling, k_tiling = a_ty.shape
|
||||
m_tiles, a_k_tiles, m_tiling, a_k_tiling = a_ty.shape
|
||||
m = m_tiles * m_tiling
|
||||
if k_tiling != kn_tiling or k_tiles * k_tiling != k:
|
||||
if a_k_tiling != swizzle_elems or a_k_tiles * a_k_tiling != k:
|
||||
raise ValueError(a_ty.shape)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_byte_strides = [s * element_bytewidth for s in a_strides]
|
||||
a_m_byte_stride = a_byte_strides[0]
|
||||
if a_byte_strides[2:] == [swizzle, element_bytewidth]:
|
||||
a_order = WGMMALayout.ROW_MAJOR
|
||||
elif a_byte_strides[2:] == [element_bytewidth, swizzle]:
|
||||
a_order = WGMMALayout.COL_MAJOR
|
||||
else:
|
||||
raise ValueError(a_byte_strides)
|
||||
if a_order != WGMMALayout.ROW_MAJOR and m_tiling != kn_tiling:
|
||||
if a_order != WGMMALayout.ROW_MAJOR and m_tiling != swizzle_elems:
|
||||
# Not sure what the layout is like, since the tiles aren't square.
|
||||
raise NotImplementedError
|
||||
|
||||
b_k_fastest = b_order == WGMMALayout.COL_MAJOR
|
||||
a_k_fastest = a_order == WGMMALayout.ROW_MAJOR
|
||||
|
||||
# Here "leading" refers to the fastest changing dimension.
|
||||
# Leading byte offset (LBO)
|
||||
# K-fastest: ignored
|
||||
# MN-fastest: stride between consecutive that share the same K coordinate.
|
||||
# Stride byte offset (SBO)
|
||||
# K-fastest: offset from one swizzle atom to the next
|
||||
# MN-fastest: offset from one swizzle atom to the next
|
||||
# This is the number of rows until consecutive repeats of the swizzle pattern.
|
||||
swizzle_pattern_rows = swizzle // 16
|
||||
# A swizzle atom is a 2D matrix with the dimensions below.
|
||||
swizzle_atom_bytes = swizzle_pattern_rows * 128
|
||||
|
||||
# Here "leading" refers to the fastest changing dimension. There are two
|
||||
# strides we have to define per value:
|
||||
# Leading byte offset (LBO)
|
||||
# K-fastest: ignored
|
||||
# MN-fastest: stride between consecutive swizzle atoms that share the same
|
||||
# K coordinate.
|
||||
# Stride byte offset (SBO)
|
||||
# As far as I can tell this is just the offset between two consecutive
|
||||
# swizzle atoms along the non-leading dimension.
|
||||
IGNORED = 0
|
||||
a_desc_fields = dict(
|
||||
# TODO(apaszke): a_m_byte_stride works, but is not convincing to me.
|
||||
# After all MMA can only consume a fixed number of bytes from LHS.
|
||||
leading_byte_offset=16 if a_k_fastest else a_m_byte_stride,
|
||||
stride_byte_offset=128 * swizzle_pattern_rows,
|
||||
# I can't fully explain why WGMMA ignores LBO for A. For a_k_fastest, it
|
||||
# is documented in the PTX docs, and my best explanation for the other
|
||||
# case is that the instruction has a fixed shape and so it does not care
|
||||
# about strides. It's possible that it's an artifact of the fact that we
|
||||
# use tiling of 64.
|
||||
leading_byte_offset=IGNORED,
|
||||
stride_byte_offset=swizzle_atom_bytes,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
# If B is N-fastest, all swizzle atoms within a tile share the same N
|
||||
# coordinate, so we simply take the stride between consecutive N tiles.
|
||||
# If B is K-fastest, all swizzle atoms within a tile share the same K
|
||||
# coordinate, which forces us to lay out the tiles in N-fastest order or else
|
||||
# they would have uneven strides.
|
||||
b_desc_fields = dict(
|
||||
leading_byte_offset=16 if b_k_fastest else b_n_byte_stride,
|
||||
stride_byte_offset=128 * swizzle_pattern_rows,
|
||||
leading_byte_offset=IGNORED if b_k_fastest else b_n_byte_stride,
|
||||
stride_byte_offset=swizzle_atom_bytes,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
# The K strides indicate the stride between the consecutive places where all
|
||||
# coordinates are 0 except for K being incremented by the instruction width.
|
||||
# If an input is K-fastest, we increment the descriptor by 32 bytes, since
|
||||
# that is the K-width of all MMA instructions.
|
||||
# TODO(apaszke): I don't have a good explanation for the MN-fastest case yet.
|
||||
if b_k_fastest:
|
||||
b_k_wgmma_stride = 32
|
||||
b_k_group_stride = b_k_byte_stride # The tile has only one K swizzle atom.
|
||||
elif b_k_tiling == swizzle_elems:
|
||||
# When B is N-fastest and we use the large square tiling, the relevant
|
||||
# slices all fall within the first tile. A single MMA instruction for 16-bit
|
||||
# types reads a subtile of shape 16x(swizzle bytes), giving us the necessary
|
||||
# expression.
|
||||
assert n_tiling == swizzle_elems or n_tiles == 1
|
||||
b_k_wgmma_stride = swizzle * 16
|
||||
b_k_group_stride = b_k_byte_stride
|
||||
else:
|
||||
# If we use the small non-square tiling and N-fastest layout, each tile only
|
||||
# contains a single swizzle atom with the K coordinate, so we just look up
|
||||
# the next tile.
|
||||
b_k_wgmma_stride = b_k_byte_stride
|
||||
wgmma_in_group = swizzle // 32
|
||||
b_k_group_stride = b_k_byte_stride * wgmma_in_group
|
||||
wgmma_params = dict(
|
||||
a_transpose=not a_k_fastest,
|
||||
b_transpose=not b_k_fastest,
|
||||
# TODO(apaszke): This explanation is quite bad. We should better figure
|
||||
# out how to do LHS transposes.
|
||||
# We only support swizzle=128 for M-fastest A. In this case the tile is
|
||||
# swizzle x 64 (= swizzle elems) and so we just take a quarter of its size.
|
||||
a_k_stride=32 if a_k_fastest else swizzle * 16,
|
||||
b_k_stride=32 if b_k_fastest else swizzle * 16,
|
||||
b_k_stride=b_k_wgmma_stride,
|
||||
swizzle=swizzle,
|
||||
element_type=ir.FloatTF32Type.get()
|
||||
if ir.F32Type.isinstance(element_type)
|
||||
@ -436,13 +498,13 @@ def _validate_mma(
|
||||
)
|
||||
if not a_in_smem:
|
||||
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
|
||||
a_k_byte_stride = a_desc_base = None
|
||||
a_k_group_stride = a_desc_base = None
|
||||
else:
|
||||
a_desc_base = create_descriptor(
|
||||
a, **a_desc_fields, const_init=descriptor_const_init
|
||||
)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_k_byte_stride = a_strides[1] * element_bytewidth
|
||||
a_k_group_stride = a_strides[1] * element_bytewidth
|
||||
b_desc_base = create_descriptor(
|
||||
b, **b_desc_fields, const_init=descriptor_const_init
|
||||
)
|
||||
@ -451,11 +513,11 @@ def _validate_mma(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_tiling, kn_tiling),
|
||||
(m_tiling, n_tiling),
|
||||
element_type,
|
||||
wgmma_params,
|
||||
a_k_byte_stride,
|
||||
b_k_byte_stride,
|
||||
a_k_group_stride,
|
||||
b_k_group_stride,
|
||||
)
|
||||
|
||||
|
||||
@ -488,11 +550,11 @@ def wgmma(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_tiling, kn_tiling),
|
||||
(m_tiling, _),
|
||||
element_type,
|
||||
wgmma_params,
|
||||
a_k_byte_stride,
|
||||
b_k_byte_stride,
|
||||
a_k_group_stride,
|
||||
b_k_group_stride,
|
||||
) = _validate_mma(a, b, swizzle)
|
||||
|
||||
if n > 256:
|
||||
@ -505,14 +567,15 @@ def wgmma(
|
||||
)
|
||||
if a.shape[0] % 64:
|
||||
raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}")
|
||||
a_m_byte_stride = None
|
||||
a_m_group_stride = None
|
||||
else:
|
||||
if m_tiling != 64:
|
||||
raise ValueError(f"A must have rows tiled by 64, got: {m_tiling}")
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_m_byte_stride = a_strides[0] * bytewidth(element_type)
|
||||
a_m_group_stride = a_strides[0] * bytewidth(element_type)
|
||||
|
||||
groups_k = k // kn_tiling
|
||||
k_group_width = swizzle // bytewidth(element_type)
|
||||
groups_k = k // k_group_width
|
||||
groups_m = m // 64
|
||||
|
||||
expected_acc_shape = (groups_m * 64, n)
|
||||
@ -530,13 +593,16 @@ def wgmma(
|
||||
for mi in range(groups_m):
|
||||
for ki in range(groups_k):
|
||||
if a_in_regs:
|
||||
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tiling : (ki + 1) * kn_tiling]
|
||||
a_mk = a[
|
||||
mi * 64 : (mi + 1) * 64,
|
||||
ki * k_group_width : (ki + 1) * k_group_width
|
||||
]
|
||||
else:
|
||||
a_mk = llvm_add(
|
||||
a_desc_base,
|
||||
c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
|
||||
c(wgmma_encode(mi * a_m_group_stride + ki * a_k_group_stride), i64),
|
||||
)
|
||||
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64))
|
||||
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_group_stride), i64))
|
||||
new_acc_regs[mi : mi + 1] = wgmma_m64(
|
||||
new_acc_regs[mi : mi + 1], a_mk, b_k, n=n, **wgmma_params
|
||||
)
|
||||
|
@ -38,7 +38,7 @@ jax_multiplatform_test(
|
||||
"gpu_h100x2",
|
||||
],
|
||||
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
|
||||
shard_count = 8,
|
||||
shard_count = 16,
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
|
@ -652,6 +652,7 @@ class WGMMATest(TestCase):
|
||||
k_steps=(1, 2),
|
||||
swizzle=(32, 64, 128),
|
||||
jax_out_dtype=(jnp.float16, jnp.float32),
|
||||
small_rhs_tile=(False, True,),
|
||||
)
|
||||
def test_wgmma_basic(
|
||||
self,
|
||||
@ -663,6 +664,7 @@ class WGMMATest(TestCase):
|
||||
rhs_transpose,
|
||||
swizzle,
|
||||
jax_out_dtype,
|
||||
small_rhs_tile,
|
||||
):
|
||||
if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type:
|
||||
raise self.skipTest("Only f16 input is supported for f16 output.")
|
||||
@ -693,13 +695,18 @@ class WGMMATest(TestCase):
|
||||
k = nk_tile * k_steps
|
||||
assert m % 64 == 0 and n % nk_tile == 0
|
||||
|
||||
small_nk_tile = 8 if rhs_transpose else 16
|
||||
rhs_tiling = (
|
||||
(small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
|
||||
)
|
||||
|
||||
def kernel(ctx, lhs, rhs, out, scratch):
|
||||
lhs_smem, rhs_smem, barriers = scratch
|
||||
lhs_transform = (mgpu.TileTransform((64, nk_tile)),)
|
||||
if lhs_transpose:
|
||||
assert nk_tile == 64 # Make sure we didn't have to transpose tiling.
|
||||
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
|
||||
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
|
||||
if rhs_transpose:
|
||||
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
ctx.async_copy(
|
||||
@ -737,10 +744,14 @@ class WGMMATest(TestCase):
|
||||
y_shape = (n, k) if rhs_transpose else (k, n)
|
||||
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype)
|
||||
rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling
|
||||
scratch_shape = [
|
||||
jax.ShapeDtypeStruct((m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype),
|
||||
jax.ShapeDtypeStruct(
|
||||
(k // nk_tile, n // nk_tile, nk_tile, nk_tile), in_jax_dtype
|
||||
(m // 64, k // nk_tile, 64, nk_tile), in_jax_dtype
|
||||
),
|
||||
jax.ShapeDtypeStruct(
|
||||
(k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling),
|
||||
in_jax_dtype,
|
||||
),
|
||||
mgpu.TMABarrier(2),
|
||||
]
|
||||
@ -812,18 +823,23 @@ class WGMMATest(TestCase):
|
||||
@parameterized.product(
|
||||
rhs_transpose=(False, True),
|
||||
swizzle=(32, 64, 128),
|
||||
n=(8, 16),
|
||||
small_rhs_tile=(False, True),
|
||||
)
|
||||
def test_narrow_n(self, rhs_transpose, swizzle):
|
||||
m, n, k_steps = 64, 8, 2
|
||||
|
||||
def test_narrow_n(self, rhs_transpose, swizzle, n, small_rhs_tile):
|
||||
m, k_steps = 64, 2
|
||||
bytewidth = 2
|
||||
nk_tile = swizzle // bytewidth
|
||||
k = nk_tile * k_steps
|
||||
if small_rhs_tile and not rhs_transpose:
|
||||
self.skipTest("Small tiles only supported for transposed RHS")
|
||||
|
||||
n_tile = 8 if small_rhs_tile else nk_tile
|
||||
|
||||
def kernel(ctx, rhs, out, smem):
|
||||
rhs_smem, barrier = smem
|
||||
gmem_slice = (ds(0, k), ds(0, nk_tile))
|
||||
transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
|
||||
gmem_slice = (ds(0, k), ds(0, max(n_tile, n)))
|
||||
transform = (mgpu.TileTransform((n_tile, nk_tile)),)
|
||||
if rhs_transpose:
|
||||
gmem_slice = gmem_slice[::-1]
|
||||
transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
@ -840,8 +856,9 @@ class WGMMATest(TestCase):
|
||||
lhs_regs = iota_tensor(m, k, jnp.float16)
|
||||
if rhs_transpose:
|
||||
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
|
||||
smem_slice = (slice(None), slice(None), slice(None), ds(0, n))
|
||||
rhs_smem = memref_slice(rhs_smem, smem_slice)
|
||||
if not small_rhs_tile:
|
||||
smem_slice = (slice(None), slice(None), slice(None), ds(0, n))
|
||||
rhs_smem = memref_slice(rhs_smem, smem_slice)
|
||||
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(0)
|
||||
@ -852,7 +869,7 @@ class WGMMATest(TestCase):
|
||||
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
||||
rhs_scratch_shape = jax.ShapeDtypeStruct(
|
||||
(k_steps, 1, nk_tile, nk_tile), jax_dtype
|
||||
(k_steps, (n + n_tile - 1) // n_tile, n_tile, nk_tile), jax_dtype
|
||||
)
|
||||
z = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()),
|
||||
@ -881,6 +898,7 @@ class TCGen05Test(TestCase):
|
||||
n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
|
||||
k_steps=(1, 2),
|
||||
swizzle=(32, 64, 128,),
|
||||
small_rhs_tile=(False, True),
|
||||
)
|
||||
def test_mma_basic(
|
||||
self,
|
||||
@ -892,6 +910,7 @@ class TCGen05Test(TestCase):
|
||||
rhs_transpose,
|
||||
in_jax_dtype,
|
||||
out_jax_dtype,
|
||||
small_rhs_tile,
|
||||
):
|
||||
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
|
||||
raise self.skipTest("Only f16 input is supported for f16 output.")
|
||||
@ -902,13 +921,16 @@ class TCGen05Test(TestCase):
|
||||
k = nk_tile * k_steps
|
||||
assert m % m_tile == 0 and n % nk_tile == 0
|
||||
|
||||
small_nk_tile = 8 if rhs_transpose else 16
|
||||
rhs_tiling = (small_nk_tile, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
|
||||
|
||||
def kernel(ctx, lhs, rhs, out, scratch):
|
||||
lhs_smem, rhs_smem, barriers, acc = scratch
|
||||
lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),)
|
||||
if lhs_transpose:
|
||||
assert nk_tile == m_tile # Make sure we didn't have to transpose tiling
|
||||
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),)
|
||||
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
|
||||
if rhs_transpose:
|
||||
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
ctx.async_copy(
|
||||
@ -950,9 +972,15 @@ class TCGen05Test(TestCase):
|
||||
y_shape = (n, k) if rhs_transpose else (k, n)
|
||||
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
|
||||
rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling
|
||||
scratch_shape = [
|
||||
jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype),
|
||||
jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype),
|
||||
jax.ShapeDtypeStruct(
|
||||
tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype
|
||||
),
|
||||
jax.ShapeDtypeStruct(
|
||||
(k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling),
|
||||
in_jax_dtype,
|
||||
),
|
||||
mgpu.TMABarrier(3),
|
||||
mgpu.TMEM((128, n), out_jax_dtype),
|
||||
]
|
||||
@ -973,6 +1001,7 @@ class TCGen05Test(TestCase):
|
||||
n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2
|
||||
k_steps=(1, 2),
|
||||
swizzle=(32, 64, 128,),
|
||||
small_rhs_tile=(False, True),
|
||||
)
|
||||
def test_mma_collective(
|
||||
self,
|
||||
@ -984,6 +1013,7 @@ class TCGen05Test(TestCase):
|
||||
rhs_transpose,
|
||||
in_jax_dtype,
|
||||
out_jax_dtype,
|
||||
small_rhs_tile,
|
||||
):
|
||||
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
|
||||
raise self.skipTest("Only f16 input is supported for f16 output.")
|
||||
@ -997,13 +1027,20 @@ class TCGen05Test(TestCase):
|
||||
assert m % m_tma_tile == 0 and n % nk_tma_tile == 0
|
||||
index = ir.IndexType.get()
|
||||
|
||||
small_nk_tile = 8 if rhs_transpose else 16
|
||||
rhs_tiling = (
|
||||
(small_nk_tile, nk_tma_tile)
|
||||
if small_rhs_tile
|
||||
else (nk_tma_tile, nk_tma_tile)
|
||||
)
|
||||
|
||||
def kernel(ctx, lhs, rhs, out, scratch):
|
||||
lhs_smem, rhs_smem, barriers, acc = scratch
|
||||
lhs_transform = (mgpu.TileTransform((m_tma_tile, nk_tma_tile)),)
|
||||
if lhs_transpose:
|
||||
assert nk_tma_tile == m_tma_tile # Make sure we didn't have to transpose tiling
|
||||
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
rhs_transform = (mgpu.TileTransform((nk_tma_tile, nk_tma_tile)),)
|
||||
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
|
||||
if rhs_transpose:
|
||||
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
||||
block_id = gpu.cluster_block_id(gpu.Dimension.x)
|
||||
@ -1053,9 +1090,16 @@ class TCGen05Test(TestCase):
|
||||
y_shape = (n, k) if rhs_transpose else (k, n)
|
||||
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
|
||||
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
|
||||
rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling
|
||||
scratch_shape = [
|
||||
jax.ShapeDtypeStruct(tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)), in_jax_dtype),
|
||||
jax.ShapeDtypeStruct(tile_shape((k, n_block_tile), (nk_tma_tile, nk_tma_tile)), in_jax_dtype),
|
||||
jax.ShapeDtypeStruct(
|
||||
tile_shape((m_block_tile, k), (m_tma_tile, nk_tma_tile)),
|
||||
in_jax_dtype,
|
||||
),
|
||||
jax.ShapeDtypeStruct(
|
||||
(k // rhs_tiling_t[0], n_block_tile // rhs_tiling_t[1], *rhs_tiling),
|
||||
in_jax_dtype,
|
||||
),
|
||||
mgpu.TMABarrier(3),
|
||||
mgpu.TMEM((128, n), out_jax_dtype, collective=True),
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user