diff --git a/jax/BUILD b/jax/BUILD
index d6f100581..12eae4afd 100644
--- a/jax/BUILD
+++ b/jax/BUILD
@@ -799,7 +799,7 @@ pytype_strict_library(
 )
 
 # This target only supports sm_90 GPUs.
-py_library(
+py_library_providing_imports_info(
     name = "mosaic_gpu",
     srcs = glob(["experimental/mosaic/gpu/*.py"]),
     visibility = [
@@ -824,6 +824,7 @@ py_library(
         "//jaxlib/mlir:pass_manager",
         "//jaxlib/mlir:scf_dialect",
         "//jaxlib/mlir:vector_dialect",
+        "//jaxlib/mosaic/python:gpu_dialect",
     ] + py_deps("absl/flags") + py_deps("numpy"),
 )
 
diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py
index d325b22c1..c6d7c02fb 100644
--- a/jax/experimental/mosaic/gpu/fragmented_array.py
+++ b/jax/experimental/mosaic/gpu/fragmented_array.py
@@ -1244,11 +1244,10 @@ class FragmentedArray:
     is_vector_reg = ir.VectorType.isinstance(reg_type)
     reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
     [vector_len] = reg_shape  # This is meant to be a 1D assertion.
-    if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len == 2:
+    if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
       new_registers = np.empty_like(self.registers)
-      empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((1,), i32))
+      out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
       for idx, reg in np.ndenumerate(self.registers):
-        reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
         # The algorithm here is largely the same as CUTLASS's
         # NumericArrayConverter specialization for int4 -> bf16 casts.
         # We modify it slightly, because we only extract 2 values.
@@ -1262,25 +1261,41 @@ class FragmentedArray:
         # positive int4s will end up larger than negative int4s, with a bias of
         # 8. Use use the sub to subtract the base (our initial exponent) and the
         # bias coming from flipping the sign bit which is 136 (0x4308 as bits).
-        new_reg_32 = llvm.inline_asm(
-            i32,
-            [reg_8],
-            """
-            {
-            .reg .b32 s<4>;
-            shr.s32 s0, $1, 4;
-            prmt.b32 s1, $1, s0, 0xF4F0;
-            lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
-            mov.b32 s3, 0x43084308;
-            sub.bf16x2 $0, s2, s3;
-            }
-            """,
-            "=r,r",
-        )
-        new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32))
-        new_registers[idx] = vector.bitcast(
-            ir.VectorType.get((vector_len,), new_dtype), new_vec_32
-        )
+        def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
+          assert 0 <= part < 4
+          return llvm.inline_asm(
+              i32,
+              [reg, reg_shr],
+              f"""
+              {{
+              .reg .b32 s<4>;
+              prmt.b32 s1, $1, $2, 0xF{part + 4}F{part};
+              lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
+              mov.b32 s3, 0x43084308;
+              sub.bf16x2 $0, s2, s3;
+              }}
+              """,
+              "=r,r,r",
+          )
+        offset = 0
+        out_int_regs = []
+        for group_size in (8, 4, 2):
+          int_ty = ir.IntegerType.get_signless(group_size * 4)
+          while vector_len - offset >= group_size:
+            reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
+            reg_slice_int = arith.extsi(i32, utils.bitcast(reg_slice, int_ty))
+            reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
+            out_int_regs.extend(
+                upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
+                for part in range(group_size // 2)
+            )
+            offset += group_size
+        assert offset == vector_len
+        out_vec_int = utils.vector_concat([
+            vector.splat(ir.VectorType.get((1,), i32), reg)
+            for reg in out_int_regs
+        ])
+        new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
       return FragmentedArray(
           _registers=new_registers, _layout=self.layout, _is_signed=None
       )
diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py
index 1807449f9..91cb19746 100644
--- a/jax/experimental/mosaic/gpu/utils.py
+++ b/jax/experimental/mosaic/gpu/utils.py
@@ -348,6 +348,9 @@ def bitwidth_impl(ty: ir.Type):
     return ir.FloatType(ty).width
   if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
     return MBARRIER_BYTES * 8
+  if ir.VectorType.isinstance(ty):
+    vty = ir.VectorType(ty)
+    return math.prod(vty.shape) * bitwidth(vty.element_type)
   raise NotImplementedError(ty)
 
 
@@ -1220,6 +1223,12 @@ def bitcast(x: ir.Value, new_type: ir.Type):
     x_ty = ir.IntegerType(x.type)
     assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
     return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
+  if ir.VectorType.isinstance(x.type) and ir.VectorType.isinstance(new_type):
+    x_ty = ir.VectorType(x.type)
+    new_ty = ir.VectorType(new_type)
+    if bitwidth(x_ty) != bitwidth(new_ty):
+      raise ValueError(f"Can't bitcast {x.type} to {new_type}")
+    return vector.bitcast(new_type, x)
   raise ValueError(f"Can't bitcast {x.type} to {new_type}")
 
 
@@ -1239,3 +1248,27 @@ def vector_slice(v: ir.Value, s: slice):
     elem = llvm.extractelement(v, c(src, i32))
     result = llvm.insertelement(result, elem, c(tgt, i32))
   return result
+
+
+def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
+  index = ir.IndexType.get()
+  if not vectors:
+    raise ValueError("Cannot concatenate an empty list of vectors")
+  vty = vectors[0].type
+  if not ir.VectorType.isinstance(vty):
+    raise ValueError("Cannot concatenate non-vector values")
+  if vty.rank != 1:
+    raise NotImplementedError("Only 1D vectors are supported")
+  for v in vectors:
+    if v.type != vty:
+      raise ValueError("Cannot concatenate vectors of different types")
+  result = llvm.mlir_undef(
+      ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type)
+  )
+  offset = 0
+  for v in vectors:
+    for i in range(vty.shape[0]):
+      elem = vector.extractelement(v, position=c(i, index))
+      result = vector.insertelement(elem, result, position=c(offset + i, index))
+    offset += vty.shape[0]
+  return result
diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py
index 1f43b46dc..574299ab1 100644
--- a/tests/mosaic/gpu_test.py
+++ b/tests/mosaic/gpu_test.py
@@ -518,14 +518,15 @@ class WGMMALayoutTest(TestCase):
     )()
     np.testing.assert_array_equal(iota, expected)
 
-  @parameterized.named_parameters(
-      ("bf16_i8", jnp.bfloat16, jnp.int8),
-      ("i8_bf16", jnp.int8, jnp.bfloat16),
-      ("i8_i8", jnp.int8, jnp.int8),
-      ("i4_i4", jnp.int4, jnp.int4),
-      ("i4_bf16", jnp.int4, jnp.bfloat16),
+  @parameterized.product(
+      jax_dtype_from_to=(
+          (jnp.int8, jnp.bfloat16),
+          (jnp.int4, jnp.bfloat16),
+      ),
+      layout=(fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_UPCAST_2X),
   )
-  def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
+  def test_optimized_conversion(self, jax_dtype_from_to, layout):
+    jax_dtype_from, jax_dtype_to = jax_dtype_from_to
     mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
     mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
     m = 128
@@ -538,7 +539,7 @@ class WGMMALayoutTest(TestCase):
           smem_from,
           swizzle=128,
           is_signed=utils.is_signed(jax_dtype_from),
-          layout=fa._tiled_wgmma_layout((m, n))
+          layout=layout,
       )
       t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
       t.store_tiled(smem_to, swizzle=128)