Merge pull request #27198 from jakevdp:lax-docs

PiperOrigin-RevId: 738038116
This commit is contained in:
jax authors 2025-03-18 09:38:58 -07:00
commit 30941480a1

View File

@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
"""
return tanh_p.bind(x)
@export
def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
of HLO arithmetic operations.
Args:
x: input array. Must have floating point or complex dtype.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
logistic/sigmoid function.
See also:
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
"""
return logistic_p.bind(x)
@export
@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
"""
return xor_p.bind(x, y)
@export
def population_count(x: ArrayLike) -> Array:
r"""Elementwise popcount, count the number of set bits in each element."""
r"""Elementwise popcount, count the number of set bits in each element.
This function lowers directly to the `stablehlo.popcnt`_ operation.
Args:
x: Input array. Must have integer dtype.
Returns:
An array of the same shape and dtype as ``x``, containing the number of
set bits in the input.
See also:
- :func:`jax.lax.clz`: Elementwise count leading zeros.
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
"""
return population_count_p.bind(x)
@export
def clz(x: ArrayLike) -> Array:
r"""Elementwise count-leading-zeros."""
r"""Elementwise count-leading-zeros.
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
Args:
x: Input array. Must have integer dtype.
Returns:
An array of the same shape and dtype as ``x``, containing the number of
set bits in the input.
See also:
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
"""
return clz_p.bind(x)
@export
@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
"""
return div_p.bind(x, y)
@export
def rem(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise remainder: :math:`x \bmod y`.
The sign of the result is taken from the dividend,
and the absolute value of the result is always
less than the divisor's absolute value.
This function lowers directly to the `stablehlo.remainder`_ operation.
The sign of the result is taken from the dividend, and the absolute value
of the result is always less than the divisor's absolute value.
Integer division overflow
(remainder by zero or remainder of INT_SMIN with -1)
Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
produces an implementation defined value.
Args:
x, y: Input arrays. Must have matching int or float dtypes. If neither
is a scalar, ``x`` and ``y`` must have the same number of dimensions
and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the remainder.
See also:
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
sign semantics.
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
"""
return rem_p.bind(x, y)
@export
def max(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
This function lowers directly to the `stablehlo.maximum`_ operation for
non-complex inputs. For complex numbers, this uses a lexicographic
comparison on the `(real, imaginary)` pairs.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
``x`` and ``y`` must have the same rank and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the elementwise
maximum.
See also:
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
- :func:`jax.lax.min`: elementwise minimum.
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
"""
return max_p.bind(x, y)
@export
def min(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
This function lowers directly to the `stablehlo.minimum`_ operation for
non-complex inputs. For complex numbers, this uses a lexicographic
comparison on the `(real, imaginary)` pairs.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
``x`` and ``y`` must have the same rank and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the elementwise
minimum.
See also:
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
- :func:`jax.lax.max`: elementwise maximum.
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
"""
return min_p.bind(x, y)
@export
@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
"""
return lt_p.bind(x, y)
@export
def convert_element_type(operand: ArrayLike,
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
"""Elementwise cast.
Wraps XLA's `ConvertElementType
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
operator, which performs an elementwise conversion from one type to another.
Similar to a C++ `static_cast`.
This function lowers directly to the `stablehlo.convert`_ operation, which
performs an elementwise conversion from one type to another, similar to a
C++ ``static_cast``.
Args:
operand: an array or scalar value to be cast.
new_dtype: a NumPy dtype representing the target type.
new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
or a valid dtype name) representing the target dtype.
Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
.. note::
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
the appropriate 32-bit type will be used in its place.
If the input is a JAX array and the input dtype and output dtype match, then
the input array will be returned unmodified.
See also:
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
- :meth:`jax.Array.astype`: dtype casting as an array method.
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
"""
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
@ -1500,12 +1615,11 @@ def _convert_element_type(
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
sharding=sharding)
@export
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise bitcast.
Wraps XLA's `BitcastConvertType
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
operator, which performs a bit cast from one type to another.
This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
The output shape depends on the size of the input and output dtypes with
the following logic::
@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
Returns:
An array of shape `output_shape` (see above) and type `new_dtype`,
constructed from the same bits as operand.
See also:
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
"""
new_dtype = dtypes.canonicalize_dtype(new_dtype)
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)