mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts
PiperOrigin-RevId: 737956598
This commit is contained in:
parent
38d52a19ef
commit
d4bd2570ae
@ -526,6 +526,18 @@ WGMMA_LAYOUT_UPCAST_2X = TiledLayout(
|
||||
lane_dims=(-4, -2, -3),
|
||||
vector_dim=-1,
|
||||
)
|
||||
# This layout should be used when upcasting 4-bit elements to 16-bit, for the
|
||||
# purpose of passing them into WGMMA later. The core matrices stored by a warp
|
||||
# are 8x32, because each of the 4 threads in a row holds 8 elements in a single
|
||||
# vector. Note that unlike WGMMA_LAYOUT_UPCAST_2X, we assign columns to each
|
||||
# group of 4 threads in order (as opposed to the swapping between 1 and 2,
|
||||
# 5 and 6, etc. that WGMMA_LAYOUT_UPCAST_2X does).
|
||||
WGMMA_LAYOUT_UPCAST_4X = TiledLayout(
|
||||
Tiling(((64, 32), (16, 32), (8, 32), (8,))),
|
||||
warp_dim=-7,
|
||||
lane_dims=(-3, -2),
|
||||
vector_dim=-1,
|
||||
)
|
||||
# This tiled layout is similar to WGMMA_LAYOUT. There, each warp stores a 8x8
|
||||
# submatrix in the following way (we only show the first 4 rows for brevity):
|
||||
#
|
||||
@ -739,58 +751,132 @@ class FragmentedArray:
|
||||
_layout=new_layout,
|
||||
_is_signed=self.is_signed,
|
||||
)
|
||||
if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 16 == 0:
|
||||
if (
|
||||
self.layout == WGMMA_LAYOUT_UPCAST_2X
|
||||
and new_layout == WGMMA_LAYOUT
|
||||
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) in {8, 16}
|
||||
):
|
||||
assert shape[1] % 16 == 0 # Should be implied by the layout
|
||||
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
||||
is_even = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
|
||||
if (
|
||||
self.layout == WGMMA_LAYOUT_UPCAST_2X
|
||||
and new_layout == WGMMA_LAYOUT
|
||||
and (dtype_bitwidth := utils.bitwidth(self.mlir_dtype)) <= 16
|
||||
):
|
||||
assert shape[1] % 16 == 0 # Should be implied by the layout
|
||||
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
||||
is_even = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, arith.remui(utils.thread_idx(), c(2)), c(0)
|
||||
)
|
||||
registers = self.registers
|
||||
if dtype_bitwidth == 4:
|
||||
if registers.shape[1] % 2:
|
||||
raise NotImplementedError(
|
||||
"This relayout implementation requires an even number of column"
|
||||
" tiles (to pack pairs of them for efficiency)"
|
||||
)
|
||||
# We pair up the consecutive column tiles, so each register is 32-bit.
|
||||
# If this layout originated from a WGMMA_LAYOUT_UPCAST_4X layout,
|
||||
# LLVM will realize that the paired up vectors actually came from the
|
||||
# same 32-bit register and it will become a no-op.
|
||||
col_minor_registers = np.moveaxis(registers, 1, -1)
|
||||
flat_registers = [
|
||||
utils.vector_concat((l, h))
|
||||
for l, h in zip(
|
||||
col_minor_registers.flat[::2], col_minor_registers.flat[1::2]
|
||||
)
|
||||
]
|
||||
registers = np.asarray(flat_registers, dtype=object).reshape(
|
||||
*col_minor_registers.shape[:-1], col_minor_registers.shape[-1] // 2
|
||||
)
|
||||
for idx, reg in np.ndenumerate(self.registers):
|
||||
assert ir.VectorType(reg.type).shape == [4]
|
||||
if dtype_bitwidth == 16:
|
||||
# A single vector is 64-bits, but shuffles are only 32-bit wide.
|
||||
# We only shuffle the half that needs to go to other thread.
|
||||
low = utils.vector_slice(reg, slice(0, 2))
|
||||
high = utils.vector_slice(reg, slice(2, 4))
|
||||
to_exchange = arith.select(is_even, high, low)
|
||||
# Exchange values between even and odd threads.
|
||||
exchanged = utils.shfl_bfly(to_exchange, 1)
|
||||
low = arith.select(is_even, low, exchanged)
|
||||
high = arith.select(is_even, exchanged, high)
|
||||
elif dtype_bitwidth == 8:
|
||||
# The vector is 32-bits, so we just shuffle the whole thing and
|
||||
# use prmt to blend it with the local register.
|
||||
exchanged = utils.shfl_bfly(reg, 1)
|
||||
# Consider lanes 0 and 1, because the situation is symmetric for
|
||||
# each pair. If we feed reg[lane] and exchanged[lane] (which is
|
||||
# really the same as reg of the other lane) to prmt, we can index
|
||||
# the elements of the result using the following indices:
|
||||
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
|
||||
# prmt[0]: 0 1 2 3 4 5 6 7
|
||||
# prmt[1]: 4 5 6 7 0 1 2 3
|
||||
# The expected outputs and their respective permutations are:
|
||||
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
|
||||
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
|
||||
# Note that the patterns still need to be flipped, since we listed
|
||||
# bytes with LSB on the left, which is the opposite of how the
|
||||
# numeric constants are spelled in Python (LSB on the right).
|
||||
perm = arith.select(is_even, c(0x5410), c(0x3276))
|
||||
blend = utils.prmt(reg, exchanged, perm)
|
||||
low = utils.vector_slice(blend, slice(0, 2))
|
||||
high = utils.vector_slice(blend, slice(2, 4))
|
||||
else:
|
||||
raise NotImplementedError(dtype_bitwidth)
|
||||
registers = np.moveaxis(registers, -1, 1)
|
||||
for idx, reg in np.ndenumerate(registers):
|
||||
if dtype_bitwidth == 16:
|
||||
assert reg.type.shape == [4]
|
||||
# A single vector is 64-bits, but shuffles are only 32-bit wide.
|
||||
# We only shuffle the half that needs to go to other thread.
|
||||
low = utils.vector_slice(reg, slice(0, 2))
|
||||
high = utils.vector_slice(reg, slice(2, 4))
|
||||
to_exchange = arith.select(is_even, high, low)
|
||||
# Exchange values between even and odd threads.
|
||||
exchanged = utils.shfl_bfly(to_exchange, 1)
|
||||
low = arith.select(is_even, low, exchanged)
|
||||
high = arith.select(is_even, exchanged, high)
|
||||
new_registers[(idx[0], idx[1] * 2, *idx[2:-1])] = low
|
||||
new_registers[(idx[0], idx[1] * 2 + 1, *idx[2:-1])] = high
|
||||
assert all(r is not None for r in new_registers)
|
||||
return FragmentedArray(
|
||||
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
||||
)
|
||||
elif dtype_bitwidth == 8:
|
||||
assert reg.type.shape == [4]
|
||||
# The vector is 32-bits, so we just shuffle the whole thing and
|
||||
# use prmt to blend it with the local register.
|
||||
exchanged = utils.shfl_bfly(reg, 1)
|
||||
# Consider lanes 0 and 1, because the situation is symmetric for
|
||||
# each pair. If we feed reg[lane] and exchanged[lane] (which is
|
||||
# really the same as reg of the other lane) to prmt, we can index
|
||||
# the elements of the result using the following indices:
|
||||
# reg[0]: 0 1 2 3 reg[1]: 8 9 10 11
|
||||
# prmt[0]: 0 1 2 3 4 5 6 7
|
||||
# prmt[1]: 4 5 6 7 0 1 2 3
|
||||
# The expected outputs and their respective permutations are:
|
||||
# out[0]: 0 1 8 9 out[1]: 2 3 10 11
|
||||
# prmt[0]: 0 1 4 5 prmt[1]: 6 7 2 3
|
||||
# Note that the patterns still need to be flipped, since we listed
|
||||
# bytes with LSB on the left, which is the opposite of how the
|
||||
# numeric constants are spelled in Python (LSB on the right).
|
||||
perm = arith.select(is_even, c(0x5410), c(0x3276))
|
||||
blend = utils.prmt(reg, exchanged, perm)
|
||||
for i in range(2):
|
||||
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
|
||||
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
|
||||
else:
|
||||
assert dtype_bitwidth == 4
|
||||
assert reg.type.shape == [8] # We paired up the registers above.
|
||||
exchanged = utils.shfl_bfly(reg, 1)
|
||||
# See comment above for a more complete explanation.
|
||||
# reg[0]: 0 1 2 3 16 17 18 19 reg[1]: 8 9 10 11 24 25 26 27
|
||||
# prmt[0]: -0- -1- --2-- --3-- -4- --5-- --6-- --7--
|
||||
# prmt[1]: -4- -5- --6-- --7-- -0- --1-- --2-- --3--
|
||||
# The expected outputs and their respective permutations are:
|
||||
# out[0]: 0 1 8 9 16 17 24 25 out[1]: 2 3 10 11 18 19 26 27
|
||||
# prmt[0]: -0- -4- --2-- --6-- prmt[1]: -5- --1-- --7-- --3--
|
||||
perm = arith.select(is_even, c(0x6240), c(0x3715))
|
||||
blend = utils.prmt(reg, exchanged, perm)
|
||||
for i in range(4):
|
||||
reg = utils.vector_slice(blend, slice(i * 2, i * 2 + 2))
|
||||
new_registers[(idx[0], idx[1] * 4 + i, *idx[2:-1])] = reg
|
||||
assert all(r is not None for r in new_registers)
|
||||
return FragmentedArray(
|
||||
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
||||
)
|
||||
if (
|
||||
self.layout == WGMMA_LAYOUT_UPCAST_4X
|
||||
and new_layout == WGMMA_LAYOUT_UPCAST_2X
|
||||
and utils.bitwidth(self.mlir_dtype) == 4
|
||||
):
|
||||
assert shape[0] % 64 == 0 # Should be implied by the layout
|
||||
assert shape[1] % 32 == 0 # Should be implied by the layout
|
||||
new_registers = np.empty(new_layout.registers_shape(shape), dtype=object)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
c = lambda x: arith.constant(i32, x)
|
||||
is_01 = arith.cmpi(
|
||||
arith.CmpIPredicate.ult, arith.remui(utils.thread_idx(), c(4)), c(2)
|
||||
)
|
||||
for idx, reg in np.ndenumerate(self.registers):
|
||||
assert ir.VectorType(reg.type).shape == [8]
|
||||
# The vector is 32-bits, so we just shuffle the whole thing and
|
||||
# use prmt to blend it with the local register.
|
||||
exchanged = utils.shfl_bfly(reg, 2)
|
||||
# See comments above for conventions. Here we exchange data between
|
||||
# threads with lane index related by flipping 2nd bit (e.g. 0 and 2).
|
||||
# reg[0]: 0 1 2 3 4 5 6 7 reg[2]: 16 17 18 19 20 21 22 23
|
||||
# prmt[0]: -0- -1- -2- -3- --4-- --5-- --6-- --7--
|
||||
# prmt[1]: -4- -5- -6- -7- --0-- --1-- --2-- --3--
|
||||
# The expected outputs and their respective permutations are:
|
||||
# out[0]: 0 1 2 3 16 17 18 19 out[2]: 4 5 6 7 20 21 22 23
|
||||
# prmt[0]: -0- -1- --4-- --5-- prmt[2]: -6- -7- --2-- --3--
|
||||
perm = arith.select(is_01, c(0x5410), c(0x3276))
|
||||
blend = utils.prmt(reg, exchanged, perm)
|
||||
for i in range(2):
|
||||
reg = utils.vector_slice(blend, slice(i * 4, i * 4 + 4))
|
||||
new_registers[(idx[0], idx[1] * 2 + i, *idx[2:-1])] = reg
|
||||
assert all(r is not None for r in new_registers)
|
||||
return FragmentedArray(
|
||||
_registers=new_registers, _layout=new_layout, _is_signed=self.is_signed,
|
||||
)
|
||||
if self.layout == WGMMA_LAYOUT_UPCAST_4X and new_layout == WGMMA_LAYOUT:
|
||||
return self.to_layout(WGMMA_LAYOUT_UPCAST_2X).to_layout(new_layout)
|
||||
if not isinstance(self.layout, WGSplatFragLayout):
|
||||
raise NotImplementedError(
|
||||
f"Cannot convert from {self.layout} to {new_layout}"
|
||||
@ -1288,7 +1374,9 @@ class FragmentedArray:
|
||||
int_ty = ir.IntegerType.get_signless(group_size * 4)
|
||||
while vector_len - offset >= group_size:
|
||||
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
||||
reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty))
|
||||
reg_slice_int = utils.bitcast(reg_slice, int_ty)
|
||||
if int_ty != i32:
|
||||
reg_slice_int = arith.extsi(i32, reg_slice_int)
|
||||
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
|
||||
out_int_regs.extend(
|
||||
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# ==============================================================================
|
||||
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
import itertools
|
||||
@ -83,6 +84,20 @@ def mlir_sum(elems):
|
||||
return total
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_sass():
|
||||
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
|
||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
|
||||
try:
|
||||
with jtu.capture_stdout() as output:
|
||||
yield output
|
||||
finally:
|
||||
if prev_dump is not None:
|
||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
|
||||
else:
|
||||
del os.environ["MOSAIC_GPU_DUMP_SASS"]
|
||||
|
||||
|
||||
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
|
||||
index = ir.IndexType.get()
|
||||
thread_id = gpu.thread_id(gpu.Dimension.x)
|
||||
@ -542,7 +557,11 @@ class WGMMALayoutTest(TestCase):
|
||||
(jnp.int8, jnp.bfloat16),
|
||||
(jnp.int4, jnp.bfloat16),
|
||||
),
|
||||
layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X),
|
||||
layout=(
|
||||
fa.WGMMA_LAYOUT,
|
||||
fa.WGMMA_LAYOUT_UPCAST_2X,
|
||||
fa.WGMMA_LAYOUT_UPCAST_4X,
|
||||
),
|
||||
)
|
||||
def test_optimized_conversion(self, jax_dtype_from_to, layout):
|
||||
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
|
||||
@ -2194,19 +2213,11 @@ class LayoutTest(TestCase):
|
||||
.transpose(0, 2, 1, 3)
|
||||
)
|
||||
|
||||
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
|
||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
|
||||
try:
|
||||
with jtu.capture_stdout() as get_sass:
|
||||
iota = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
|
||||
[expected, expected, mgpu.TMABarrier()],
|
||||
)(expected)
|
||||
finally:
|
||||
if prev_dump is not None:
|
||||
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
|
||||
else:
|
||||
del os.environ["MOSAIC_GPU_DUMP_SASS"]
|
||||
with get_sass() as sass:
|
||||
iota = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
|
||||
[expected, expected, mgpu.TMABarrier()],
|
||||
)(expected)
|
||||
np.testing.assert_array_equal(iota, expected)
|
||||
|
||||
# Verify that we don't use too many registers for the transfers.
|
||||
@ -2219,7 +2230,7 @@ class LayoutTest(TestCase):
|
||||
expected_regs //= 2
|
||||
for instr in ("STS", "LDS"):
|
||||
with self.subTest(instr + " count"):
|
||||
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
|
||||
addrs = re.findall(instr + r".* \[(.*)\]", sass())
|
||||
def get_reg(addr):
|
||||
if (pos := addr.find("+")) != -1:
|
||||
return addr[:pos]
|
||||
@ -2294,30 +2305,38 @@ class LayoutTest(TestCase):
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, y_ref)
|
||||
|
||||
@parameterized.product(
|
||||
upcast_before_layout_change=[True, False],
|
||||
@parameterized.parameters(
|
||||
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int8, 1),
|
||||
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int8, jnp.int16, 1),
|
||||
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X, jnp.int4, jnp.int4, 1),
|
||||
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5),
|
||||
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
|
||||
)
|
||||
def test_upcast_to_wgmma(self, upcast_before_layout_change):
|
||||
in_dtype = jnp.dtype(jnp.int8)
|
||||
def test_upcast_to_wgmma(
|
||||
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
|
||||
):
|
||||
in_dtype = jnp.dtype(in_dtype)
|
||||
out_dtype = jnp.dtype(jnp.int16)
|
||||
out_dtype_mlir = utils.dtype_to_ir_type(out_dtype)
|
||||
swizzle = 128
|
||||
in_col_tiling = 8 * swizzle // jnp.iinfo(in_dtype).bits
|
||||
in_tiling = (8, in_col_tiling)
|
||||
out_col_tiling = swizzle // out_dtype.itemsize
|
||||
out_tiling = (8, out_col_tiling)
|
||||
m, n = 128, in_col_tiling * 2
|
||||
regs_per_thread = None
|
||||
def kernel(ctx, in_, out, smems):
|
||||
nonlocal regs_per_thread
|
||||
smem_in, smem_out, barrier = smems
|
||||
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
||||
barrier.wait()
|
||||
t = mgpu.FragmentedArray.load_tiled(
|
||||
smem_in, swizzle=swizzle, is_signed=True, layout=fa.WGMMA_LAYOUT_UPCAST_2X
|
||||
smem_in, swizzle=swizzle, is_signed=True, layout=start_layout
|
||||
)
|
||||
if upcast_before_layout_change:
|
||||
t = t.astype(ir.IntegerType.get_signless(16), is_signed=True)
|
||||
t = t.to_layout(fa.WGMMA_LAYOUT)
|
||||
if not upcast_before_layout_change:
|
||||
t = t.astype(ir.IntegerType.get_signless(16), is_signed=True)
|
||||
regs_per_thread = t.registers.size
|
||||
t = t.astype(utils.dtype_to_ir_type(cast_dtype), is_signed=True)
|
||||
t = t.to_layout(end_layout)
|
||||
t = t.astype(out_dtype_mlir, is_signed=True)
|
||||
t.store_tiled(smem_out, swizzle=swizzle)
|
||||
mgpu.commit_shared()
|
||||
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
|
||||
@ -2326,14 +2345,20 @@ class LayoutTest(TestCase):
|
||||
return x.reshape(
|
||||
x.shape[0] // tiling[0], tiling[0], x.shape[1] // tiling[1], tiling[1]
|
||||
).transpose(0, 2, 1, 3)
|
||||
x = jax.random.randint(jax.random.key(42), (m, n), -128, 127, dtype=in_dtype)
|
||||
in_iinfo = jnp.iinfo(in_dtype)
|
||||
x = jax.random.randint(
|
||||
jax.random.key(42), (m, n), in_iinfo.min, in_iinfo.max, dtype=jnp.int32
|
||||
).astype(in_dtype)
|
||||
xt = tile(x, in_tiling)
|
||||
y = x.astype(out_dtype)
|
||||
yt = tile(y, out_tiling)
|
||||
f = mgpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()],
|
||||
)
|
||||
np.testing.assert_array_equal(f(xt), yt)
|
||||
with get_sass() as sass:
|
||||
yt_kernel = f(xt)
|
||||
np.testing.assert_array_equal(yt_kernel, yt)
|
||||
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user