mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improved docs for jnp.fft.rfftn and jnp.fft.irfftn
This commit is contained in:
parent
e57a7e3f05
commit
ff1f199d09
@ -252,17 +252,171 @@ def ifftn(a: ArrayLike, s: Shape | None = None,
|
||||
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@implements(np.fft.rfftn)
|
||||
def rfftn(a: ArrayLike, s: Shape | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
"""Compute a multidimensional discrete Fourier transform of a real-valued array.
|
||||
|
||||
JAX implementation of :func:`numpy.fft.rfftn`.
|
||||
|
||||
Args:
|
||||
a: real-valued input array.
|
||||
s: optional sequence of integers. Controls the effective size of the input
|
||||
along each specified axis. If not specified, it will default to the
|
||||
dimension of input along ``axes``.
|
||||
axes: optional sequence of integers, default=None. Specifies the axes along
|
||||
which the transform is computed. If not specified, the transform is computed
|
||||
along the last ``len(s)`` axes. If neither ``axes`` nor ``s`` is specified,
|
||||
the transform is computed along all the axes.
|
||||
norm: string, default="backward". The normalization mode. "backward", "ortho"
|
||||
and "forward" are supported.
|
||||
|
||||
Returns:
|
||||
An array containing the multidimensional discrete Fourier transform of ``a``
|
||||
having size specified in ``s`` along the axes ``axes`` except along the axis
|
||||
``axes[-1]``. The size of the output along the axis ``axes[-1]`` is
|
||||
``s[-1]//2+1``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier
|
||||
transform of real-valued array.
|
||||
- :func:`jax.numpy.fft.rfft2`: Computes a two-dimensional discrete Fourier
|
||||
transform of real-valued array.
|
||||
- :func:`jax.numpy.fft.irfftn`: Computes a real-valued multidimensional inverse
|
||||
discrete Fourier transform.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([[[1, 3, 5],
|
||||
... [2, 4, 6]],
|
||||
... [[7, 9, 11],
|
||||
... [8, 10, 12]]])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jnp.fft.rfftn(x)
|
||||
Array([[[ 78.+0.j , -12.+6.93j],
|
||||
[ -6.+0.j , 0.+0.j ]],
|
||||
<BLANKLINE>
|
||||
[[-36.+0.j , 0.+0.j ],
|
||||
[ 0.+0.j , 0.+0.j ]]], dtype=complex64)
|
||||
|
||||
When ``s=[3, 3, 4]``, size of the transform along ``axes (-3, -2)`` will
|
||||
be (3, 3), and along ``axis -1`` will be ``4//2+1 = 3`` and size along
|
||||
other axes will be the same as that of input.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jnp.fft.rfftn(x, s=[3, 3, 4])
|
||||
Array([[[ 78. +0.j , -16. -26.j , 26. +0.j ],
|
||||
[ 15. -36.37j, -16.12 +1.93j, 5. -12.12j],
|
||||
[ 15. +36.37j, 8.12-11.93j, 5. +12.12j]],
|
||||
<BLANKLINE>
|
||||
[[ -7.5 -49.36j, -20.45 +9.43j, -2.5 -16.45j],
|
||||
[-25.5 -7.79j, -0.6 +11.96j, -8.5 -2.6j ],
|
||||
[ 19.5 -12.99j, -8.33 -6.5j , 6.5 -4.33j]],
|
||||
<BLANKLINE>
|
||||
[[ -7.5 +49.36j, 12.45 -4.43j, -2.5 +16.45j],
|
||||
[ 19.5 +12.99j, 0.33 -6.5j , 6.5 +4.33j],
|
||||
[-25.5 +7.79j, 4.6 +5.04j, -8.5 +2.6j ]]], dtype=complex64)
|
||||
|
||||
When ``s=[3, 5]`` and ``axes=(0, 1)``, size of the transform along ``axis 0``
|
||||
will be ``3``, along ``axis 1`` will be ``5//2+1 = 3`` and dimension along
|
||||
other axes will be same as that of input.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jnp.fft.rfftn(x, s=[3, 5], axes=[0, 1])
|
||||
Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ],
|
||||
[ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j],
|
||||
[ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]],
|
||||
<BLANKLINE>
|
||||
[[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j],
|
||||
[ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j],
|
||||
[ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]],
|
||||
<BLANKLINE>
|
||||
[[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j],
|
||||
[ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j],
|
||||
[ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64)
|
||||
|
||||
For 1-D input:
|
||||
|
||||
>>> x1 = jnp.array([1, 2, 3, 4])
|
||||
>>> jnp.fft.rfftn(x1)
|
||||
Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64)
|
||||
"""
|
||||
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@implements(np.fft.irfftn)
|
||||
def irfftn(a: ArrayLike, s: Shape | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
"""Compute a real-valued multidimensional inverse discrete Fourier transform.
|
||||
|
||||
JAX implementation of :func:`numpy.fft.irfftn`.
|
||||
|
||||
Args:
|
||||
a: input array.
|
||||
s: optional sequence of integers. Specifies the size of the output in each
|
||||
specified axis. If not specified, the dimension of output along axis
|
||||
``axes[-1]`` is ``2*(m-1)``, ``m`` is the size of input along axis ``axes[-1]``
|
||||
and the dimension along other axes will be the same as that of input.
|
||||
axes: optional sequence of integers, default=None. Specifies the axes along
|
||||
which the transform is computed. If not specified, the transform is computed
|
||||
along the last ``len(s)`` axes. If neither ``axes`` nor ``s`` is specified,
|
||||
the transform is computed along all the axes.
|
||||
norm: string, default="backward". The normalization mode. "backward", "ortho"
|
||||
and "forward" are supported.
|
||||
|
||||
Returns:
|
||||
A real-valued array containing the multidimensional inverse discrete Fourier
|
||||
transform of ``a`` with size ``s`` along specified ``axes``, and the same as
|
||||
the input along other axes.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier
|
||||
transform of a real-valued array.
|
||||
- :func:`jax.numpy.fft.irfft`: Computes a real-valued one-dimensional inverse
|
||||
discrete Fourier transform.
|
||||
- :func:`jax.numpy.fft.irfft2`: Computes a real-valued two-dimensional inverse
|
||||
discrete Fourier transform.
|
||||
|
||||
Examples:
|
||||
``jnp.fft.irfftn`` computes the transform along all the axes by default.
|
||||
|
||||
>>> x = jnp.array([[[1, 3, 5],
|
||||
... [2, 4, 6]],
|
||||
... [[7, 9, 11],
|
||||
... [8, 10, 12]]])
|
||||
>>> jnp.fft.irfftn(x)
|
||||
Array([[[ 6.5, -1. , 0. , -1. ],
|
||||
[-0.5, 0. , 0. , 0. ]],
|
||||
<BLANKLINE>
|
||||
[[-3. , 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. , 0. ]]], dtype=float32)
|
||||
|
||||
When ``s=[3, 4]``, size of the transform along ``axes (-2, -1)`` will be
|
||||
``(3, 4)`` and size along other axes will be the same as that of input.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jnp.fft.irfftn(x, s=[3, 4])
|
||||
Array([[[ 2.33, -0.67, 0. , -0.67],
|
||||
[ 0.33, -0.74, 0. , 0.41],
|
||||
[ 0.33, 0.41, 0. , -0.74]],
|
||||
<BLANKLINE>
|
||||
[[ 6.33, -0.67, 0. , -0.67],
|
||||
[ 1.33, -1.61, 0. , 1.28],
|
||||
[ 1.33, 1.28, 0. , -1.61]]], dtype=float32)
|
||||
|
||||
When ``s=[3]`` and ``axes=[0]``, size of the transform along ``axes 0`` will
|
||||
be ``3`` and dimension along other axes will be same as that of input.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jnp.fft.irfftn(x, s=[3], axes=[0])
|
||||
Array([[[ 5., 7., 9.],
|
||||
[ 6., 8., 10.]],
|
||||
<BLANKLINE>
|
||||
[[-2., -2., -2.],
|
||||
[-2., -2., -2.]],
|
||||
<BLANKLINE>
|
||||
[[-2., -2., -2.],
|
||||
[-2., -2., -2.]]], dtype=float32)
|
||||
"""
|
||||
return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm)
|
||||
|
||||
|
||||
@ -937,7 +1091,7 @@ def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
||||
discrete Fourier transform.
|
||||
|
||||
Examples:
|
||||
``jnp.fft.ifft2`` computes the transform along the last two axes by default.
|
||||
``jnp.fft.irfft2`` computes the transform along the last two axes by default.
|
||||
|
||||
>>> x = jnp.array([[[1, 3, 5],
|
||||
... [2, 4, 6]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user