[Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length

We can now perform the conversion in groups of 2, 4 or even 8 elements at a time.

PiperOrigin-RevId: 737626600
This commit is contained in:
Adam Paszke 2025-03-17 08:36:36 -07:00 committed by jax authors
parent 0ff234049b
commit 3649da56fc
4 changed files with 81 additions and 31 deletions

View File

@ -799,7 +799,7 @@ pytype_strict_library(
)
# This target only supports sm_90 GPUs.
py_library(
py_library_providing_imports_info(
name = "mosaic_gpu",
srcs = glob(["experimental/mosaic/gpu/*.py"]),
visibility = [
@ -824,6 +824,7 @@ py_library(
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:vector_dialect",
"//jaxlib/mosaic/python:gpu_dialect",
] + py_deps("absl/flags") + py_deps("numpy"),
)

View File

@ -1244,11 +1244,10 @@ class FragmentedArray:
is_vector_reg = ir.VectorType.isinstance(reg_type)
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
[vector_len] = reg_shape # This is meant to be a 1D assertion.
if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2:
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
new_registers = np.empty_like(self.registers)
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32))
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
for idx, reg in np.ndenumerate(self.registers):
reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
# The algorithm here is largely the same as CUTLASS's
# NumericArrayConverter specialization for int4 -> bf16 casts.
# We modify it slightly, because we only extract 2 values.
@ -1262,25 +1261,41 @@ class FragmentedArray:
# positive int4s will end up larger than negative int4s, with a bias of
# 8. Use use the sub to subtract the base (our initial exponent) and the
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
new_reg_32 = llvm.inline_asm(
i32,
[reg_8],
"""
{
.reg .b32 s<4>;
shr.s32 s0, $1, 4;
prmt.b32 s1, $1, s0, 0xF4F0;
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
mov.b32 s3, 0x43084308;
sub.bf16x2 $0, s2, s3;
}
""",
"=r,r",
)
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
new_registers[idx] = vector.bitcast(
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
)
def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
assert 0 <= part < 4
return llvm.inline_asm(
i32,
[reg, reg_shr],
f"""
{{
.reg .b32 s<4>;
prmt.b32 s1, $1, $2, 0xF{part + 4}F{part};
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
mov.b32 s3, 0x43084308;
sub.bf16x2 $0, s2, s3;
}}
""",
"=r,r,r",
)
offset = 0
out_int_regs = []
for group_size in (8, 4, 2):
int_ty = ir.IntegerType.get_signless(group_size * 4)
while vector_len - offset >= group_size:
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty))
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
out_int_regs.extend(
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
for part in range(group_size // 2)
)
offset += group_size
assert offset == vector_len
out_vec_int = utils.vector_concat([
vector.splat(ir.VectorType.get((1,), i32), reg)
for reg in out_int_regs
])
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
return FragmentedArray(
_registers=new_registers, _layout=self.layout, _is_signed=None
)

View File

@ -348,6 +348,9 @@ def bitwidth_impl(ty: ir.Type):
return ir.FloatType(ty).width
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
return MBARRIER_BYTES * 8
if ir.VectorType.isinstance(ty):
vty = ir.VectorType(ty)
return math.prod(vty.shape) * bitwidth(vty.element_type)
raise NotImplementedError(ty)
@ -1220,6 +1223,12 @@ def bitcast(x: ir.Value, new_type: ir.Type):
x_ty = ir.IntegerType(x.type)
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
x_ty = ir.VectorType(x.type)
new_ty = ir.VectorType(new_type)
if bitwidth(x_ty) != bitwidth(new_ty):
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
return vector.bitcast(new_type, x)
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
@ -1239,3 +1248,27 @@ def vector_slice(v: ir.Value, s: slice):
elem = llvm.extractelement(v, c(src, i32))
result = llvm.insertelement(result, elem, c(tgt, i32))
return result
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
index = ir.IndexType.get()
if not vectors:
raise ValueError("Cannot concatenate an empty list of vectors")
vty = vectors[0].type
if not ir.VectorType.isinstance(vty):
raise ValueError("Cannot concatenate non-vector values")
if vty.rank != 1:
raise NotImplementedError("Only 1D vectors are supported")
for v in vectors:
if v.type != vty:
raise ValueError("Cannot concatenate vectors of different types")
result = llvm.mlir_undef(
ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type)
)
offset = 0
for v in vectors:
for i in range(vty.shape[0]):
elem = vector.extractelement(v, position=c(i, index))
result = vector.insertelement(elem, result, position=c(offset + i, index))
offset += vty.shape[0]
return result

View File

@ -518,14 +518,15 @@ class WGMMALayoutTest(TestCase):
)()
np.testing.assert_array_equal(iota, expected)
@parameterized.named_parameters(
("bf16_i8", jnp.bfloat16, jnp.int8),
("i8_bf16", jnp.int8, jnp.bfloat16),
("i8_i8", jnp.int8, jnp.int8),
("i4_i4", jnp.int4, jnp.int4),
("i4_bf16", jnp.int4, jnp.bfloat16),
@parameterized.product(
jax_dtype_from_to=(
(jnp.int8, jnp.bfloat16),
(jnp.int4, jnp.bfloat16),
),
layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X),
)
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
def test_optimized_conversion(self, jax_dtype_from_to, layout):
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
m = 128
@ -538,7 +539,7 @@ class WGMMALayoutTest(TestCase):
smem_from,
swizzle=128,
is_signed=utils.is_signed(jax_dtype_from),
layout=fa._tiled_wgmma_layout((m, n))
layout=layout,
)
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
t.store_tiled(smem_to, swizzle=128)