Add jax.numpy.bitwise_count()

This commit is contained in:
Jake VanderPlas 2023-10-03 13:48:16 -07:00
parent 3d848080f9
commit a09fdf6e2f
7 changed files with 49 additions and 1 deletions

View File

@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
# jax 0.4.17
* New features
* Added new {func}`jax.numpy.bitwise_count` function, matching the API of the simlar
function recently added to NumPy.
* Deprecations
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument

View File

@ -85,6 +85,7 @@ namespace; they are listed below.
bartlett
bincount
bitwise_and
bitwise_count
bitwise_not
bitwise_or
bitwise_xor

View File

@ -197,6 +197,29 @@ def arccosh(x: ArrayLike, /) -> Array:
out = _where(real(out) < 0, lax.neg(out), out)
return out
@_wraps(getattr(np, 'bitwise_count', None), module='numpy')
@jit
def bitwise_count(x: ArrayLike, /) -> Array:
# Ref: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
check_arraylike('bitwise_count', x)
a = lax.asarray(x)
if not dtypes.issubdtype(a.dtype, np.integer):
raise ValueError('bitwise_count is implemented only for integers inputs.')
nbits = np.iinfo(a.dtype).bits
if nbits not in [8, 16, 32, 64]:
raise ValueError("bitwise_count is implemented only for 8, 16, 32, and 64-bit")
if not dtypes.issubdtype(a.dtype, np.unsignedinteger):
# For signed integers, follow numpy's convention of taking the abs value.
a = abs(a).astype(f'uint{nbits}')
B0, B1, B2, B3 = np.array(
[0x5555555555555555, 0x3333333333333333, 0x0f0f0f0f0f0f0f0f, 0x0101010101010101],
dtype='uint64').astype(a.dtype)
a = a - ((a >> 1) & B0)
a = (a & B1) + ((a >> 2) & B1)
a = (a + (a >> 4)) & B2
count = a if nbits == 8 else (a * B3) >> (nbits - 8)
# Following numpy's convention, we always return uint8.
return count.astype('uint8')
@_wraps(np.right_shift, module='numpy')
@partial(jit, inline=True)

View File

@ -334,6 +334,7 @@ from jax._src.numpy.ufuncs import (
arctan2 as arctan2,
arctanh as arctanh,
bitwise_and as bitwise_and,
bitwise_count as bitwise_count,
bitwise_not as bitwise_not,
bitwise_or as bitwise_or,
bitwise_xor as bitwise_xor,

View File

@ -133,6 +133,7 @@ bfloat16: Any
def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ...,
minlength: int = ..., *, length: Optional[int] = ...) -> Array: ...
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_count(x: ArrayLike, /) -> Array: ...
def bitwise_not(x: ArrayLike, /) -> Array: ...
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...

View File

@ -298,6 +298,10 @@ JAX_BITWISE_OP_RECORDS = [
op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
]
if hasattr(np, "bitwise_count"):
# Numpy versions after 1.26
JAX_BITWISE_OP_RECORDS.append(
op_record("bitwise_count", 1, int_dtypes, all_shapes, jtu.rand_fullrange, []))
JAX_OPERATOR_OVERLOADS = [
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
@ -570,6 +574,21 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
shape=array_shapes,
dtype=int_dtypes,
)
def testBitwiseCount(self, shape, dtype):
# np.bitwise_count added after numpy 1.26, but
# np_scalar.bit_count() is available before that.
np_fun = getattr(
np, "bitwise_count",
np.vectorize(lambda x: np.ravel(x)[0].bit_count(), otypes=['uint8']))
rng = jtu.rand_fullrange(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp.bitwise_count, args_maker)
self._CompileAndCheck(jnp.bitwise_count, args_maker)
@jtu.sample_product(
[dict(dtypes=dtypes, shapes=shapes)
for shapes in filter(

View File

@ -5413,7 +5413,7 @@ _available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all
if dtype != dtypes.bfloat16]
# TODO(jakevdp): implement missing ufuncs
UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_count'}
UNIMPLEMENTED_UFUNCS = {'spacing'}
def _all_numpy_ufuncs() -> Iterator[str]: