mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add jax.numpy.bitwise_count()
This commit is contained in:
parent
3d848080f9
commit
a09fdf6e2f
@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
# jax 0.4.17
|
||||
|
||||
* New features
|
||||
* Added new {func}`jax.numpy.bitwise_count` function, matching the API of the simlar
|
||||
function recently added to NumPy.
|
||||
* Deprecations
|
||||
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
|
||||
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument
|
||||
|
@ -85,6 +85,7 @@ namespace; they are listed below.
|
||||
bartlett
|
||||
bincount
|
||||
bitwise_and
|
||||
bitwise_count
|
||||
bitwise_not
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
|
@ -197,6 +197,29 @@ def arccosh(x: ArrayLike, /) -> Array:
|
||||
out = _where(real(out) < 0, lax.neg(out), out)
|
||||
return out
|
||||
|
||||
@_wraps(getattr(np, 'bitwise_count', None), module='numpy')
|
||||
@jit
|
||||
def bitwise_count(x: ArrayLike, /) -> Array:
|
||||
# Ref: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
|
||||
check_arraylike('bitwise_count', x)
|
||||
a = lax.asarray(x)
|
||||
if not dtypes.issubdtype(a.dtype, np.integer):
|
||||
raise ValueError('bitwise_count is implemented only for integers inputs.')
|
||||
nbits = np.iinfo(a.dtype).bits
|
||||
if nbits not in [8, 16, 32, 64]:
|
||||
raise ValueError("bitwise_count is implemented only for 8, 16, 32, and 64-bit")
|
||||
if not dtypes.issubdtype(a.dtype, np.unsignedinteger):
|
||||
# For signed integers, follow numpy's convention of taking the abs value.
|
||||
a = abs(a).astype(f'uint{nbits}')
|
||||
B0, B1, B2, B3 = np.array(
|
||||
[0x5555555555555555, 0x3333333333333333, 0x0f0f0f0f0f0f0f0f, 0x0101010101010101],
|
||||
dtype='uint64').astype(a.dtype)
|
||||
a = a - ((a >> 1) & B0)
|
||||
a = (a & B1) + ((a >> 2) & B1)
|
||||
a = (a + (a >> 4)) & B2
|
||||
count = a if nbits == 8 else (a * B3) >> (nbits - 8)
|
||||
# Following numpy's convention, we always return uint8.
|
||||
return count.astype('uint8')
|
||||
|
||||
@_wraps(np.right_shift, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
|
@ -334,6 +334,7 @@ from jax._src.numpy.ufuncs import (
|
||||
arctan2 as arctan2,
|
||||
arctanh as arctanh,
|
||||
bitwise_and as bitwise_and,
|
||||
bitwise_count as bitwise_count,
|
||||
bitwise_not as bitwise_not,
|
||||
bitwise_or as bitwise_or,
|
||||
bitwise_xor as bitwise_xor,
|
||||
|
@ -133,6 +133,7 @@ bfloat16: Any
|
||||
def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ...,
|
||||
minlength: int = ..., *, length: Optional[int] = ...) -> Array: ...
|
||||
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def bitwise_count(x: ArrayLike, /) -> Array: ...
|
||||
def bitwise_not(x: ArrayLike, /) -> Array: ...
|
||||
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
|
@ -298,6 +298,10 @@ JAX_BITWISE_OP_RECORDS = [
|
||||
op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
|
||||
jtu.rand_fullrange, []),
|
||||
]
|
||||
if hasattr(np, "bitwise_count"):
|
||||
# Numpy versions after 1.26
|
||||
JAX_BITWISE_OP_RECORDS.append(
|
||||
op_record("bitwise_count", 1, int_dtypes, all_shapes, jtu.rand_fullrange, []))
|
||||
|
||||
JAX_OPERATOR_OVERLOADS = [
|
||||
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
@ -570,6 +574,21 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=array_shapes,
|
||||
dtype=int_dtypes,
|
||||
)
|
||||
def testBitwiseCount(self, shape, dtype):
|
||||
# np.bitwise_count added after numpy 1.26, but
|
||||
# np_scalar.bit_count() is available before that.
|
||||
np_fun = getattr(
|
||||
np, "bitwise_count",
|
||||
np.vectorize(lambda x: np.ravel(x)[0].bit_count(), otypes=['uint8']))
|
||||
rng = jtu.rand_fullrange(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp.bitwise_count, args_maker)
|
||||
self._CompileAndCheck(jnp.bitwise_count, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(dtypes=dtypes, shapes=shapes)
|
||||
for shapes in filter(
|
||||
|
@ -5413,7 +5413,7 @@ _available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all
|
||||
if dtype != dtypes.bfloat16]
|
||||
|
||||
# TODO(jakevdp): implement missing ufuncs
|
||||
UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_count'}
|
||||
UNIMPLEMENTED_UFUNCS = {'spacing'}
|
||||
|
||||
|
||||
def _all_numpy_ufuncs() -> Iterator[str]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user