Improved docs for jnp.fft.rfft and irfft

This commit is contained in:
rajasekharporeddy 2024-07-23 11:33:03 +05:30
parent a18872aa13
commit 1650d1e8aa

View File

@ -400,15 +400,124 @@ def ifft(a: ArrayLike, n: int | None = None,
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis,
norm=norm)
@implements(np.fft.rfft)
def rfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
r"""Compute a one-dimensional discrete Fourier transform of a real-valued array.
JAX implementation of :func:`numpy.fft.rfft`.
Args:
a: real-valued input array.
n: int. Specifies the effective dimension of the input along ``axis``. If not
specified, it will default to the dimension of input along ``axis``.
axis: int, default=-1. Specifies the axis along which the transform is computed.
If not specified, the transform is computed along axis -1.
norm: string. The normalization mode. "backward", "ortho" and "forward" are
supported.
Returns:
An array containing the one-dimensional discrete Fourier transform of ``a``.
The dimension of the array along ``axis`` is ``(n/2)+1``, if ``n`` is even and
``(n+1)/2``, if ``n`` is odd.
See also:
- :func:`jax.numpy.fft.fft`: Computes a one-dimensional discrete Fourier
transform.
- :func:`jax.numpy.fft.irfft`: Computes a one-dimensional inverse discrete
Fourier transform for real input.
- :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier
transform for real input.
- :func:`jax.numpy.fft.irfftn`: Computes a multidimensional inverse discrete
Fourier transform for real input.
Examples:
``jnp.fft.rfft`` computes the transform along ``axis -1`` by default.
>>> x = jnp.array([[1, 3, 5],
... [2, 4, 6]])
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.rfft(x)
Array([[ 9.+0.j , -3.+1.73j],
[12.+0.j , -3.+1.73j]], dtype=complex64)
When ``n=5``, dimension of the transform along axis -1 will be ``(5+1)/2 =3``
and dimension along other axes will be the same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.rfft(x, n=5)
Array([[ 9. +0.j , -2.12-5.79j, 0.12+2.99j],
[12. +0.j , -1.62-7.33j, 0.62+3.36j]], dtype=complex64)
When ``n=4`` and ``axis=0``, dimension of the transform along ``axis 0`` will
be ``(4/2)+1 =3`` and dimension along other axes will be same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.rfft(x, n=4, axis=0)
Array([[ 3.+0.j, 7.+0.j, 11.+0.j],
[ 1.-2.j, 3.-4.j, 5.-6.j],
[-1.+0.j, -1.+0.j, -1.+0.j]], dtype=complex64)
"""
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis,
norm=norm)
@implements(np.fft.irfft)
def irfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
r"""Compute a one-dimensional inverse discrete Fourier transform for real input.
JAX implementation of :func:`numpy.fft.irfft`.
Args:
a: real-valued input array.
n: int. Specifies the dimension of the result along ``axis``. If not specified,
``n = 2*(m-1)``, where ``m`` is the dimension of ``a`` along ``axis``.
axis: int, default=-1. Specifies the axis along which the transform is computed.
If not specified, the transform is computed along axis -1.
norm: string. The normalization mode. "backward", "ortho" and "forward" are
supported.
Returns:
An array containing the one-dimensional inverse discrete Fourier transform
of ``a``, with a dimension of ``n`` along ``axis``.
See also:
- :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete
Fourier transform.
- :func:`jax.numpy.fft.irfft`: Computes a one-dimensional inverse discrete
Fourier transform for real input.
- :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier
transform for real input.
- :func:`jax.numpy.fft.irfftn`: Computes a multidimensional inverse discrete
Fourier transform for real input.
Examples:
``jnp.fft.rfft`` computes the transform along ``axis -1`` by default.
>>> x = jnp.array([[1, 3, 5],
... [2, 4, 6]])
>>> jnp.fft.irfft(x)
Array([[ 3., -1., 0., -1.],
[ 4., -1., 0., -1.]], dtype=float32)
When ``n=3``, dimension of the transform along axis -1 will be ``3`` and
dimension along other axes will be the same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.irfft(x, n=3)
Array([[ 2.33, -0.67, -0.67],
[ 3.33, -0.67, -0.67]], dtype=float32)
When ``n=4`` and ``axis=0``, dimension of the transform along ``axis 0`` will
be ``4`` and dimension along other axes will be same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.irfft(x, n=4, axis=0)
Array([[ 1.25, 2.75, 4.25],
[ 0.25, 0.75, 1.25],
[-0.75, -1.25, -1.75],
[ 0.25, 0.75, 1.25]], dtype=float32)
"""
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis,
norm=norm)