Merge pull request #26550 from jakevdp:lax-docs

PiperOrigin-RevId: 728251791
This commit is contained in:
jax authors 2025-02-18 09:56:53 -08:00
commit 63bc22d653

View File

@ -913,20 +913,106 @@ def cbrt(x: ArrayLike) -> Array:
"""
return cbrt_p.bind(x)
@export
def bitwise_not(x: ArrayLike) -> Array:
r"""Elementwise NOT: :math:`\neg x`."""
r"""Elementwise NOT: :math:`\neg x`.
This function lowers directly to the `stablehlo.not`_ operation.
Args:
x: Input array. Must have boolean or integer dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the bitwise
inversion of each entry.
See also:
- :func:`jax.numpy.invert`: NumPy wrapper for this API, also accessible
via the ``~x`` operator on JAX arrays.
- :func:`jax.lax.bitwise_and`: Elementwise AND.
- :func:`jax.lax.bitwise_or`: Elementwise OR.
- :func:`jax.lax.bitwise_xor`: Elementwise exclusive OR.
.. _stablehlo.not: https://openxla.org/stablehlo/spec#not
"""
return not_p.bind(x)
@export
def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise AND: :math:`x \wedge y`."""
r"""Elementwise AND: :math:`x \wedge y`.
This function lowers directly to the `stablehlo.and`_ operation.
Args:
x, y: Input arrays. Must have matching boolean or integer dtypes.
If neither is a scalar, ``x`` and ``y`` must have the same number
of dimensions and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the bitwise
AND of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.bitwise_and`: NumPy wrapper for this API, also accessible
via the ``x & y`` operator on JAX arrays.
- :func:`jax.lax.bitwise_not`: Elementwise NOT.
- :func:`jax.lax.bitwise_or`: Elementwise OR.
- :func:`jax.lax.bitwise_xor`: Elementwise exclusive OR.
.. _stablehlo.and: https://openxla.org/stablehlo/spec#and
"""
return and_p.bind(x, y)
@export
def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise OR: :math:`x \vee y`."""
r"""Elementwise OR: :math:`x \vee y`.
This function lowers directly to the `stablehlo.or`_ operation.
Args:
x, y: Input arrays. Must have matching boolean or integer dtypes.
If neither is a scalar, ``x`` and ``y`` must have the same number
of dimensions and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the bitwise
OR of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.invert`: NumPy wrapper for this API, also accessible
via the ``x | y`` operator on JAX arrays.
- :func:`jax.lax.bitwise_not`: Elementwise NOT.
- :func:`jax.lax.bitwise_and`: Elementwise AND.
- :func:`jax.lax.bitwise_xor`: Elementwise exclusive OR.
.. _stablehlo.or: https://openxla.org/stablehlo/spec#or
"""
return or_p.bind(x, y)
@export
def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
r"""Elementwise exclusive OR: :math:`x \oplus y`.
This function lowers directly to the `stablehlo.xor`_ operation.
Args:
x, y: Input arrays. Must have matching boolean or integer dtypes.
If neither is a scalar, ``x`` and ``y`` must have the same number
of dimensions and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the bitwise
XOR of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.bitwise_xor`: NumPy wrapper for this API, also accessible
via the ``x ^ y`` operator on JAX arrays.
- :func:`jax.lax.bitwise_not`: Elementwise NOT.
- :func:`jax.lax.bitwise_and`: Elementwise AND.
- :func:`jax.lax.bitwise_or`: Elementwise OR.
.. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor
"""
return xor_p.bind(x, y)
def population_count(x: ArrayLike) -> Array:
@ -985,16 +1071,81 @@ def min(x: ArrayLike, y: ArrayLike) -> Array:
`(real, imaginary)` pairs."""
return min_p.bind(x, y)
@export
def shift_left(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise left shift: :math:`x \ll y`."""
r"""Elementwise left shift: :math:`x \ll y`.
This function lowers directly to the `stablehlo.shift_left`_ operation.
Args:
x, y: Input arrays. Must have matching integer dtypes. If neither is a
scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the element-wise
left shift of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.left_shift`: NumPy wrapper for this API, also accessible
via the ``x << y`` operator on JAX arrays.
- :func:`jax.lax.shift_right_arithmetic`: Elementwise arithmetic right shift.
- :func:`jax.lax.shift_right_logical`: Elementwise logical right shift.
.. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left
"""
return shift_left_p.bind(x, y)
@export
def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise arithmetic right shift: :math:`x \gg y`."""
r"""Elementwise arithmetic right shift: :math:`x \gg y`.
This function lowers directly to the `stablehlo.shift_right_arithmetic`_ operation.
Args:
x, y: Input arrays. Must have matching integer dtypes. If neither is a
scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the element-wise
arithmetic right shift of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.right_shift`: NumPy wrapper for this API when applied to
signed integers, also accessible via the ``x >> y`` operator on JAX arrays
with signed integer dtype.
- :func:`jax.lax.shift_left`: Elementwise left shift.
- :func:`jax.lax.shift_right_logical`: Elementwise logical right shift.
.. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic
"""
return shift_right_arithmetic_p.bind(x, y)
@export
def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise logical right shift: :math:`x \gg y`."""
r"""Elementwise logical right shift: :math:`x \gg y`.
This function lowers directly to the `stablehlo.shift_right_logical`_ operation.
Args:
x, y: Input arrays. Must have matching integer dtypes. If neither is a
scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the element-wise
logical right shift of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.right_shift`: NumPy wrapper for this API when applied to
unsigned integers, also accessible via the ``x >> y`` operator on JAX arrays
with unsigned integer dtype.
- :func:`jax.lax.shift_left`: Elementwise left shift.
- :func:`jax.lax.shift_right_arithmetic`: Elementwise arithmetic right shift.
.. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical
"""
return shift_right_logical_p.bind(x, y)
def eq(x: ArrayLike, y: ArrayLike) -> Array: