diff --git a/CHANGELOG.md b/CHANGELOG.md index 86fa67d9a..4b3a7ef8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list # jax 0.4.17 +* New features + * Added new {func}`jax.numpy.bitwise_count` function, matching the API of the simlar + function recently added to NumPy. * Deprecations * Removed the deprecated module `jax.abstract_arrays` and all its contents. * Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index a595fd1f7..4194baff1 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -85,6 +85,7 @@ namespace; they are listed below. bartlett bincount bitwise_and + bitwise_count bitwise_not bitwise_or bitwise_xor diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index b0dcd8a0d..d0f27cdff 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -197,6 +197,29 @@ def arccosh(x: ArrayLike, /) -> Array: out = _where(real(out) < 0, lax.neg(out), out) return out +@_wraps(getattr(np, 'bitwise_count', None), module='numpy') +@jit +def bitwise_count(x: ArrayLike, /) -> Array: + # Ref: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + check_arraylike('bitwise_count', x) + a = lax.asarray(x) + if not dtypes.issubdtype(a.dtype, np.integer): + raise ValueError('bitwise_count is implemented only for integers inputs.') + nbits = np.iinfo(a.dtype).bits + if nbits not in [8, 16, 32, 64]: + raise ValueError("bitwise_count is implemented only for 8, 16, 32, and 64-bit") + if not dtypes.issubdtype(a.dtype, np.unsignedinteger): + # For signed integers, follow numpy's convention of taking the abs value. + a = abs(a).astype(f'uint{nbits}') + B0, B1, B2, B3 = np.array( + [0x5555555555555555, 0x3333333333333333, 0x0f0f0f0f0f0f0f0f, 0x0101010101010101], + dtype='uint64').astype(a.dtype) + a = a - ((a >> 1) & B0) + a = (a & B1) + ((a >> 2) & B1) + a = (a + (a >> 4)) & B2 + count = a if nbits == 8 else (a * B3) >> (nbits - 8) + # Following numpy's convention, we always return uint8. + return count.astype('uint8') @_wraps(np.right_shift, module='numpy') @partial(jit, inline=True) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 5f69f3d8c..63839b26b 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -334,6 +334,7 @@ from jax._src.numpy.ufuncs import ( arctan2 as arctan2, arctanh as arctanh, bitwise_and as bitwise_and, + bitwise_count as bitwise_count, bitwise_not as bitwise_not, bitwise_or as bitwise_or, bitwise_xor as bitwise_xor, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 1766d745e..8a8e36c94 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -133,6 +133,7 @@ bfloat16: Any def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ..., minlength: int = ..., *, length: Optional[int] = ...) -> Array: ... def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_not(x: ArrayLike, /) -> Array: ... def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 391d1f3fb..4ae817041 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -298,6 +298,10 @@ JAX_BITWISE_OP_RECORDS = [ op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_fullrange, []), ] +if hasattr(np, "bitwise_count"): + # Numpy versions after 1.26 + JAX_BITWISE_OP_RECORDS.append( + op_record("bitwise_count", 1, int_dtypes, all_shapes, jtu.rand_fullrange, [])) JAX_OPERATOR_OVERLOADS = [ op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []), @@ -570,6 +574,21 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase): self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + shape=array_shapes, + dtype=int_dtypes, + ) + def testBitwiseCount(self, shape, dtype): + # np.bitwise_count added after numpy 1.26, but + # np_scalar.bit_count() is available before that. + np_fun = getattr( + np, "bitwise_count", + np.vectorize(lambda x: np.ravel(x)[0].bit_count(), otypes=['uint8'])) + rng = jtu.rand_fullrange(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp.bitwise_count, args_maker) + self._CompileAndCheck(jnp.bitwise_count, args_maker) + @jtu.sample_product( [dict(dtypes=dtypes, shapes=shapes) for shapes in filter( diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index bbf47dbae..a1d07e208 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5413,7 +5413,7 @@ _available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all if dtype != dtypes.bfloat16] # TODO(jakevdp): implement missing ufuncs -UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_count'} +UNIMPLEMENTED_UFUNCS = {'spacing'} def _all_numpy_ufuncs() -> Iterator[str]: