Merge pull request #26470 from jakevdp:lax-docs

PiperOrigin-RevId: 725804083
This commit is contained in:
jax authors 2025-02-11 15:58:52 -08:00
commit d0b6c677b0

View File

@ -685,34 +685,119 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array:
"""
return atan2_p.bind(x, y)
@export
def real(x: ArrayLike) -> Array:
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
Returns the real part of a complex number.
This function lowers directly to the `stablehlo.real`_ operation.
Args:
x: input array. Must have complex dtype.
Returns:
Array of the same shape as ``x`` containing its real part. Will have dtype
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.
See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.conj`: elementwise complex conjugate.
.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
"""
return real_p.bind(x)
@export
def imag(x: ArrayLike) -> Array:
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
Returns the imaginary part of a complex number.
This function lowers directly to the `stablehlo.imag`_ operation.
Args:
x: input array. Must have complex dtype.
Returns:
Array of the same shape as ``x`` containing its imaginary part. Will have dtype
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.
See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.conj`: elementwise complex conjugate.
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
"""
return imag_p.bind(x)
@export
def complex(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise make complex number: :math:`x + jy`.
Builds a complex number from real and imaginary parts.
This function lowers directly to the `stablehlo.complex`_ operation.
Args:
x, y: input arrays. Must have matching floating-point dtypes. If
neither is a scalar, the two arrays must have the same number
of dimensions and be broadcast-compatible.
Returns:
The complex array with the real part given by ``x``, and the imaginary
part given by ``y``. For inputs of dtype float32 or float64, the result
will have dtype complex64 or complex128 respectively.
See also:
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.conj`: elementwise complex conjugate.
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
"""
return complex_p.bind(x, y)
@export
def conj(x: ArrayLike) -> Array:
r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
r"""Elementwise complex conjugate function: :math:`\overline{x}`.
This function lowers to a combination of `stablehlo.real`_, `stablehlo.imag`_,
and `stablehlo.complex`_.
Args:
x: input array. Must have complex dtype.
Returns:
Array of the same shape and dtype as ``x`` containing its complex conjugate.
See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.abs`: elementwise absolute value / complex magnitude.
.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
"""
# TODO(mattjj): remove input_dtype, not needed anymore
return conj_p.bind(x, input_dtype=_dtype(x))
@export
def abs(x: ArrayLike) -> Array:
r"""Elementwise absolute value: :math:`|x|`."""
r"""Elementwise absolute value: :math:`|x|`.
This function lowers directly to the `stablehlo.abs`_ operation.
Args:
x: Input array. Must have signed integer, floating, or complex dtype.
Returns:
An array of the same shape as ``x`` containing the elementwise absolute value.
For complex valued input, :math:`a + ib`, ``abs(x)`` returns :math:`\sqrt{a^2+b^2}`.
See also:
- :func:`jax.numpy.abs`: a more flexible NumPy-style ``abs`` implementation.
.. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs
"""
return abs_p.bind(x)
def pow(x: ArrayLike, y: ArrayLike) -> Array: