Merge pull request #22914 from rajasekharporeddy:testbranch1

PiperOrigin-RevId: 660421322
This commit is contained in:
jax authors 2024-08-07 09:33:43 -07:00
commit 930c8ca791

View File

@ -465,12 +465,12 @@ def rfft(a: ArrayLike, n: int | None = None,
def irfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
r"""Compute a one-dimensional inverse discrete Fourier transform for real input.
"""Compute a real-valued one-dimensional inverse discrete Fourier transform.
JAX implementation of :func:`numpy.fft.irfft`.
Args:
a: real-valued input array.
a: input array.
n: int. Specifies the dimension of the result along ``axis``. If not specified,
``n = 2*(m-1)``, where ``m`` is the dimension of ``a`` along ``axis``.
axis: int, default=-1. Specifies the axis along which the transform is computed.
@ -479,8 +479,8 @@ def irfft(a: ArrayLike, n: int | None = None,
supported.
Returns:
An array containing the one-dimensional inverse discrete Fourier transform
of ``a``, with a dimension of ``n`` along ``axis``.
A real-valued array containing the one-dimensional inverse discrete Fourier
transform of ``a``, with a dimension of ``n`` along ``axis``.
See also:
- :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete
@ -826,15 +826,157 @@ def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
norm=norm)
@implements(np.fft.rfft2)
def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
"""Compute a two-dimensional discrete Fourier transform of a real-valued array.
JAX implementation of :func:`numpy.fft.rfft2`.
Args:
a: real-valued input array. Must have ``a.ndim >= 2``.
s: optional length-2 sequence of integers. Specifies the effective size of the
output along each specified axis. If not specified, it will default to the
dimension of input along ``axes``.
axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the
axes along which the transform is computed.
norm: string, default="backward". The normalization mode. "backward", "ortho"
and "forward" are supported.
Returns:
An array containing the two-dimensional discrete Fourier transform of ``a``.
The size of the output along the axis ``axes[1]`` is ``(s[1]/2)+1``, if ``s[1]``
is even and ``(s[1]+1)/2``, if ``s[1]`` is odd. The size of the output along
the axis ``axes[0]`` is ``s[0]``.
See also:
- :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier
transform of real-valued array.
- :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier
transform of real-valued array.
- :func:`jax.numpy.fft.irfft2`: Computes a real-valued two-dimensional inverse
discrete Fourier transform.
Examples:
``jnp.fft.rfft2`` computes the transform along the last two axes by default.
>>> x = jnp.array([[[1, 3, 5],
... [2, 4, 6]],
... [[7, 9, 11],
... [8, 10, 12]]])
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.rfft2(x)
Array([[[21.+0.j , -6.+3.46j],
[-3.+0.j , 0.+0.j ]],
<BLANKLINE>
[[57.+0.j , -6.+3.46j],
[-3.+0.j , 0.+0.j ]]], dtype=complex64)
When ``s=[2, 4]``, dimension of the transform along ``axis -2`` will be
``2``, along ``axis -1`` will be ``(4/2)+1) = 3`` and dimension along other
axes will be the same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.rfft2(x, s=[2, 4])
Array([[[21. +0.j, -8. -7.j, 7. +0.j],
[-3. +0.j, 0. +1.j, -1. +0.j]],
<BLANKLINE>
[[57. +0.j, -8.-19.j, 19. +0.j],
[-3. +0.j, 0. +1.j, -1. +0.j]]], dtype=complex64)
When ``s=[3, 5]`` and ``axes=(0, 1)``, shape of the transform along ``axis 0``
will be ``3``, along ``axis 1`` will be ``(5+1)/2 = 3`` and dimension along
other axes will be same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.rfft2(x, s=[3, 5], axes=(0, 1))
Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ],
[ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j],
[ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]],
<BLANKLINE>
[[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j],
[ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j],
[ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]],
<BLANKLINE>
[[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j],
[ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j],
[ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64)
"""
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
norm=norm)
@implements(np.fft.irfft2)
def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
"""Compute a real-valued two-dimensional inverse discrete Fourier transform.
JAX implementation of :func:`numpy.fft.irfft2`.
Args:
a: input array. Must have ``a.ndim >= 2``.
s: optional length-2 sequence of integers. Specifies the size of the output
in each specified axis. If not specified, the dimension of output along
axis ``axes[1]`` is ``2*(m-1)``, ``m`` is the size of input along axis
``axes[1]`` and the dimension along other axes will be the same as that of
input.
axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the
axes along which the transform is computed.
norm: string, default="backward". The normalization mode. "backward", "ortho"
and "forward" are supported.
Returns:
A real-valued array containing the two-dimensional inverse discrete Fourier
transform of ``a``.
See also:
- :func:`jax.numpy.fft.rfft2`: Computes a two-dimensional discrete Fourier
transform of a real-valued array.
- :func:`jax.numpy.fft.irfft`: Computes a real-valued one-dimensional inverse
discrete Fourier transform.
- :func:`jax.numpy.fft.irfftn`: Computes a real-valued multidimensional inverse
discrete Fourier transform.
Examples:
``jnp.fft.ifft2`` computes the transform along the last two axes by default.
>>> x = jnp.array([[[1, 3, 5],
... [2, 4, 6]],
... [[7, 9, 11],
... [8, 10, 12]]])
>>> jnp.fft.irfft2(x)
Array([[[ 3.5, -1. , 0. , -1. ],
[-0.5, 0. , 0. , 0. ]],
<BLANKLINE>
[[ 9.5, -1. , 0. , -1. ],
[-0.5, 0. , 0. , 0. ]]], dtype=float32)
When ``s=[3, 3]``, dimension of the transform along ``axes (-2, -1)`` will be
``(3, 3)`` and dimension along other axes will be the same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.irfft2(x, s=[3, 3])
Array([[[ 1.89, -0.44, -0.44],
[ 0.22, -0.78, 0.56],
[ 0.22, 0.56, -0.78]],
<BLANKLINE>
[[ 5.89, -0.44, -0.44],
[ 1.22, -1.78, 1.56],
[ 1.22, 1.56, -1.78]]], dtype=float32)
When ``s=[2, 3]`` and ``axes=(0, 1)``, shape of the transform along
``axes (0, 1)`` will be ``(2, 3)`` and dimension along other axes will be
same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... jnp.fft.irfft2(x, s=[2, 3], axes=(0, 1))
Array([[[ 4.67, 6.67, 8.67],
[-0.33, -0.33, -0.33],
[-0.33, -0.33, -0.33]],
<BLANKLINE>
[[-3. , -3. , -3. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]]], dtype=float32)
"""
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
norm=norm)