mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change int4 packing from big-endian to little-endian
LLVM uses little-endian format for int4 packing. To avoid converting between these formats, we should also use little-endian in XLA. PiperOrigin-RevId: 731731530
This commit is contained in:
parent
5ae0e58a4a
commit
de4d047852
@ -40,6 +40,7 @@ from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import math as math_dialect
|
||||
@ -2053,17 +2054,22 @@ def _masked_load_lowering_rule(
|
||||
if not is_int4:
|
||||
return values
|
||||
|
||||
# XLA packs pairs of `[u]int4` values into a `uint8` value with the first
|
||||
# in the most significant bits and the second in the least significant.
|
||||
# After jaxlib 0.5.2, XLA packs pairs of `[u]int4` values into a `uint8`
|
||||
# value with the first in the least significant bits and the second in the
|
||||
# most significant. Before jaxlib 0.5.2, the order was reversed.
|
||||
if is_contiguous_int4:
|
||||
even_values = arith_dialect.shrui(values, _full(values.type, 4))
|
||||
values = tt_dialect.join(even_values, values)
|
||||
msb_values = arith_dialect.shrui(values, _full(values.type, 4))
|
||||
if jaxlib_version < (0, 5, 2):
|
||||
values = tt_dialect.join(msb_values, values)
|
||||
else:
|
||||
values = tt_dialect.join(values, msb_values)
|
||||
shape = ir.RankedTensorType(values.type).shape
|
||||
values = _reshape(values, (*shape[:-2], shape[-2] * shape[-1]))
|
||||
else:
|
||||
offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False)
|
||||
in_lsb = _mod(offsets, _full(offsets.type, 2), signed=False)
|
||||
in_msb = arith_dialect.xori(in_lsb, _full(in_lsb.type, 1))
|
||||
in_msb = _mod(offsets, _full(offsets.type, 2), signed=False)
|
||||
if jaxlib_version < (0, 5, 2):
|
||||
in_msb = arith_dialect.xori(in_msb, _full(in_msb.type, 1))
|
||||
shift = _mul(in_msb, _full(in_msb.type, 4))
|
||||
shift = _ir_cast(shift, values.type, signed=False)
|
||||
values = arith_dialect.shrui(values, shift)
|
||||
|
@ -1081,8 +1081,7 @@ class FragmentedArray:
|
||||
reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg)
|
||||
# The algorithm here is largely the same as CUTLASS's
|
||||
# NumericArrayConverter specialization for int4 -> bf16 casts.
|
||||
# We modify it slightly, because we only extract 2 values, and we also
|
||||
# flip them to account for XLA using big-endian packing into bytes.
|
||||
# We modify it slightly, because we only extract 2 values.
|
||||
# We first shift the value by 4 bits, to put the high int4 in low bits.
|
||||
# The prmt then blends the two values together, by putting them into the
|
||||
# low bits of each 16-bit subword of our register. Then, we use the lop3
|
||||
@ -1100,7 +1099,7 @@ class FragmentedArray:
|
||||
{
|
||||
.reg .b32 s<4>;
|
||||
shr.s32 s0, $1, 4;
|
||||
prmt.b32 s1, $1, s0, 0xF0F4;
|
||||
prmt.b32 s1, $1, s0, 0xF4F0;
|
||||
lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa;
|
||||
mov.b32 s3, 0x43084308;
|
||||
sub.bf16x2 $0, s2, s3;
|
||||
|
Loading…
x
Reference in New Issue
Block a user