diff --git a/jax/BUILD b/jax/BUILD index d6f100581..12eae4afd 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -799,7 +799,7 @@ pytype_strict_library( ) # This target only supports sm_90 GPUs. -py_library( +py_library_providing_imports_info( name = "mosaic_gpu", srcs = glob(["experimental/mosaic/gpu/*.py"]), visibility = [ @@ -824,6 +824,7 @@ py_library( "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:vector_dialect", + "//jaxlib/mosaic/python:gpu_dialect", ] + py_deps("absl/flags") + py_deps("numpy"), ) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index d325b22c1..c6d7c02fb 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1244,11 +1244,10 @@ class FragmentedArray: is_vector_reg = ir.VectorType.isinstance(reg_type) reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) [vector_len] = reg_shape # This is meant to be a 1D assertion. - if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2: + if cur_dtype == i4 and self.is_signed and new_dtype == bf16: new_registers = np.empty_like(self.registers) - empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32)) + out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) for idx, reg in np.ndenumerate(self.registers): - reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg) # The algorithm here is largely the same as CUTLASS's # NumericArrayConverter specialization for int4 -> bf16 casts. # We modify it slightly, because we only extract 2 values. @@ -1262,25 +1261,41 @@ class FragmentedArray: # positive int4s will end up larger than negative int4s, with a bias of # 8. Use use the sub to subtract the base (our initial exponent) and the # bias coming from flipping the sign bit which is 136 (0x4308 as bits). - new_reg_32 = llvm.inline_asm( - i32, - [reg_8], - """ - { - .reg .b32 s<4>; - shr.s32 s0, $1, 4; - prmt.b32 s1, $1, s0, 0xF4F0; - lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa; - mov.b32 s3, 0x43084308; - sub.bf16x2 $0, s2, s3; - } - """, - "=r,r", - ) - new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) - new_registers[idx] = vector.bitcast( - ir.VectorType.get((vector_len,), new_dtype), new_vec_32 - ) + def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int): + assert 0 <= part < 4 + return llvm.inline_asm( + i32, + [reg, reg_shr], + f""" + {{ + .reg .b32 s<4>; + prmt.b32 s1, $1, $2, 0xF{part + 4}F{part}; + lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa; + mov.b32 s3, 0x43084308; + sub.bf16x2 $0, s2, s3; + }} + """, + "=r,r,r", + ) + offset = 0 + out_int_regs = [] + for group_size in (8, 4, 2): + int_ty = ir.IntegerType.get_signless(group_size * 4) + while vector_len - offset >= group_size: + reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) + reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty)) + reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) + out_int_regs.extend( + upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) + for part in range(group_size // 2) + ) + offset += group_size + assert offset == vector_len + out_vec_int = utils.vector_concat([ + vector.splat(ir.VectorType.get((1,), i32), reg) + for reg in out_int_regs + ]) + new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=None ) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 1807449f9..91cb19746 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -348,6 +348,9 @@ def bitwidth_impl(ty: ir.Type): return ir.FloatType(ty).width if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"): return MBARRIER_BYTES * 8 + if ir.VectorType.isinstance(ty): + vty = ir.VectorType(ty) + return math.prod(vty.shape) * bitwidth(vty.element_type) raise NotImplementedError(ty) @@ -1220,6 +1223,12 @@ def bitcast(x: ir.Value, new_type: ir.Type): x_ty = ir.IntegerType(x.type) assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape) return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x)) + if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type): + x_ty = ir.VectorType(x.type) + new_ty = ir.VectorType(new_type) + if bitwidth(x_ty) != bitwidth(new_ty): + raise ValueError(f"Can't bitcast {x.type} to {new_type}") + return vector.bitcast(new_type, x) raise ValueError(f"Can't bitcast {x.type} to {new_type}") @@ -1239,3 +1248,27 @@ def vector_slice(v: ir.Value, s: slice): elem = llvm.extractelement(v, c(src, i32)) result = llvm.insertelement(result, elem, c(tgt, i32)) return result + + +def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: + index = ir.IndexType.get() + if not vectors: + raise ValueError("Cannot concatenate an empty list of vectors") + vty = vectors[0].type + if not ir.VectorType.isinstance(vty): + raise ValueError("Cannot concatenate non-vector values") + if vty.rank != 1: + raise NotImplementedError("Only 1D vectors are supported") + for v in vectors: + if v.type != vty: + raise ValueError("Cannot concatenate vectors of different types") + result = llvm.mlir_undef( + ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type) + ) + offset = 0 + for v in vectors: + for i in range(vty.shape[0]): + elem = vector.extractelement(v, position=c(i, index)) + result = vector.insertelement(elem, result, position=c(offset + i, index)) + offset += vty.shape[0] + return result diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1f43b46dc..574299ab1 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -518,14 +518,15 @@ class WGMMALayoutTest(TestCase): )() np.testing.assert_array_equal(iota, expected) - @parameterized.named_parameters( - ("bf16_i8", jnp.bfloat16, jnp.int8), - ("i8_bf16", jnp.int8, jnp.bfloat16), - ("i8_i8", jnp.int8, jnp.int8), - ("i4_i4", jnp.int4, jnp.int4), - ("i4_bf16", jnp.int4, jnp.bfloat16), + @parameterized.product( + jax_dtype_from_to=( + (jnp.int8, jnp.bfloat16), + (jnp.int4, jnp.bfloat16), + ), + layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X), ) - def test_convert_tiled(self, jax_dtype_from, jax_dtype_to): + def test_optimized_conversion(self, jax_dtype_from_to, layout): + jax_dtype_from, jax_dtype_to = jax_dtype_from_to mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from) mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) m = 128 @@ -538,7 +539,7 @@ class WGMMALayoutTest(TestCase): smem_from, swizzle=128, is_signed=utils.is_signed(jax_dtype_from), - layout=fa._tiled_wgmma_layout((m, n)) + layout=layout, ) t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to)) t.store_tiled(smem_to, swizzle=128)