mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #26470 from jakevdp:lax-docs
PiperOrigin-RevId: 725804083
This commit is contained in:
commit
d0b6c677b0
@ -685,34 +685,119 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"""
|
||||
return atan2_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def real(x: ArrayLike) -> Array:
|
||||
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
|
||||
|
||||
Returns the real part of a complex number.
|
||||
This function lowers directly to the `stablehlo.real`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have complex dtype.
|
||||
|
||||
Returns:
|
||||
Array of the same shape as ``x`` containing its real part. Will have dtype
|
||||
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.complex`: elementwise construct complex number.
|
||||
- :func:`jax.lax.imag`: elementwise extract imaginary part.
|
||||
- :func:`jax.lax.conj`: elementwise complex conjugate.
|
||||
|
||||
.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
|
||||
"""
|
||||
return real_p.bind(x)
|
||||
|
||||
@export
|
||||
def imag(x: ArrayLike) -> Array:
|
||||
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
|
||||
|
||||
Returns the imaginary part of a complex number.
|
||||
This function lowers directly to the `stablehlo.imag`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have complex dtype.
|
||||
|
||||
Returns:
|
||||
Array of the same shape as ``x`` containing its imaginary part. Will have dtype
|
||||
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.complex`: elementwise construct complex number.
|
||||
- :func:`jax.lax.real`: elementwise extract real part.
|
||||
- :func:`jax.lax.conj`: elementwise complex conjugate.
|
||||
|
||||
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
|
||||
"""
|
||||
return imag_p.bind(x)
|
||||
|
||||
@export
|
||||
def complex(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise make complex number: :math:`x + jy`.
|
||||
|
||||
Builds a complex number from real and imaginary parts.
|
||||
This function lowers directly to the `stablehlo.complex`_ operation.
|
||||
|
||||
Args:
|
||||
x, y: input arrays. Must have matching floating-point dtypes. If
|
||||
neither is a scalar, the two arrays must have the same number
|
||||
of dimensions and be broadcast-compatible.
|
||||
|
||||
Returns:
|
||||
The complex array with the real part given by ``x``, and the imaginary
|
||||
part given by ``y``. For inputs of dtype float32 or float64, the result
|
||||
will have dtype complex64 or complex128 respectively.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.real`: elementwise extract real part.
|
||||
- :func:`jax.lax.imag`: elementwise extract imaginary part.
|
||||
- :func:`jax.lax.conj`: elementwise complex conjugate.
|
||||
|
||||
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
|
||||
"""
|
||||
return complex_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def conj(x: ArrayLike) -> Array:
|
||||
r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
|
||||
r"""Elementwise complex conjugate function: :math:`\overline{x}`.
|
||||
|
||||
This function lowers to a combination of `stablehlo.real`_, `stablehlo.imag`_,
|
||||
and `stablehlo.complex`_.
|
||||
|
||||
Args:
|
||||
x: input array. Must have complex dtype.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing its complex conjugate.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.complex`: elementwise construct complex number.
|
||||
- :func:`jax.lax.real`: elementwise extract real part.
|
||||
- :func:`jax.lax.imag`: elementwise extract imaginary part.
|
||||
- :func:`jax.lax.abs`: elementwise absolute value / complex magnitude.
|
||||
|
||||
.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
|
||||
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
|
||||
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
|
||||
"""
|
||||
# TODO(mattjj): remove input_dtype, not needed anymore
|
||||
return conj_p.bind(x, input_dtype=_dtype(x))
|
||||
|
||||
@export
|
||||
def abs(x: ArrayLike) -> Array:
|
||||
r"""Elementwise absolute value: :math:`|x|`."""
|
||||
r"""Elementwise absolute value: :math:`|x|`.
|
||||
|
||||
This function lowers directly to the `stablehlo.abs`_ operation.
|
||||
|
||||
Args:
|
||||
x: Input array. Must have signed integer, floating, or complex dtype.
|
||||
|
||||
Returns:
|
||||
An array of the same shape as ``x`` containing the elementwise absolute value.
|
||||
For complex valued input, :math:`a + ib`, ``abs(x)`` returns :math:`\sqrt{a^2+b^2}`.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.abs`: a more flexible NumPy-style ``abs`` implementation.
|
||||
|
||||
.. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs
|
||||
"""
|
||||
return abs_p.bind(x)
|
||||
|
||||
def pow(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
|
Loading…
x
Reference in New Issue
Block a user