Merge pull request #21745 from pkgoogle:better_right_shift_doc

PiperOrigin-RevId: 641972495
This commit is contained in:
jax authors 2024-06-10 11:45:38 -07:00
commit f6ce973860

View File

@ -253,9 +253,51 @@ def bitwise_count(x: ArrayLike, /) -> Array:
# Following numpy we take the absolute value and return uint8.
return lax.population_count(abs(x)).astype('uint8')
@implements(np.right_shift, module='numpy')
@partial(jit, inline=True)
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
r"""Right shift the bits of ``x1`` to the amount specified in ``x2``.
LAX-backend implementation of :func:`numpy.right_shift`.
Args:
x1: Input array, only accepts unsigned integer subtypes
x2: The amount of bits to shift each element in ``x1`` to the right, only accepts
integer subtypes
Returns:
An array-like object containing the right shifted elements of ``x1`` by the
amount specified in ``x2``, with the same shape as the broadcasted shape of
``x1`` and ``x2``.
Note:
If ``x1.shape != x2.shape``, they must be compatible for broadcasting to a
shared shape, this shared shape will also be the shape of the output. Right shifting
a scalar x1 by scalar x2 is equivalent to ``x1 // 2**x2``.
Example:
>>> def print_binary(x):
... return [bin(int(val)) for val in x]
>>> x1 = jnp.array([1, 2, 4, 8])
>>> print_binary(x1)
['0b1', '0b10', '0b100', '0b1000']
>>> x2 = 1
>>> result = jnp.right_shift(x1, x2)
>>> result
Array([0, 1, 2, 4], dtype=int32)
>>> print_binary(result)
['0b0', '0b1', '0b10', '0b100']
>>> x1 = 16
>>> print_binary([x1])
['0b10000']
>>> x2 = jnp.array([1, 2, 3, 4])
>>> result = jnp.right_shift(x1, x2)
>>> result
Array([8, 4, 2, 1], dtype=int32)
>>> print_binary(result)
['0b1000', '0b100', '0b10', '0b1']
"""
x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic