diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index c6d7c02fb..dc5ad48c4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1244,6 +1244,11 @@ class FragmentedArray: is_vector_reg = ir.VectorType.isinstance(reg_type) reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) [vector_len] = reg_shape # This is meant to be a 1D assertion. + if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8: + raise ValueError( + "Register bitwidth in target type must be divisible by 8, got" + f" {new_reg_bitwidth}" + ) if cur_dtype == i4 and self.is_signed and new_dtype == bf16: new_registers = np.empty_like(self.registers) out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) @@ -1344,11 +1349,6 @@ class FragmentedArray: _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) # Generic path. - # XLA packs elements into bytes in big-endian order, while LLVM assumes the - # same endianness as the target machine (which is little for NVIDIA GPUs). - # We'll need to add specialized casting routines that flip the endianness. - if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8: - raise NotImplementedError("Conversion involving sub-byte types unsupported") from_float = ir.FloatType.isinstance(cur_dtype) to_float = ir.FloatType.isinstance(new_dtype) from_integer = ir.IntegerType.isinstance(cur_dtype) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 574299ab1..91644be5c 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -518,6 +518,25 @@ class WGMMALayoutTest(TestCase): )() np.testing.assert_array_equal(iota, expected) + @parameterized.parameters(jnp.int8, jnp.int16, jnp.int32) + def test_sub_byte_conversion(self, jax_dtype_to): + jax_dtype_from = jnp.int4 + def kernel(ctx, inp, out, smem): + del ctx # Unused. + smem_inp, smem_out = smem + copy(inp, smem_inp, swizzle=16) + t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16) + t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True) + t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) + copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) + + x = self.prng.integers( + low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32 + ).astype(jax_dtype_from) + y = x.astype(jax_dtype_to) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y)) + np.testing.assert_array_equal(f(x), y) + @parameterized.product( jax_dtype_from_to=( (jnp.int8, jnp.bfloat16),