mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Remove sub-byte conversion restriction
XLA:GPU recently changed its endianness to little endian to better match LLVM and the rest of the CUDA ecosystem, so we can lift the earlier restrictions. PiperOrigin-RevId: 737934373
This commit is contained in:
parent
549973dec6
commit
34cd5b0d74
@ -1244,6 +1244,11 @@ class FragmentedArray:
|
||||
is_vector_reg = ir.VectorType.isinstance(reg_type)
|
||||
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
|
||||
[vector_len] = reg_shape # This is meant to be a 1D assertion.
|
||||
if (new_reg_bitwidth := utils.bitwidth(new_dtype) * vector_len) % 8:
|
||||
raise ValueError(
|
||||
"Register bitwidth in target type must be divisible by 8, got"
|
||||
f" {new_reg_bitwidth}"
|
||||
)
|
||||
if cur_dtype == i4 and self.is_signed and new_dtype == bf16:
|
||||
new_registers = np.empty_like(self.registers)
|
||||
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
|
||||
@ -1344,11 +1349,6 @@ class FragmentedArray:
|
||||
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
|
||||
)
|
||||
# Generic path.
|
||||
# XLA packs elements into bytes in big-endian order, while LLVM assumes the
|
||||
# same endianness as the target machine (which is little for NVIDIA GPUs).
|
||||
# We'll need to add specialized casting routines that flip the endianness.
|
||||
if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8:
|
||||
raise NotImplementedError("Conversion involving sub-byte types unsupported")
|
||||
from_float = ir.FloatType.isinstance(cur_dtype)
|
||||
to_float = ir.FloatType.isinstance(new_dtype)
|
||||
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
||||
|
@ -518,6 +518,25 @@ class WGMMALayoutTest(TestCase):
|
||||
)()
|
||||
np.testing.assert_array_equal(iota, expected)
|
||||
|
||||
@parameterized.parameters(jnp.int8, jnp.int16, jnp.int32)
|
||||
def test_sub_byte_conversion(self, jax_dtype_to):
|
||||
jax_dtype_from = jnp.int4
|
||||
def kernel(ctx, inp, out, smem):
|
||||
del ctx # Unused.
|
||||
smem_inp, smem_out = smem
|
||||
copy(inp, smem_inp, swizzle=16)
|
||||
t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16)
|
||||
t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True)
|
||||
t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
|
||||
copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize)
|
||||
|
||||
x = self.prng.integers(
|
||||
low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32
|
||||
).astype(jax_dtype_from)
|
||||
y = x.astype(jax_dtype_to)
|
||||
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y))
|
||||
np.testing.assert_array_equal(f(x), y)
|
||||
|
||||
@parameterized.product(
|
||||
jax_dtype_from_to=(
|
||||
(jnp.int8, jnp.bfloat16),
|
||||
|
Loading…
x
Reference in New Issue
Block a user