[Mosaic GPU] Add a fast type conversion from s8 vectors to bf16 vectors

Regular conversion instructions have a ridiculously low throughput on Hopper,
so replacing them with some bit tricks yields a much faster implementation.

Co-authored-by: Benjamin Chetioui <bchetioui@google.com>
PiperOrigin-RevId: 665893696
This commit is contained in:
Adam Paszke 2024-08-21 08:36:24 -07:00 committed by jax authors
parent d49d070f0e
commit ce3ea109a4
2 changed files with 61 additions and 0 deletions

View File

@ -422,9 +422,50 @@ class FragmentedArray:
# TODO(apaszke): Support JAX dtypes here as well?
def astype(self, new_dtype: ir.Type):
i8 = ir.IntegerType.get_signless(8)
i16 = ir.IntegerType.get_signless(16)
i32 = ir.IntegerType.get_signless(32)
bf16 = ir.BF16Type.get()
cur_dtype = self.mlir_dtype
if cur_dtype == new_dtype:
return self
reg_type = self.registers.flat[0].type
is_vector_reg = ir.VectorType.isinstance(reg_type)
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else ()
if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,):
new_registers = np.empty_like(self.registers)
for idx, reg in np.ndenumerate(self.registers):
reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg)
val_16 = llvm.extractelement(reg_16, c(0, i32))
# We first embed the s8 into a bf16 with the exponent equal to
# bias + mantissa bits. Then, we zero the msb that didn't fit into the
# mantissa, zero out all bits other than msb, and subtract the last
# two values from each other. This takes advantage of the fact that the
# lsb of the exponent (msb of the second byte) is zero, which allows us
# to losslesly pack the msb there. When 1, it doubles the value of s2,
# making the result negative.
new_val_32 = llvm.inline_asm(
i32,
[val_16],
"""
{
.reg .b32 s<3>;
prmt.b32 s0, $1, 0x43, 0x4140;
and.b32 s1, s0, 0xff7fff7f;
and.b32 s2, s0, 0xff80ff80;
sub.bf16x2 $0, s1, s2;
}
""",
"=r,r",
)
new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32))
new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32))
new_registers[idx] = vector.bitcast(
ir.VectorType.get((2,), new_dtype), new_vec
)
return FragmentedArray(_registers=new_registers, _layout=self.layout)
# Generic path.
from_float = ir.FloatType.isinstance(cur_dtype)
to_float = ir.FloatType.isinstance(new_dtype)
from_integer = ir.IntegerType.isinstance(cur_dtype)

View File

@ -1192,6 +1192,26 @@ class FragmentedArrayTest(TestCase):
np.testing.assert_array_equal(result, x)
@parameterized.named_parameters(
("_bf16", jnp.bfloat16)
)
def test_fast_i8_convert(self, jax_dtype_to):
jax_dtype_to = jnp.dtype(jax_dtype_to)
jax_dtype_from = jnp.dtype(jnp.int8)
mlir_dtype_to = mlir.dtype_to_ir_type(jax_dtype_to)
def kernel(ctx, inp, out, smem):
del ctx, smem
arr = mgpu.FragmentedArray.load_strided(inp)
arr.astype(mlir_dtype_to).store_untiled(out)
x = jnp.arange(-128, 128, dtype=jax_dtype_from)
reference = x.astype(jax_dtype_to)
result = mosaic_gpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, reference, None,
)(x)
np.testing.assert_array_equal(result, reference)
class ProfilerTest(TestCase):