Add population_count primitive to lax (#2753)

* add population_count primitive (needs new jaxlib)

fixes #2263

* Add popcount docs

* Add population_count to lax_reference

* Use int prng (since we're only testing uints)

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Jamie Townsend 2020-04-28 06:32:52 +01:00 committed by GitHub
parent 2d96cfb266
commit 75617be803
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 0 deletions

View File

@ -38,6 +38,7 @@ Operators
bitwise_and
bitwise_or
bitwise_xor
population_count
broadcast
broadcasted_iota
broadcast_in_dim

View File

@ -278,6 +278,10 @@ def bitwise_xor(x: Array, y: Array) -> Array:
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
return xor_p.bind(x, y)
def population_count(x: Array) -> Array:
r"""Elementwise popcount, count the number of set bits in each element."""
return population_count_p.bind(x)
def add(x: Array, y: Array) -> Array:
r"""Elementwise addition: :math:`x + y`."""
return add_p.bind(x, y)
@ -2023,6 +2027,8 @@ ad.defjvp_zero(or_p)
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
ad.defjvp_zero(xor_p)
population_count_p = standard_unop(_bool_or_int, 'population_count')
def _add_transpose(t, x, y):
# The following linearity assertion is morally true, but because in some cases we
# instantiate zeros for convenience, it doesn't always hold.

View File

@ -111,6 +111,31 @@ shift_left = onp.left_shift
shift_right_arithmetic = onp.right_shift
# TODO shift_right_logical
def population_count(x):
assert x.dtype in (onp.uint32, onp.uint64)
m = [
0x5555555555555555, # binary: 0101...
0x3333333333333333, # binary: 00110011..
0x0f0f0f0f0f0f0f0f, # binary: 4 zeros, 4 ones ...
0x00ff00ff00ff00ff, # binary: 8 zeros, 8 ones ...
0x0000ffff0000ffff, # binary: 16 zeros, 16 ones ...
0x00000000ffffffff, # binary: 32 zeros, 32 ones
]
if x.dtype == onp.uint32:
m = list(map(onp.uint32, m[:-1]))
else:
m = list(map(onp.uint64, m))
x = (x & m[0]) + ((x >> 1) & m[0]) # put count of each 2 bits into those 2 bits
x = (x & m[1]) + ((x >> 2) & m[1]) # put count of each 4 bits into those 4 bits
x = (x & m[2]) + ((x >> 4) & m[2]) # put count of each 8 bits into those 8 bits
x = (x & m[3]) + ((x >> 8) & m[3]) # put count of each 16 bits into those 16 bits
x = (x & m[4]) + ((x >> 16) & m[4]) # put count of each 32 bits into those 32 bits
if x.dtype == onp.uint64:
x = (x & m[5]) + ((x >> 32) & m[5]) # put count of each 64 bits into those 64 bits
return x
eq = onp.equal
ne = onp.not_equal
ge = onp.greater_equal

View File

@ -152,6 +152,7 @@ LAX_OPS = [
op_record("bitwise_not", 1, bool_dtypes, jtu.rand_small),
op_record("bitwise_or", 2, bool_dtypes, jtu.rand_small),
op_record("bitwise_xor", 2, bool_dtypes, jtu.rand_small),
op_record("population_count", 1, uint_dtypes, partial(jtu.rand_int, 1 << 32)),
op_record("add", 2, default_dtypes + complex_dtypes, jtu.rand_small),
op_record("sub", 2, default_dtypes + complex_dtypes, jtu.rand_small),