Change int4 packing from big-endian to little-endian

LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.

PiperOrigin-RevId: 731731530
This commit is contained in:
Adrian Kuegel 2025-02-27 08:12:57 -08:00 committed by jax authors
parent 5ae0e58a4a
commit de4d047852
2 changed files with 14 additions and 9 deletions

View File

@ -40,6 +40,7 @@ from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax.control_flow import for_loop
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith as arith_dialect
from jax._src.lib.mlir.dialects import math as math_dialect
@ -2053,17 +2054,22 @@ def _masked_load_lowering_rule(
if not is_int4:
return values
# XLA packs pairs of `[u]int4` values into a `uint8` value with the first
# in the most significant bits and the second in the least significant.
# After jaxlib 0.5.2, XLA packs pairs of `[u]int4` values into a `uint8`
# value with the first in the least significant bits and the second in the
# most significant. Before jaxlib 0.5.2, the order was reversed.
if is_contiguous_int4:
even_values = arith_dialect.shrui(values, _full(values.type, 4))
values = tt_dialect.join(even_values, values)
msb_values = arith_dialect.shrui(values, _full(values.type, 4))
if jaxlib_version < (0, 5, 2):
values = tt_dialect.join(msb_values, values)
else:
values = tt_dialect.join(values, msb_values)
shape = ir.RankedTensorType(values.type).shape
values = _reshape(values, (*shape[:-2], shape[-2] * shape[-1]))
else:
offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False)
in_lsb = _mod(offsets, _full(offsets.type, 2), signed=False)
in_msb = arith_dialect.xori(in_lsb, _full(in_lsb.type, 1))
in_msb = _mod(offsets, _full(offsets.type, 2), signed=False)
if jaxlib_version < (0, 5, 2):
in_msb = arith_dialect.xori(in_msb, _full(in_msb.type, 1))
shift = _mul(in_msb, _full(in_msb.type, 4))
shift = _ir_cast(shift, values.type, signed=False)
values = arith_dialect.shrui(values, shift)

View File

@ -1081,8 +1081,7 @@ class FragmentedArray:
reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
# The algorithm here is largely the same as CUTLASS's
# NumericArrayConverter specialization for int4 -> bf16 casts.
# We modify it slightly, because we only extract 2 values, and we also
# flip them to account for XLA using big-endian packing into bytes.
# We modify it slightly, because we only extract 2 values.
# We first shift the value by 4 bits, to put the high int4 in low bits.
# The prmt then blends the two values together, by putting them into the
# low bits of each 16-bit subword of our register. Then, we use the lop3
@ -1100,7 +1099,7 @@ class FragmentedArray:
{
.reg .b32 s<4>;
shr.s32 s0, $1, 4;
prmt.b32 s1, $1, s0, 0xF0F4;
prmt.b32 s1, $1, s0, 0xF4F0;
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
mov.b32 s3, 0x43084308;
sub.bf16x2 $0, s2, s3;