[Mosaic GPU] Remove sub-byte conversion restriction

XLA:GPU recently changed its endianness to little endian to better match LLVM
and the rest of the CUDA ecosystem, so we can lift the earlier restrictions.
PiperOrigin-RevId: 737934373
This commit is contained in:
Adam Paszke 2025-03-18 03:12:23 -07:00 committed by jax authors
parent 549973dec6
commit 34cd5b0d74
2 changed files with 24 additions and 5 deletions

View File

@ -1244,6 +1244,11 @@ class FragmentedArray:
is_vector_reg = ir.VectorType.isinstance(reg_type)
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
[vector_len] = reg_shape # This is meant to be a 1D assertion.
if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8:
raise ValueError(
"Register bitwidth in target type must be divisible by 8, got"
f" {new_reg_bitwidth}"
)
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
new_registers = np.empty_like(self.registers)
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
@ -1344,11 +1349,6 @@ class FragmentedArray:
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
)
# Generic path.
# XLA packs elements into bytes in big-endian order, while LLVM assumes the
# same endianness as the target machine (which is little for NVIDIA GPUs).
# We'll need to add specialized casting routines that flip the endianness.
if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8:
raise NotImplementedError("Conversion involving sub-byte types unsupported")
from_float = ir.FloatType.isinstance(cur_dtype)
to_float = ir.FloatType.isinstance(new_dtype)
from_integer = ir.IntegerType.isinstance(cur_dtype)

View File

@ -518,6 +518,25 @@ class WGMMALayoutTest(TestCase):
)()
np.testing.assert_array_equal(iota, expected)
@parameterized.parameters(jnp.int8, jnp.int16, jnp.int32)
def test_sub_byte_conversion(self, jax_dtype_to):
jax_dtype_from = jnp.int4
def kernel(ctx, inp, out, smem):
del ctx # Unused.
smem_inp, smem_out = smem
copy(inp, smem_inp, swizzle=16)
t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16)
t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True)
t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
x = self.prng.integers(
low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32
).astype(jax_dtype_from)
y = x.astype(jax_dtype_to)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y))
np.testing.assert_array_equal(f(x), y)
@parameterized.product(
jax_dtype_from_to=(
(jnp.int8, jnp.bfloat16),