mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
doc: improve docs for jax.lax trig functions
This commit is contained in:
parent
7025b7d116
commit
2fb750e0ab
@ -90,6 +90,7 @@ Operators
|
||||
erfc
|
||||
erf_inv
|
||||
exp
|
||||
exp2
|
||||
expand_dims
|
||||
expm1
|
||||
fft
|
||||
|
@ -458,6 +458,7 @@ def round(x: ArrayLike,
|
||||
rounding_method = RoundingMethod(rounding_method)
|
||||
return round_p.bind(x, rounding_method=rounding_method)
|
||||
|
||||
@export
|
||||
def is_finite(x: ArrayLike) -> Array:
|
||||
r"""Elementwise :math:`\mathrm{isfinite}`.
|
||||
|
||||
@ -478,6 +479,7 @@ def is_finite(x: ArrayLike) -> Array:
|
||||
"""
|
||||
return is_finite_p.bind(x)
|
||||
|
||||
@export
|
||||
def exp(x: ArrayLike) -> Array:
|
||||
r"""Elementwise exponential: :math:`e^x`.
|
||||
|
||||
@ -488,7 +490,7 @@ def exp(x: ArrayLike) -> Array:
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
exponential.
|
||||
exponential.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
|
||||
@ -509,7 +511,7 @@ def exp2(x: ArrayLike) -> Array:
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
base-2 exponential.
|
||||
base-2 exponential.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
||||
@ -520,6 +522,7 @@ def exp2(x: ArrayLike) -> Array:
|
||||
"""
|
||||
return exp2_p.bind(x)
|
||||
|
||||
@export
|
||||
def expm1(x: ArrayLike) -> Array:
|
||||
r"""Elementwise :math:`e^{x} - 1`.
|
||||
|
||||
@ -532,7 +535,7 @@ def expm1(x: ArrayLike) -> Array:
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
exponential minus 1.
|
||||
exponential minus 1.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
||||
@ -542,6 +545,7 @@ def expm1(x: ArrayLike) -> Array:
|
||||
"""
|
||||
return expm1_p.bind(x)
|
||||
|
||||
@export
|
||||
def log(x: ArrayLike) -> Array:
|
||||
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
|
||||
|
||||
@ -552,7 +556,7 @@ def log(x: ArrayLike) -> Array:
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
natural logarithm.
|
||||
natural logarithm.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
||||
@ -561,8 +565,9 @@ def log(x: ArrayLike) -> Array:
|
||||
"""
|
||||
return log_p.bind(x)
|
||||
|
||||
@export
|
||||
def log1p(x: ArrayLike) -> Array:
|
||||
r"""Elementwise :math:`\mathrm{log}(1 + x)`..
|
||||
r"""Elementwise :math:`\mathrm{log}(1 + x)`.
|
||||
|
||||
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
|
||||
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
|
||||
@ -573,7 +578,7 @@ def log1p(x: ArrayLike) -> Array:
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
natural logarithm of ``x + 1``.
|
||||
natural logarithm of ``x + 1``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
|
||||
@ -591,17 +596,76 @@ def logistic(x: ArrayLike) -> Array:
|
||||
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
|
||||
return logistic_p.bind(x)
|
||||
|
||||
@export
|
||||
def sin(x: ArrayLike) -> Array:
|
||||
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
|
||||
r"""Elementwise sine: :math:`\mathrm{sin}(x)`.
|
||||
|
||||
For floating-point inputs, this function lowers directly to the
|
||||
`stablehlo.sine`_ operation. For complex inputs, it lowers to a
|
||||
sequence of HLO operations implementing the complex sine.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
sine.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.cos`: elementwise cosine.
|
||||
- :func:`jax.lax.tan`: elementwise tangent.
|
||||
- :func:`jax.lax.asin`: elementwise arc sine.
|
||||
|
||||
.. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine
|
||||
"""
|
||||
return sin_p.bind(x)
|
||||
|
||||
@export
|
||||
def cos(x: ArrayLike) -> Array:
|
||||
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
|
||||
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.
|
||||
|
||||
For floating-point inputs, this function lowers directly to the
|
||||
`stablehlo.cosine`_ operation. For complex inputs, it lowers to a
|
||||
sequence of HLO operations implementing the complex cosine.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
cosine.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.sin`: elementwise sine.
|
||||
- :func:`jax.lax.tan`: elementwise tangent.
|
||||
- :func:`jax.lax.acos`: elementwise arc cosine.
|
||||
|
||||
.. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine
|
||||
"""
|
||||
return cos_p.bind(x)
|
||||
|
||||
@export
|
||||
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise arc tangent of two variables:
|
||||
:math:`\mathrm{atan}({x \over y})`."""
|
||||
r"""Elementwise two-term arc tangent: :math:`\mathrm{atan}({x \over y})`.
|
||||
|
||||
This function lowers directly to the `stablehlo.atan2`_ operation.
|
||||
|
||||
Args:
|
||||
x, y: input arrays. Must have a matching floating-point or complex dtypes. If
|
||||
neither is a scalar, the two arrays must have the same number of dimensions
|
||||
and be broadcast-compatible.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` and ``y`` containing the element-wise
|
||||
arc tangent of :math:`x \over y`, respecting the quadrant indicated by the sign
|
||||
of each input.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.tan`: elementwise tangent.
|
||||
- :func:`jax.lax.atan`: elementwise one-term arc tangent.
|
||||
|
||||
.. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2
|
||||
"""
|
||||
return atan2_p.bind(x, y)
|
||||
|
||||
def real(x: ArrayLike) -> Array:
|
||||
@ -2473,20 +2537,88 @@ def reciprocal(x: ArrayLike) -> Array:
|
||||
r"""Elementwise reciprocal: :math:`1 \over x`."""
|
||||
return integer_pow(x, -1)
|
||||
|
||||
@export
|
||||
def tan(x: ArrayLike) -> Array:
|
||||
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
|
||||
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.
|
||||
|
||||
This function lowers directly to the `stablehlo.tangent`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
tangent.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.cos`: elementwise cosine.
|
||||
- :func:`jax.lax.sin`: elementwise sine.
|
||||
- :func:`jax.lax.atan`: elementwise arc tangent.
|
||||
- :func:`jax.lax.atan2`: elementwise 2-term arc tangent.
|
||||
|
||||
.. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent
|
||||
"""
|
||||
return tan_p.bind(x)
|
||||
|
||||
@export
|
||||
def asin(x: ArrayLike) -> Array:
|
||||
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
|
||||
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.
|
||||
|
||||
This function lowers directly to the ``chlo.asin`` operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the
|
||||
element-wise arc sine.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.sin`: elementwise sine.
|
||||
- :func:`jax.lax.acos`: elementwise arc cosine.
|
||||
- :func:`jax.lax.atan`: elementwise arc tangent.
|
||||
"""
|
||||
return asin_p.bind(x)
|
||||
|
||||
@export
|
||||
def acos(x: ArrayLike) -> Array:
|
||||
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
|
||||
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`.
|
||||
|
||||
This function lowers directly to the ``chlo.acos`` operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the
|
||||
element-wise arc cosine.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.cos`: elementwise cosine.
|
||||
- :func:`jax.lax.asin`: elementwise arc sine.
|
||||
- :func:`jax.lax.atan`: elementwise arc tangent.
|
||||
"""
|
||||
return acos_p.bind(x)
|
||||
|
||||
@export
|
||||
def atan(x: ArrayLike) -> Array:
|
||||
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
|
||||
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.
|
||||
|
||||
This function lowers directly to the ``chlo.atan`` operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the
|
||||
element-wise arc tangent.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.tan`: elementwise tangent.
|
||||
- :func:`jax.lax.acos`: elementwise arc cosine.
|
||||
- :func:`jax.lax.asin`: elementwise arc sine.
|
||||
- :func:`jax.lax.atan2`: elementwise 2-term arc tangent.
|
||||
"""
|
||||
return atan_p.bind(x)
|
||||
|
||||
def sinh(x: ArrayLike) -> Array:
|
||||
|
Loading…
x
Reference in New Issue
Block a user