jax.lax: improve docs for exp & log functions

This commit is contained in:
Jake VanderPlas 2025-02-04 09:33:52 -08:00
parent 09ee37a41d
commit f180353d78

View File

@ -374,7 +374,7 @@ def floor(x: ArrayLike) -> Array:
This function lowers directly to the `stablehlo.floor`_ operation.
Args:
x: input array. Must be have floating-point type.
x: input array. Must have floating-point type.
Returns:
Array of same shape and dtype as ``x``, containing values rounded
@ -395,7 +395,7 @@ def ceil(x: ArrayLike) -> Array:
This function lowers directly to the `stablehlo.ceil`_ operation.
Args:
x: input array. Must be have floating-point type.
x: input array. Must have floating-point type.
Returns:
Array of same shape and dtype as ``x``, containing values rounded
@ -461,29 +461,126 @@ def round(x: ArrayLike,
def is_finite(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{isfinite}`.
For each element x returns `True` if and only if x is not :math:`\pm\infty` or
:math:`\mathit{NaN}`.
This function lowers directly to the `stablehlo.is_finite`_ operation.
Args:
x: input array. Must have floating-point type.
Returns:
Array of boolean dtype with the same shape as ``x``, containing ``False`` where
``x`` is :math:`\pm\infty` or :math:`\mathit{NaN}`, and ``True`` otherwise.
See also:
- :func:`jax.numpy.isinf`: return True where array is infinite.
- :func:`jax.numpy.isnan`: return True where array is NaN.
.. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite
"""
return is_finite_p.bind(x)
def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`."""
r"""Elementwise exponential: :math:`e^x`.
This function lowers directly to the `stablehlo.exponential`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential.
See also:
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
"""
return exp_p.bind(x)
def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`."""
r"""Elementwise base-2 exponential: :math:`2^x`.
This function is implemented in terms of the `stablehlo.exponential`_
and `stablehlo.multiply`_ operations.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
base-2 exponential.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
return exp2_p.bind(x)
def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`."""
r"""Elementwise :math:`e^{x} - 1`.
This function lowers directly to the `stablehlo.exponential_minus_one`_
operation. Compared to the naive expression ``lax.exp(x) - 1``, it is
more accurate for ``x`` near zero.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential minus 1.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
- :func:`jax.lax.log1p`: elementwise :math:`\mathrm{log}(1 + x)`.
.. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one
"""
return expm1_p.bind(x)
def log(x: ArrayLike) -> Array:
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
This function lowers directly to the `stablehlo.log`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
.. _stablehlo.log: https://openxla.org/stablehlo/spec#log
"""
return log_p.bind(x)
def log1p(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
r"""Elementwise :math:`\mathrm{log}(1 + x)`..
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
for ``x`` near zero.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm of ``x + 1``.
See also:
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
- :func:`jax.lax.log`: elementwise natural logarithm :math:`\mathrm{log}(x)`.
.. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one
"""
return log1p_p.bind(x)
def tanh(x: ArrayLike) -> Array: