mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length
We can now perform the conversion in groups of 2, 4 or even 8 elements at a time. PiperOrigin-RevId: 737626600
This commit is contained in:
parent
0ff234049b
commit
3649da56fc
@ -799,7 +799,7 @@ pytype_strict_library(
|
||||
)
|
||||
|
||||
# This target only supports sm_90 GPUs.
|
||||
py_library(
|
||||
py_library_providing_imports_info(
|
||||
name = "mosaic_gpu",
|
||||
srcs = glob(["experimental/mosaic/gpu/*.py"]),
|
||||
visibility = [
|
||||
@ -824,6 +824,7 @@ py_library(
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:scf_dialect",
|
||||
"//jaxlib/mlir:vector_dialect",
|
||||
"//jaxlib/mosaic/python:gpu_dialect",
|
||||
] + py_deps("absl/flags") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
|
@ -1244,11 +1244,10 @@ class FragmentedArray:
|
||||
is_vector_reg = ir.VectorType.isinstance(reg_type)
|
||||
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
|
||||
[vector_len] = reg_shape # This is meant to be a 1D assertion.
|
||||
if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2:
|
||||
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
|
||||
new_registers = np.empty_like(self.registers)
|
||||
empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32))
|
||||
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
|
||||
for idx, reg in np.ndenumerate(self.registers):
|
||||
reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
|
||||
# The algorithm here is largely the same as CUTLASS's
|
||||
# NumericArrayConverter specialization for int4 -> bf16 casts.
|
||||
# We modify it slightly, because we only extract 2 values.
|
||||
@ -1262,25 +1261,41 @@ class FragmentedArray:
|
||||
# positive int4s will end up larger than negative int4s, with a bias of
|
||||
# 8. Use use the sub to subtract the base (our initial exponent) and the
|
||||
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
|
||||
new_reg_32 = llvm.inline_asm(
|
||||
i32,
|
||||
[reg_8],
|
||||
"""
|
||||
{
|
||||
.reg .b32 s<4>;
|
||||
shr.s32 s0, $1, 4;
|
||||
prmt.b32 s1, $1, s0, 0xF4F0;
|
||||
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
|
||||
mov.b32 s3, 0x43084308;
|
||||
sub.bf16x2 $0, s2, s3;
|
||||
}
|
||||
""",
|
||||
"=r,r",
|
||||
)
|
||||
new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
|
||||
new_registers[idx] = vector.bitcast(
|
||||
ir.VectorType.get((vector_len,), new_dtype), new_vec_32
|
||||
)
|
||||
def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
|
||||
assert 0 <= part < 4
|
||||
return llvm.inline_asm(
|
||||
i32,
|
||||
[reg, reg_shr],
|
||||
f"""
|
||||
{{
|
||||
.reg .b32 s<4>;
|
||||
prmt.b32 s1, $1, $2, 0xF{part + 4}F{part};
|
||||
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
|
||||
mov.b32 s3, 0x43084308;
|
||||
sub.bf16x2 $0, s2, s3;
|
||||
}}
|
||||
""",
|
||||
"=r,r,r",
|
||||
)
|
||||
offset = 0
|
||||
out_int_regs = []
|
||||
for group_size in (8, 4, 2):
|
||||
int_ty = ir.IntegerType.get_signless(group_size * 4)
|
||||
while vector_len - offset >= group_size:
|
||||
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
||||
reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty))
|
||||
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
|
||||
out_int_regs.extend(
|
||||
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
|
||||
for part in range(group_size // 2)
|
||||
)
|
||||
offset += group_size
|
||||
assert offset == vector_len
|
||||
out_vec_int = utils.vector_concat([
|
||||
vector.splat(ir.VectorType.get((1,), i32), reg)
|
||||
for reg in out_int_regs
|
||||
])
|
||||
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
|
||||
return FragmentedArray(
|
||||
_registers=new_registers, _layout=self.layout, _is_signed=None
|
||||
)
|
||||
|
@ -348,6 +348,9 @@ def bitwidth_impl(ty: ir.Type):
|
||||
return ir.FloatType(ty).width
|
||||
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
|
||||
return MBARRIER_BYTES * 8
|
||||
if ir.VectorType.isinstance(ty):
|
||||
vty = ir.VectorType(ty)
|
||||
return math.prod(vty.shape) * bitwidth(vty.element_type)
|
||||
raise NotImplementedError(ty)
|
||||
|
||||
|
||||
@ -1220,6 +1223,12 @@ def bitcast(x: ir.Value, new_type: ir.Type):
|
||||
x_ty = ir.IntegerType(x.type)
|
||||
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
|
||||
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
|
||||
if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
|
||||
x_ty = ir.VectorType(x.type)
|
||||
new_ty = ir.VectorType(new_type)
|
||||
if bitwidth(x_ty) != bitwidth(new_ty):
|
||||
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
||||
return vector.bitcast(new_type, x)
|
||||
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
||||
|
||||
|
||||
@ -1239,3 +1248,27 @@ def vector_slice(v: ir.Value, s: slice):
|
||||
elem = llvm.extractelement(v, c(src, i32))
|
||||
result = llvm.insertelement(result, elem, c(tgt, i32))
|
||||
return result
|
||||
|
||||
|
||||
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
|
||||
index = ir.IndexType.get()
|
||||
if not vectors:
|
||||
raise ValueError("Cannot concatenate an empty list of vectors")
|
||||
vty = vectors[0].type
|
||||
if not ir.VectorType.isinstance(vty):
|
||||
raise ValueError("Cannot concatenate non-vector values")
|
||||
if vty.rank != 1:
|
||||
raise NotImplementedError("Only 1D vectors are supported")
|
||||
for v in vectors:
|
||||
if v.type != vty:
|
||||
raise ValueError("Cannot concatenate vectors of different types")
|
||||
result = llvm.mlir_undef(
|
||||
ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type)
|
||||
)
|
||||
offset = 0
|
||||
for v in vectors:
|
||||
for i in range(vty.shape[0]):
|
||||
elem = vector.extractelement(v, position=c(i, index))
|
||||
result = vector.insertelement(elem, result, position=c(offset + i, index))
|
||||
offset += vty.shape[0]
|
||||
return result
|
||||
|
@ -518,14 +518,15 @@ class WGMMALayoutTest(TestCase):
|
||||
)()
|
||||
np.testing.assert_array_equal(iota, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("bf16_i8", jnp.bfloat16, jnp.int8),
|
||||
("i8_bf16", jnp.int8, jnp.bfloat16),
|
||||
("i8_i8", jnp.int8, jnp.int8),
|
||||
("i4_i4", jnp.int4, jnp.int4),
|
||||
("i4_bf16", jnp.int4, jnp.bfloat16),
|
||||
@parameterized.product(
|
||||
jax_dtype_from_to=(
|
||||
(jnp.int8, jnp.bfloat16),
|
||||
(jnp.int4, jnp.bfloat16),
|
||||
),
|
||||
layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X),
|
||||
)
|
||||
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
|
||||
def test_optimized_conversion(self, jax_dtype_from_to, layout):
|
||||
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
|
||||
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
|
||||
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
|
||||
m = 128
|
||||
@ -538,7 +539,7 @@ class WGMMALayoutTest(TestCase):
|
||||
smem_from,
|
||||
swizzle=128,
|
||||
is_signed=utils.is_signed(jax_dtype_from),
|
||||
layout=fa._tiled_wgmma_layout((m, n))
|
||||
layout=layout,
|
||||
)
|
||||
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
|
||||
t.store_tiled(smem_to, swizzle=128)
|
||||
|
Loading…
x
Reference in New Issue
Block a user