diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 099eed25f..0a2f699b4 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -40,6 +40,7 @@ from jax._src import util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax.control_flow import for_loop +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import math as math_dialect @@ -2053,17 +2054,22 @@ def _masked_load_lowering_rule( if not is_int4: return values - # XLA packs pairs of `[u]int4` values into a `uint8` value with the first - # in the most significant bits and the second in the least significant. + # After jaxlib 0.5.2, XLA packs pairs of `[u]int4` values into a `uint8` + # value with the first in the least significant bits and the second in the + # most significant. Before jaxlib 0.5.2, the order was reversed. if is_contiguous_int4: - even_values = arith_dialect.shrui(values, _full(values.type, 4)) - values = tt_dialect.join(even_values, values) + msb_values = arith_dialect.shrui(values, _full(values.type, 4)) + if jaxlib_version < (0, 5, 2): + values = tt_dialect.join(msb_values, values) + else: + values = tt_dialect.join(values, msb_values) shape = ir.RankedTensorType(values.type).shape values = _reshape(values, (*shape[:-2], shape[-2] * shape[-1])) else: offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False) - in_lsb = _mod(offsets, _full(offsets.type, 2), signed=False) - in_msb = arith_dialect.xori(in_lsb, _full(in_lsb.type, 1)) + in_msb = _mod(offsets, _full(offsets.type, 2), signed=False) + if jaxlib_version < (0, 5, 2): + in_msb = arith_dialect.xori(in_msb, _full(in_msb.type, 1)) shift = _mul(in_msb, _full(in_msb.type, 4)) shift = _ir_cast(shift, values.type, signed=False) values = arith_dialect.shrui(values, shift) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 6bb8d6781..51300012a 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1081,8 +1081,7 @@ class FragmentedArray: reg_8 = vector.bitcast(ir.VectorType.get((1,), i8), reg) # The algorithm here is largely the same as CUTLASS's # NumericArrayConverter specialization for int4 -> bf16 casts. - # We modify it slightly, because we only extract 2 values, and we also - # flip them to account for XLA using big-endian packing into bytes. + # We modify it slightly, because we only extract 2 values. # We first shift the value by 4 bits, to put the high int4 in low bits. # The prmt then blends the two values together, by putting them into the # low bits of each 16-bit subword of our register. Then, we use the lop3 @@ -1100,7 +1099,7 @@ class FragmentedArray: { .reg .b32 s<4>; shr.s32 s0, $1, 4; - prmt.b32 s1, $1, s0, 0xF0F4; + prmt.b32 s1, $1, s0, 0xF4F0; lop3.b32 s2, s1, 0x000F000F, 0x43084308, (0xf0 & 0xcc) ^ 0xaa; mov.b32 s3, 0x43084308; sub.bf16x2 $0, s2, s3;