doc: improve docs for jax.lax trig functions

This commit is contained in:
Jake VanderPlas 2025-02-06 11:09:42 -08:00
parent 7025b7d116
commit 2fb750e0ab
2 changed files with 147 additions and 14 deletions

View File

@ -90,6 +90,7 @@ Operators
erfc
erf_inv
exp
exp2
expand_dims
expm1
fft

View File

@ -458,6 +458,7 @@ def round(x: ArrayLike,
rounding_method = RoundingMethod(rounding_method)
return round_p.bind(x, rounding_method=rounding_method)
@export
def is_finite(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{isfinite}`.
@ -478,6 +479,7 @@ def is_finite(x: ArrayLike) -> Array:
"""
return is_finite_p.bind(x)
@export
def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`.
@ -488,7 +490,7 @@ def exp(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential.
exponential.
See also:
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
@ -509,7 +511,7 @@ def exp2(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
base-2 exponential.
base-2 exponential.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
@ -520,6 +522,7 @@ def exp2(x: ArrayLike) -> Array:
"""
return exp2_p.bind(x)
@export
def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`.
@ -532,7 +535,7 @@ def expm1(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential minus 1.
exponential minus 1.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
@ -542,6 +545,7 @@ def expm1(x: ArrayLike) -> Array:
"""
return expm1_p.bind(x)
@export
def log(x: ArrayLike) -> Array:
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
@ -552,7 +556,7 @@ def log(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm.
natural logarithm.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
@ -561,8 +565,9 @@ def log(x: ArrayLike) -> Array:
"""
return log_p.bind(x)
@export
def log1p(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{log}(1 + x)`..
r"""Elementwise :math:`\mathrm{log}(1 + x)`.
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
@ -573,7 +578,7 @@ def log1p(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm of ``x + 1``.
natural logarithm of ``x + 1``.
See also:
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
@ -591,17 +596,76 @@ def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
return logistic_p.bind(x)
@export
def sin(x: ArrayLike) -> Array:
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
r"""Elementwise sine: :math:`\mathrm{sin}(x)`.
For floating-point inputs, this function lowers directly to the
`stablehlo.sine`_ operation. For complex inputs, it lowers to a
sequence of HLO operations implementing the complex sine.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
sine.
See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.asin`: elementwise arc sine.
.. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine
"""
return sin_p.bind(x)
@export
def cos(x: ArrayLike) -> Array:
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.
For floating-point inputs, this function lowers directly to the
`stablehlo.cosine`_ operation. For complex inputs, it lowers to a
sequence of HLO operations implementing the complex cosine.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
cosine.
See also:
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.acos`: elementwise arc cosine.
.. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine
"""
return cos_p.bind(x)
@export
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise arc tangent of two variables:
:math:`\mathrm{atan}({x \over y})`."""
r"""Elementwise two-term arc tangent: :math:`\mathrm{atan}({x \over y})`.
This function lowers directly to the `stablehlo.atan2`_ operation.
Args:
x, y: input arrays. Must have a matching floating-point or complex dtypes. If
neither is a scalar, the two arrays must have the same number of dimensions
and be broadcast-compatible.
Returns:
Array of the same shape and dtype as ``x`` and ``y`` containing the element-wise
arc tangent of :math:`x \over y`, respecting the quadrant indicated by the sign
of each input.
See also:
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.atan`: elementwise one-term arc tangent.
.. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2
"""
return atan2_p.bind(x, y)
def real(x: ArrayLike) -> Array:
@ -2473,20 +2537,88 @@ def reciprocal(x: ArrayLike) -> Array:
r"""Elementwise reciprocal: :math:`1 \over x`."""
return integer_pow(x, -1)
@export
def tan(x: ArrayLike) -> Array:
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.
This function lowers directly to the `stablehlo.tangent`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
tangent.
See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.atan`: elementwise arc tangent.
- :func:`jax.lax.atan2`: elementwise 2-term arc tangent.
.. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent
"""
return tan_p.bind(x)
@export
def asin(x: ArrayLike) -> Array:
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.
This function lowers directly to the ``chlo.asin`` operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arc sine.
See also:
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.acos`: elementwise arc cosine.
- :func:`jax.lax.atan`: elementwise arc tangent.
"""
return asin_p.bind(x)
@export
def acos(x: ArrayLike) -> Array:
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`.
This function lowers directly to the ``chlo.acos`` operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arc cosine.
See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.asin`: elementwise arc sine.
- :func:`jax.lax.atan`: elementwise arc tangent.
"""
return acos_p.bind(x)
@export
def atan(x: ArrayLike) -> Array:
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.
This function lowers directly to the ``chlo.atan`` operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arc tangent.
See also:
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.acos`: elementwise arc cosine.
- :func:`jax.lax.asin`: elementwise arc sine.
- :func:`jax.lax.atan2`: elementwise 2-term arc tangent.
"""
return atan_p.bind(x)
def sinh(x: ArrayLike) -> Array: