Reverts 83e60a9697ec20023f4e11169edf64e910b93031

PiperOrigin-RevId: 711403091
This commit is contained in:
Adam Paszke 2025-01-02 06:03:42 -08:00 committed by jax authors
parent 4a6cfebcea
commit dbe9ccd6dc
2 changed files with 9 additions and 49 deletions

View File

@ -1652,8 +1652,8 @@ def _reshape_lowering_rule(
)
def _compute_offsets_from_indices(
block_info: BlockInfo, nd_indexer: NDIndexer
def _compute_pointers_from_indices(
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
) -> ir.Value:
full_shape = block_info.full_shape_dtype.shape
num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape)
@ -1732,14 +1732,7 @@ def _compute_offsets_from_indices(
dim_offsets = _mul(dim_offsets, _full(dim_offsets.type, dim_stride))
offsets = _add(offsets, dim_offsets)
return offsets
def _compute_pointers_from_indices(
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
) -> ir.Value:
offsets = _compute_offsets_from_indices(block_info, nd_indexer)
return _add(_bcast_to(root_ptr, nd_indexer.get_indexer_shape()), offsets)
return _add(_bcast_to(root_ptr, indexer_shape), offsets)
@register_lowering(sp.get_p)
@ -1855,20 +1848,14 @@ def _masked_load_lowering_rule(
if not tt_dialect.PointerType.isinstance(ptr.type):
assert len(ctx.avals_in) == 1
return ptr
offsets = _compute_offsets_from_indices(block_info, idx)
ptr_offsets = offsets
if block_info.full_shape_dtype.dtype in (jnp.int4, jnp.uint4):
ptr_offsets = _floordiv(offsets, _full(offsets.type, 2), signed=False)
shape = idx.get_indexer_shape()
ptr = _add(_bcast_to(ptr, shape), ptr_offsets)
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
if mask is not None:
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), shape)
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
if other is not None:
other = _bcast_to(_ensure_ir_value(other, other_aval), shape)
values = _load(
other = _bcast_to(
_ensure_ir_value(other, other_aval), idx.get_indexer_shape()
)
return _load(
ptr,
mask=mask,
other=other,
@ -1877,19 +1864,6 @@ def _masked_load_lowering_rule(
eviction_policy=eviction_policy,
)
if block_info.full_shape_dtype.dtype not in (jnp.int4, jnp.uint4):
return values
# XLA packs pairs of `[u]int4` values into a `uint8` value with the first
# in the most significant bits and the second in the least significant.
offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False)
in_lsb = _mod(offsets, _full(offsets.type, 2), signed=False)
in_msb = arith_dialect.xori(in_lsb, _full(in_lsb.type, 1))
shift = _mul(in_msb, _full(in_msb.type, 4))
shift = _ir_cast(shift, values.type, signed=False)
values = arith_dialect.shrui(values, shift)
return _ir_cast(values, ir.IntegerType.get_signless(4), signed=False)
@register_lowering(sp.swap_p)
def _swap_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):

View File

@ -725,20 +725,6 @@ class PallasCallTest(PallasBaseTest):
)
self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3)
@parameterized.parameters(jnp.int4, jnp.uint4)
def test_subbyte_load(self, dtype):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("`[u]int4` loads only supported on GPU.")
x = jnp.arange(-128, 128, dtype=jnp.int8)
@functools.partial(self.pallas_call, out_shape=x)
def copy_kernel(x_ref, o_ref):
o_ref[()] = x_ref[()].astype(jnp.int8)
expected = x.astype(dtype).astype(jnp.int8)
self.assertAllClose(copy_kernel(x.astype(dtype)), expected)
class PallasCallInterpretTest(PallasCallTest):
INTERPRET = True