mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add a fast type conversion from s8 vectors to bf16 vectors
Regular conversion instructions have a ridiculously low throughput on Hopper, so replacing them with some bit tricks yields a much faster implementation. Co-authored-by: Benjamin Chetioui <bchetioui@google.com> PiperOrigin-RevId: 665893696
This commit is contained in:
parent
d49d070f0e
commit
ce3ea109a4
@ -422,9 +422,50 @@ class FragmentedArray:
|
||||
|
||||
# TODO(apaszke): Support JAX dtypes here as well?
|
||||
def astype(self, new_dtype: ir.Type):
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
i16 = ir.IntegerType.get_signless(16)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
bf16 = ir.BF16Type.get()
|
||||
|
||||
cur_dtype = self.mlir_dtype
|
||||
if cur_dtype == new_dtype:
|
||||
return self
|
||||
reg_type = self.registers.flat[0].type
|
||||
is_vector_reg = ir.VectorType.isinstance(reg_type)
|
||||
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else ()
|
||||
if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,):
|
||||
new_registers = np.empty_like(self.registers)
|
||||
for idx, reg in np.ndenumerate(self.registers):
|
||||
reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg)
|
||||
val_16 = llvm.extractelement(reg_16, c(0, i32))
|
||||
# We first embed the s8 into a bf16 with the exponent equal to
|
||||
# bias + mantissa bits. Then, we zero the msb that didn't fit into the
|
||||
# mantissa, zero out all bits other than msb, and subtract the last
|
||||
# two values from each other. This takes advantage of the fact that the
|
||||
# lsb of the exponent (msb of the second byte) is zero, which allows us
|
||||
# to losslesly pack the msb there. When 1, it doubles the value of s2,
|
||||
# making the result negative.
|
||||
new_val_32 = llvm.inline_asm(
|
||||
i32,
|
||||
[val_16],
|
||||
"""
|
||||
{
|
||||
.reg .b32 s<3>;
|
||||
prmt.b32 s0, $1, 0x43, 0x4140;
|
||||
and.b32 s1, s0, 0xff7fff7f;
|
||||
and.b32 s2, s0, 0xff80ff80;
|
||||
sub.bf16x2 $0, s1, s2;
|
||||
}
|
||||
""",
|
||||
"=r,r",
|
||||
)
|
||||
new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32))
|
||||
new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32))
|
||||
new_registers[idx] = vector.bitcast(
|
||||
ir.VectorType.get((2,), new_dtype), new_vec
|
||||
)
|
||||
return FragmentedArray(_registers=new_registers, _layout=self.layout)
|
||||
# Generic path.
|
||||
from_float = ir.FloatType.isinstance(cur_dtype)
|
||||
to_float = ir.FloatType.isinstance(new_dtype)
|
||||
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
||||
|
@ -1192,6 +1192,26 @@ class FragmentedArrayTest(TestCase):
|
||||
|
||||
np.testing.assert_array_equal(result, x)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("_bf16", jnp.bfloat16)
|
||||
)
|
||||
def test_fast_i8_convert(self, jax_dtype_to):
|
||||
jax_dtype_to = jnp.dtype(jax_dtype_to)
|
||||
jax_dtype_from = jnp.dtype(jnp.int8)
|
||||
mlir_dtype_to = mlir.dtype_to_ir_type(jax_dtype_to)
|
||||
def kernel(ctx, inp, out, smem):
|
||||
del ctx, smem
|
||||
arr = mgpu.FragmentedArray.load_strided(inp)
|
||||
arr.astype(mlir_dtype_to).store_untiled(out)
|
||||
|
||||
x = jnp.arange(-128, 128, dtype=jax_dtype_from)
|
||||
reference = x.astype(jax_dtype_to)
|
||||
|
||||
result = mosaic_gpu.as_gpu_kernel(
|
||||
kernel, (1, 1, 1), (128, 1, 1), x, reference, None,
|
||||
)(x)
|
||||
np.testing.assert_array_equal(result, reference)
|
||||
|
||||
|
||||
class ProfilerTest(TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user