mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add population_count primitive to lax (#2753)
* add population_count primitive (needs new jaxlib) fixes #2263 * Add popcount docs * Add population_count to lax_reference * Use int prng (since we're only testing uints) Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
2d96cfb266
commit
75617be803
@ -38,6 +38,7 @@ Operators
|
||||
bitwise_and
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
population_count
|
||||
broadcast
|
||||
broadcasted_iota
|
||||
broadcast_in_dim
|
||||
|
@ -278,6 +278,10 @@ def bitwise_xor(x: Array, y: Array) -> Array:
|
||||
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
|
||||
return xor_p.bind(x, y)
|
||||
|
||||
def population_count(x: Array) -> Array:
|
||||
r"""Elementwise popcount, count the number of set bits in each element."""
|
||||
return population_count_p.bind(x)
|
||||
|
||||
def add(x: Array, y: Array) -> Array:
|
||||
r"""Elementwise addition: :math:`x + y`."""
|
||||
return add_p.bind(x, y)
|
||||
@ -2023,6 +2027,8 @@ ad.defjvp_zero(or_p)
|
||||
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
|
||||
ad.defjvp_zero(xor_p)
|
||||
|
||||
population_count_p = standard_unop(_bool_or_int, 'population_count')
|
||||
|
||||
def _add_transpose(t, x, y):
|
||||
# The following linearity assertion is morally true, but because in some cases we
|
||||
# instantiate zeros for convenience, it doesn't always hold.
|
||||
|
@ -111,6 +111,31 @@ shift_left = onp.left_shift
|
||||
shift_right_arithmetic = onp.right_shift
|
||||
# TODO shift_right_logical
|
||||
|
||||
def population_count(x):
|
||||
assert x.dtype in (onp.uint32, onp.uint64)
|
||||
m = [
|
||||
0x5555555555555555, # binary: 0101...
|
||||
0x3333333333333333, # binary: 00110011..
|
||||
0x0f0f0f0f0f0f0f0f, # binary: 4 zeros, 4 ones ...
|
||||
0x00ff00ff00ff00ff, # binary: 8 zeros, 8 ones ...
|
||||
0x0000ffff0000ffff, # binary: 16 zeros, 16 ones ...
|
||||
0x00000000ffffffff, # binary: 32 zeros, 32 ones
|
||||
]
|
||||
|
||||
if x.dtype == onp.uint32:
|
||||
m = list(map(onp.uint32, m[:-1]))
|
||||
else:
|
||||
m = list(map(onp.uint64, m))
|
||||
|
||||
x = (x & m[0]) + ((x >> 1) & m[0]) # put count of each 2 bits into those 2 bits
|
||||
x = (x & m[1]) + ((x >> 2) & m[1]) # put count of each 4 bits into those 4 bits
|
||||
x = (x & m[2]) + ((x >> 4) & m[2]) # put count of each 8 bits into those 8 bits
|
||||
x = (x & m[3]) + ((x >> 8) & m[3]) # put count of each 16 bits into those 16 bits
|
||||
x = (x & m[4]) + ((x >> 16) & m[4]) # put count of each 32 bits into those 32 bits
|
||||
if x.dtype == onp.uint64:
|
||||
x = (x & m[5]) + ((x >> 32) & m[5]) # put count of each 64 bits into those 64 bits
|
||||
return x
|
||||
|
||||
eq = onp.equal
|
||||
ne = onp.not_equal
|
||||
ge = onp.greater_equal
|
||||
|
@ -152,6 +152,7 @@ LAX_OPS = [
|
||||
op_record("bitwise_not", 1, bool_dtypes, jtu.rand_small),
|
||||
op_record("bitwise_or", 2, bool_dtypes, jtu.rand_small),
|
||||
op_record("bitwise_xor", 2, bool_dtypes, jtu.rand_small),
|
||||
op_record("population_count", 1, uint_dtypes, partial(jtu.rand_int, 1 << 32)),
|
||||
|
||||
op_record("add", 2, default_dtypes + complex_dtypes, jtu.rand_small),
|
||||
op_record("sub", 2, default_dtypes + complex_dtypes, jtu.rand_small),
|
||||
|
Loading…
x
Reference in New Issue
Block a user