[Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts

PiperOrigin-RevId: 737956598
This commit is contained in:
Adam Paszke 2025-03-18 04:47:04 -07:00 committed by jax authors
parent 38d52a19ef
commit d4bd2570ae
2 changed files with 190 additions and 77 deletions

View File

@ -526,6 +526,18 @@ WGMMA_LAYOUT_UPCAST_2X = TiledLayout(
lane_dims=(-4, -2, -3),
vector_dim=-1,
)
# This layout should be used when upcasting 4-bit elements to 16-bit, for the
# purpose of passing them into WGMMA later. The core matrices stored by a warp
# are 8x32, because each of the 4 threads in a row holds 8 elements in a single
# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each
# group of 4 threads in order (as opposed to the swapping between 1 and 2,
# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does).
WGMMA_LAYOUT_UPCAST_4X = TiledLayout(
Tiling(((64, 32), (16, 32), (8, 32), (8,))),
warp_dim=-7,
lane_dims=(-3, -2),
vector_dim=-1,
)
# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8
# submatrix in the following way (we only show the first 4 rows for brevity):
#
@ -739,58 +751,132 @@ class FragmentedArray:
_layout=new_layout,
_is_signed=self.is_signed,
)
if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 16 == 0:
if (
self.layout == WGMMA_LAYOUT_UPCAST_2X
and new_layout == WGMMA_LAYOUT
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) in {8, 16}
):
assert shape[1] % 16 == 0 # Should be implied by the layout
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
is_even = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
if (
self.layout == WGMMA_LAYOUT_UPCAST_2X
and new_layout == WGMMA_LAYOUT
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16
):
assert shape[1] % 16 == 0 # Should be implied by the layout
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
is_even = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
)
registers = self.registers
if dtype_bitwidth == 4:
if registers.shape[1] % 2:
raise NotImplementedError(
"This relayout implementation requires an even number of column"
" tiles (to pack pairs of them for efficiency)"
)
# We pair up the consecutive column tiles, so each register is 32-bit.
# If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout,
# LLVM will realize that the paired up vectors actually came from the
# same 32-bit register and it will become a no-op.
col_minor_registers = np.moveaxis(registers, 1, -1)
flat_registers = [
utils.vector_concat((l, h))
for l, h in zip(
col_minor_registers.flat[::2], col_minor_registers.flat[1::2]
)
]
registers = np.asarray(flat_registers, dtype=object).reshape(
*col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2
)
for idx, reg in np.ndenumerate(self.registers):
assert ir.VectorType(reg.type).shape == [4]
if dtype_bitwidth == 16:
# A single vector is 64-bits, but shuffles are only 32-bit wide.
# We only shuffle the half that needs to go to other thread.
low = utils.vector_slice(reg, slice(0, 2))
high = utils.vector_slice(reg, slice(2, 4))
to_exchange = arith.select(is_even, high, low)
# Exchange values between even and odd threads.
exchanged = utils.shfl_bfly(to_exchange, 1)
low = arith.select(is_even, low, exchanged)
high = arith.select(is_even, exchanged, high)
elif dtype_bitwidth == 8:
# The vector is 32-bits, so we just shuffle the whole thing and
# use prmt to blend it with the local register.
exchanged = utils.shfl_bfly(reg, 1)
# Consider lanes 0 and 1, because the situation is symmetric for
# each pair. If we feed reg[lane] and exchanged[lane] (which is
# really the same as reg of the other lane) to prmt, we can index
# the elements of the result using the following indices:
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
# prmt[0]: 0 1 2 3 4 5 6 7
# prmt[1]: 4 5 6 7 0 1 2 3
# The expected outputs and their respective permutations are:
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
# Note that the patterns still need to be flipped, since we listed
# bytes with LSB on the left, which is the opposite of how the
# numeric constants are spelled in Python (LSB on the right).
perm = arith.select(is_even, c(0x5410), c(0x3276))
blend = utils.prmt(reg, exchanged, perm)
low = utils.vector_slice(blend, slice(0, 2))
high = utils.vector_slice(blend, slice(2, 4))
else:
raise NotImplementedError(dtype_bitwidth)
registers = np.moveaxis(registers, -1, 1)
for idx, reg in np.ndenumerate(registers):
if dtype_bitwidth == 16:
assert reg.type.shape == [4]
# A single vector is 64-bits, but shuffles are only 32-bit wide.
# We only shuffle the half that needs to go to other thread.
low = utils.vector_slice(reg, slice(0, 2))
high = utils.vector_slice(reg, slice(2, 4))
to_exchange = arith.select(is_even, high, low)
# Exchange values between even and odd threads.
exchanged = utils.shfl_bfly(to_exchange, 1)
low = arith.select(is_even, low, exchanged)
high = arith.select(is_even, exchanged, high)
new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low
new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high
assert all(r is not None for r in new_registers)
return FragmentedArray(
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
)
elif dtype_bitwidth == 8:
assert reg.type.shape == [4]
# The vector is 32-bits, so we just shuffle the whole thing and
# use prmt to blend it with the local register.
exchanged = utils.shfl_bfly(reg, 1)
# Consider lanes 0 and 1, because the situation is symmetric for
# each pair. If we feed reg[lane] and exchanged[lane] (which is
# really the same as reg of the other lane) to prmt, we can index
# the elements of the result using the following indices:
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
# prmt[0]: 0 1 2 3 4 5 6 7
# prmt[1]: 4 5 6 7 0 1 2 3
# The expected outputs and their respective permutations are:
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
# Note that the patterns still need to be flipped, since we listed
# bytes with LSB on the left, which is the opposite of how the
# numeric constants are spelled in Python (LSB on the right).
perm = arith.select(is_even, c(0x5410), c(0x3276))
blend = utils.prmt(reg, exchanged, perm)
for i in range(2):
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
else:
assert dtype_bitwidth == 4
assert reg.type.shape == [8] # We paired up the registers above.
exchanged = utils.shfl_bfly(reg, 1)
# See comment above for a more complete explanation.
# reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27
# prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7--
# prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3--
# The expected outputs and their respective permutations are:
# out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27
# prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3--
perm = arith.select(is_even, c(0x6240), c(0x3715))
blend = utils.prmt(reg, exchanged, perm)
for i in range(4):
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg
assert all(r is not None for r in new_registers)
return FragmentedArray(
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
)
if (
self.layout == WGMMA_LAYOUT_UPCAST_4X
and new_layout == WGMMA_LAYOUT_UPCAST_2X
and utils.bitwidth(self.mlir_dtype) == 4
):
assert shape[0] % 64 == 0 # Should be implied by the layout
assert shape[1] % 32 == 0 # Should be implied by the layout
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
i32 = ir.IntegerType.get_signless(32)
c = lambda x: arith.constant(i32, x)
is_01 = arith.cmpi(
arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2)
)
for idx, reg in np.ndenumerate(self.registers):
assert ir.VectorType(reg.type).shape == [8]
# The vector is 32-bits, so we just shuffle the whole thing and
# use prmt to blend it with the local register.
exchanged = utils.shfl_bfly(reg, 2)
# See comments above for conventions. Here we exchange data between
# threads with lane index related by flipping 2nd bit (e.g. 0 and 2).
# reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23
# prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7--
# prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3--
# The expected outputs and their respective permutations are:
# out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23
# prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3--
perm = arith.select(is_01, c(0x5410), c(0x3276))
blend = utils.prmt(reg, exchanged, perm)
for i in range(2):
reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4))
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
assert all(r is not None for r in new_registers)
return FragmentedArray(
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
)
if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT:
return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout)
if not isinstance(self.layout, WGSplatFragLayout):
raise NotImplementedError(
f"Cannot convert from {self.layout} to {new_layout}"
@ -1288,7 +1374,9 @@ class FragmentedArray:
int_ty = ir.IntegerType.get_signless(group_size * 4)
while vector_len - offset >= group_size:
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty))
reg_slice_int = utils.bitcast(reg_slice, int_ty)
if int_ty != i32:
reg_slice_int = arith.extsi(i32, reg_slice_int)
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
out_int_regs.extend(
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)

