mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Reverts 83e60a9697ec20023f4e11169edf64e910b93031
PiperOrigin-RevId: 711403091
This commit is contained in:
parent
4a6cfebcea
commit
dbe9ccd6dc
@ -1652,8 +1652,8 @@ def _reshape_lowering_rule(
|
||||
)
|
||||
|
||||
|
||||
def _compute_offsets_from_indices(
|
||||
block_info: BlockInfo, nd_indexer: NDIndexer
|
||||
def _compute_pointers_from_indices(
|
||||
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
|
||||
) -> ir.Value:
|
||||
full_shape = block_info.full_shape_dtype.shape
|
||||
num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape)
|
||||
@ -1732,14 +1732,7 @@ def _compute_offsets_from_indices(
|
||||
dim_offsets = _mul(dim_offsets, _full(dim_offsets.type, dim_stride))
|
||||
offsets = _add(offsets, dim_offsets)
|
||||
|
||||
return offsets
|
||||
|
||||
|
||||
def _compute_pointers_from_indices(
|
||||
root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer
|
||||
) -> ir.Value:
|
||||
offsets = _compute_offsets_from_indices(block_info, nd_indexer)
|
||||
return _add(_bcast_to(root_ptr, nd_indexer.get_indexer_shape()), offsets)
|
||||
return _add(_bcast_to(root_ptr, indexer_shape), offsets)
|
||||
|
||||
|
||||
@register_lowering(sp.get_p)
|
||||
@ -1855,20 +1848,14 @@ def _masked_load_lowering_rule(
|
||||
if not tt_dialect.PointerType.isinstance(ptr.type):
|
||||
assert len(ctx.avals_in) == 1
|
||||
return ptr
|
||||
|
||||
offsets = _compute_offsets_from_indices(block_info, idx)
|
||||
ptr_offsets = offsets
|
||||
|
||||
if block_info.full_shape_dtype.dtype in (jnp.int4, jnp.uint4):
|
||||
ptr_offsets = _floordiv(offsets, _full(offsets.type, 2), signed=False)
|
||||
|
||||
shape = idx.get_indexer_shape()
|
||||
ptr = _add(_bcast_to(ptr, shape), ptr_offsets)
|
||||
ptr = _compute_pointers_from_indices(ptr, block_info, idx)
|
||||
if mask is not None:
|
||||
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), shape)
|
||||
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
|
||||
if other is not None:
|
||||
other = _bcast_to(_ensure_ir_value(other, other_aval), shape)
|
||||
values = _load(
|
||||
other = _bcast_to(
|
||||
_ensure_ir_value(other, other_aval), idx.get_indexer_shape()
|
||||
)
|
||||
return _load(
|
||||
ptr,
|
||||
mask=mask,
|
||||
other=other,
|
||||
@ -1877,19 +1864,6 @@ def _masked_load_lowering_rule(
|
||||
eviction_policy=eviction_policy,
|
||||
)
|
||||
|
||||
if block_info.full_shape_dtype.dtype not in (jnp.int4, jnp.uint4):
|
||||
return values
|
||||
|
||||
# XLA packs pairs of `[u]int4` values into a `uint8` value with the first
|
||||
# in the most significant bits and the second in the least significant.
|
||||
offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False)
|
||||
in_lsb = _mod(offsets, _full(offsets.type, 2), signed=False)
|
||||
in_msb = arith_dialect.xori(in_lsb, _full(in_lsb.type, 1))
|
||||
shift = _mul(in_msb, _full(in_msb.type, 4))
|
||||
shift = _ir_cast(shift, values.type, signed=False)
|
||||
values = arith_dialect.shrui(values, shift)
|
||||
return _ir_cast(values, ir.IntegerType.get_signless(4), signed=False)
|
||||
|
||||
|
||||
@register_lowering(sp.swap_p)
|
||||
def _swap_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
|
||||
|
@ -725,20 +725,6 @@ class PallasCallTest(PallasBaseTest):
|
||||
)
|
||||
self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3)
|
||||
|
||||
@parameterized.parameters(jnp.int4, jnp.uint4)
|
||||
def test_subbyte_load(self, dtype):
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("`[u]int4` loads only supported on GPU.")
|
||||
|
||||
x = jnp.arange(-128, 128, dtype=jnp.int8)
|
||||
|
||||
@functools.partial(self.pallas_call, out_shape=x)
|
||||
def copy_kernel(x_ref, o_ref):
|
||||
o_ref[()] = x_ref[()].astype(jnp.int8)
|
||||
|
||||
expected = x.astype(dtype).astype(jnp.int8)
|
||||
self.assertAllClose(copy_kernel(x.astype(dtype)), expected)
|
||||
|
||||
|
||||
class PallasCallInterpretTest(PallasCallTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user