View File

@ -14,6 +14,7 @@
# ==============================================================================
from collections.abc import Sequence
import contextlib
import dataclasses
import enum
import itertools
@ -83,6 +84,20 @@ def mlir_sum(elems):
return total
@contextlib.contextmanager
def get_sass():
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
try:
with jtu.capture_stdout() as output:
yield output
finally:
if prev_dump is not None:
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
else:
del os.environ["MOSAIC_GPU_DUMP_SASS"]
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
index = ir.IndexType.get()
thread_id = gpu.thread_id(gpu.Dimension.x)
@ -542,7 +557,11 @@ class WGMMALayoutTest(TestCase):
(jnp.int8, jnp.bfloat16),
(jnp.int4, jnp.bfloat16),
),
layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X),
layout=(
fa.WGMMA_LAYOUT,
fa.WGMMA_LAYOUT_UPCAST_2X,
fa.WGMMA_LAYOUT_UPCAST_4X,
),
)
def test_optimized_conversion(self, jax_dtype_from_to, layout):
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
@ -2194,19 +2213,11 @@ class LayoutTest(TestCase):
.transpose(0, 2, 1, 3)
)
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
try:
with jtu.capture_stdout() as get_sass:
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
[expected, expected, mgpu.TMABarrier()],
)(expected)
finally:
if prev_dump is not None:
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
else:
del os.environ["MOSAIC_GPU_DUMP_SASS"]
with get_sass() as sass:
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
[expected, expected, mgpu.TMABarrier()],
)(expected)
np.testing.assert_array_equal(iota, expected)
# Verify that we don't use too many registers for the transfers.
@ -2219,7 +2230,7 @@ class LayoutTest(TestCase):
expected_regs //= 2
for instr in ("STS", "LDS"):
with self.subTest(instr + " count"):
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
addrs = re.findall(instr + r".* \[(.*)\]", sass())
def get_reg(addr):
if (pos := addr.find("+")) != -1:
return addr[:pos]
@ -2294,30 +2305,38 @@ class LayoutTest(TestCase):
)(x)
np.testing.assert_array_equal(y, y_ref)
@parameterized.product(
upcast_before_layout_change=[True, False],
@parameterized.parameters(
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int8, 1),
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int16, 1),
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, jnp.int4, jnp.int4, 1),
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5),
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
)
def test_upcast_to_wgmma(self, upcast_before_layout_change):
in_dtype = jnp.dtype(jnp.int8)
def test_upcast_to_wgmma(
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
):
in_dtype = jnp.dtype(in_dtype)
out_dtype = jnp.dtype(jnp.int16)
out_dtype_mlir = utils.dtype_to_ir_type(out_dtype)
swizzle = 128
in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits
in_tiling = (8, in_col_tiling)
out_col_tiling = swizzle // out_dtype.itemsize
out_tiling = (8, out_col_tiling)
m, n = 128, in_col_tiling * 2
regs_per_thread = None
def kernel(ctx, in_, out, smems):
nonlocal regs_per_thread
smem_in, smem_out, barrier = smems
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
barrier.wait()
t = mgpu.FragmentedArray.load_tiled(
smem_in, swizzle=swizzle, is_signed=True, layout=fa.WGMMA_LAYOUT_UPCAST_2X
smem_in, swizzle=swizzle, is_signed=True, layout=start_layout
)
if upcast_before_layout_change:
t = t.astype(ir.IntegerType.get_signless(16), is_signed=True)
t = t.to_layout(fa.WGMMA_LAYOUT)
if not upcast_before_layout_change:
t = t.astype(ir.IntegerType.get_signless(16), is_signed=True)
regs_per_thread = t.registers.size
t = t.astype(utils.dtype_to_ir_type(cast_dtype), is_signed=True)
t = t.to_layout(end_layout)
t = t.astype(out_dtype_mlir, is_signed=True)
t.store_tiled(smem_out, swizzle=swizzle)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
@ -2326,14 +2345,20 @@ class LayoutTest(TestCase):
return x.reshape(
x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1]
).transpose(0, 2, 1, 3)
x = jax.random.randint(jax.random.key(42), (m, n), -128, 127, dtype=in_dtype)
in_iinfo = jnp.iinfo(in_dtype)
x = jax.random.randint(
jax.random.key(42), (m, n), in_iinfo.min, in_iinfo.max, dtype=jnp.int32
).astype(in_dtype)
xt = tile(x, in_tiling)
y = x.astype(out_dtype)
yt = tile(y, out_tiling)
f = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()],
)
np.testing.assert_array_equal(f(xt), yt)
with get_sass() as sass:
yt_kernel = f(xt)
np.testing.assert_array_equal(yt_kernel, yt)
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
@dataclasses.dataclass(frozen=True)