2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2020-10-16 18:08:20 -04:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
from collections.abc import Sequence
|
2021-02-15 20:43:55 +01:00
|
|
|
import operator
|
2020-10-16 18:08:20 -04:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from jax import lax
|
2025-02-06 14:48:10 -08:00
|
|
|
from jax._src import dtypes
|
2021-09-23 06:33:25 -07:00
|
|
|
from jax._src.lib import xla_client
|
2021-02-15 20:43:55 +01:00
|
|
|
from jax._src.util import safe_zip
|
2025-01-23 08:48:13 -08:00
|
|
|
from jax._src.numpy.util import ensure_arraylike, promote_dtypes_inexact
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax._src.numpy import lax_numpy as jnp
|
2023-03-08 10:29:04 -08:00
|
|
|
from jax._src.numpy import ufuncs, reductions
|
2024-07-29 12:40:28 -07:00
|
|
|
from jax._src.sharding import Sharding
|
|
|
|
from jax._src.typing import Array, ArrayLike, DTypeLike
|
2020-10-16 18:08:20 -04:00
|
|
|
|
2022-10-04 15:52:52 -07:00
|
|
|
Shape = Sequence[int]
|
2020-10-16 18:08:20 -04:00
|
|
|
|
2022-10-04 15:52:52 -07:00
|
|
|
def _fft_norm(s: Array, func_name: str, norm: str) -> Array:
|
2022-07-07 11:53:58 -07:00
|
|
|
if norm == "backward":
|
2022-10-04 15:52:52 -07:00
|
|
|
return jnp.array(1)
|
2024-02-05 18:01:48 -05:00
|
|
|
|
|
|
|
# Avoid potential integer overflow
|
|
|
|
s, = promote_dtypes_inexact(s)
|
|
|
|
|
|
|
|
if norm == "ortho":
|
2023-03-08 10:29:04 -08:00
|
|
|
return ufuncs.sqrt(reductions.prod(s)) if func_name.startswith('i') else 1/ufuncs.sqrt(reductions.prod(s))
|
2021-10-20 12:46:50 +01:00
|
|
|
elif norm == "forward":
|
2023-03-08 10:29:04 -08:00
|
|
|
return reductions.prod(s) if func_name.startswith('i') else 1/reductions.prod(s)
|
2021-10-20 12:46:50 +01:00
|
|
|
raise ValueError(f'Invalid norm value {norm}; should be "backward",'
|
|
|
|
'"ortho" or "forward".')
|
|
|
|
|
|
|
|
|
2024-10-10 08:06:50 -07:00
|
|
|
def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
2023-12-11 13:59:29 +00:00
|
|
|
s: Shape | None, axes: Sequence[int] | None,
|
|
|
|
norm: str | None) -> Array:
|
2023-12-28 19:38:13 +03:00
|
|
|
full_name = f"jax.numpy.fft.{func_name}"
|
2025-01-23 08:48:13 -08:00
|
|
|
arr = ensure_arraylike(full_name, a)
|
2022-10-04 15:52:52 -07:00
|
|
|
|
2020-10-16 18:08:20 -04:00
|
|
|
if s is not None:
|
2021-02-15 20:43:55 +01:00
|
|
|
s = tuple(map(operator.index, s))
|
|
|
|
if np.any(np.less(s, 0)):
|
|
|
|
raise ValueError("Shape should be non-negative.")
|
2021-10-20 12:46:50 +01:00
|
|
|
|
2020-10-16 18:08:20 -04:00
|
|
|
if s is not None and axes is not None and len(s) != len(axes):
|
|
|
|
# Same error as numpy.
|
|
|
|
raise ValueError("Shape and axes have different lengths.")
|
|
|
|
|
|
|
|
orig_axes = axes
|
|
|
|
if axes is None:
|
|
|
|
if s is None:
|
2022-10-04 15:52:52 -07:00
|
|
|
axes = range(arr.ndim)
|
2020-10-16 18:08:20 -04:00
|
|
|
else:
|
2022-10-04 15:52:52 -07:00
|
|
|
axes = range(arr.ndim - len(s), arr.ndim)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
if len(axes) != len(set(axes)):
|
|
|
|
raise ValueError(
|
2022-05-12 19:13:00 +01:00
|
|
|
f"{full_name} does not support repeated axes. Got axes {axes}.")
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
|
|
|
|
if orig_axes is not None:
|
2022-10-04 15:52:52 -07:00
|
|
|
axes = tuple(range(arr.ndim - len(axes), arr.ndim))
|
|
|
|
arr = jnp.moveaxis(arr, orig_axes, axes)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
2021-02-15 20:43:55 +01:00
|
|
|
if s is not None:
|
2022-10-04 15:52:52 -07:00
|
|
|
in_s = list(arr.shape)
|
2021-02-15 20:43:55 +01:00
|
|
|
for axis, x in safe_zip(axes, s):
|
|
|
|
in_s[axis] = x
|
2024-10-10 08:06:50 -07:00
|
|
|
if fft_type == lax.FftType.IRFFT:
|
2021-02-15 20:43:55 +01:00
|
|
|
in_s[-1] = (in_s[-1] // 2 + 1)
|
|
|
|
# Cropping
|
2022-10-04 15:52:52 -07:00
|
|
|
arr = arr[tuple(map(slice, in_s))]
|
2021-02-15 20:43:55 +01:00
|
|
|
# Padding
|
2022-10-04 15:52:52 -07:00
|
|
|
arr = jnp.pad(arr, [(0, x-y) for x, y in zip(in_s, arr.shape)])
|
2021-02-15 20:43:55 +01:00
|
|
|
else:
|
2024-10-10 08:06:50 -07:00
|
|
|
if fft_type == lax.FftType.IRFFT:
|
2022-10-04 15:52:52 -07:00
|
|
|
s = [arr.shape[axis] for axis in axes[:-1]]
|
2020-10-16 18:08:20 -04:00
|
|
|
if axes:
|
2022-10-04 15:52:52 -07:00
|
|
|
s += [max(0, 2 * (arr.shape[axes[-1]] - 1))]
|
2020-10-16 18:08:20 -04:00
|
|
|
else:
|
2022-10-04 15:52:52 -07:00
|
|
|
s = [arr.shape[axis] for axis in axes]
|
2024-12-19 12:02:36 +00:00
|
|
|
transformed = _fft_core_nd(arr, fft_type, s)
|
2022-07-07 11:53:58 -07:00
|
|
|
if norm is not None:
|
|
|
|
transformed *= _fft_norm(
|
|
|
|
jnp.array(s, dtype=transformed.dtype), func_name, norm)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
if orig_axes is not None:
|
|
|
|
transformed = jnp.moveaxis(transformed, axes, orig_axes)
|
|
|
|
return transformed
|
|
|
|
|
|
|
|
|
2024-12-19 12:02:36 +00:00
|
|
|
def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array:
|
|
|
|
# XLA supports N-D transforms up to N=3 so we use XLA's FFT N-D directly.
|
|
|
|
if len(s) <= 3:
|
|
|
|
return lax.fft(arr, fft_type, tuple(s))
|
|
|
|
|
|
|
|
# For larger N, we repeatedly apply N<=3 transforms until we reach the
|
|
|
|
# requested dimension. We special case N=4 to use two 2-D transforms instead
|
|
|
|
# of one 3-D and one 1-D, since we typically expect better accelerator
|
|
|
|
# performance when N>1.
|
|
|
|
n = 2 if len(s) == 4 else 3
|
|
|
|
src = tuple(range(arr.ndim - len(s), arr.ndim - n))
|
|
|
|
dst = tuple(range(arr.ndim - len(s) + n, arr.ndim))
|
|
|
|
if fft_type in {lax.FftType.RFFT, lax.FftType.FFT}:
|
|
|
|
arr = lax.fft(arr, fft_type, tuple(s)[-n:])
|
|
|
|
arr = jnp.moveaxis(arr, src, dst)
|
|
|
|
arr = _fft_core_nd(arr, lax.FftType.FFT, s[:-n])
|
|
|
|
arr = jnp.moveaxis(arr, dst, src)
|
|
|
|
else:
|
|
|
|
arr = jnp.moveaxis(arr, src, dst)
|
|
|
|
arr = _fft_core_nd(arr, lax.FftType.IFFT, s[:-n])
|
|
|
|
arr = jnp.moveaxis(arr, dst, src)
|
|
|
|
arr = lax.fft(arr, fft_type, tuple(s)[-n:])
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def fftn(a: ArrayLike, s: Shape | None = None,
|
|
|
|
axes: Sequence[int] | None = None,
|
|
|
|
norm: str | None = None) -> Array:
|
2024-07-17 07:28:52 +05:30
|
|
|
r"""Compute a multidimensional discrete Fourier transform along given axes.
|
2024-07-09 12:18:23 +05:30
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.fftn`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array
|
|
|
|
s: sequence of integers. Specifies the shape of the result. If not specified,
|
|
|
|
it will default to the shape of ``a`` along the specified ``axes``.
|
|
|
|
axes: sequence of integers, default=None. Specifies the axes along which the
|
|
|
|
transform is computed.
|
|
|
|
norm: string. The normalization mode. "backward", "ortho" and "forward" are
|
|
|
|
supported.
|
|
|
|
|
|
|
|
Returns:
|
2024-07-17 07:28:52 +05:30
|
|
|
An array containing the multidimensional discrete Fourier transform of ``a``.
|
2024-07-09 12:18:23 +05:30
|
|
|
|
|
|
|
See also:
|
2024-07-17 07:28:52 +05:30
|
|
|
- :func:`jax.numpy.fft.fft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete
|
|
|
|
Fourier transform.
|
|
|
|
- :func:`jax.numpy.fft.ifftn`: Computes a multidimensional inverse discrete
|
|
|
|
Fourier transform.
|
2024-07-09 12:18:23 +05:30
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.fftn`` computes the transform along all the axes by default when
|
|
|
|
``axes`` argument is ``None``.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[1, 2, 5, 6],
|
|
|
|
... [4, 1, 3, 7],
|
|
|
|
... [5, 9, 2, 1]])
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.fftn(x)
|
|
|
|
Array([[ 46. +0.j , 0. +2.j , -6. +0.j , 0. -2.j ],
|
|
|
|
[ -2. +1.73j, 6.12+6.73j, 0. -1.73j, -18.12-3.27j],
|
|
|
|
[ -2. -1.73j, -18.12+3.27j, 0. +1.73j, 6.12-6.73j]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[2]``, dimension of the transform along ``axis -1`` will be ``2``
|
|
|
|
and dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jax.numpy.fft.fftn(x, s=[2]))
|
|
|
|
[[ 3.+0.j -1.+0.j]
|
|
|
|
[ 5.+0.j 3.+0.j]
|
|
|
|
[14.+0.j -4.+0.j]]
|
|
|
|
|
|
|
|
When ``s=[2]`` and ``axes=[0]``, dimension of the transform along ``axis 0``
|
|
|
|
will be ``2`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jax.numpy.fft.fftn(x, s=[2], axes=[0]))
|
|
|
|
[[ 5.+0.j 3.+0.j 8.+0.j 13.+0.j]
|
|
|
|
[-3.+0.j 1.+0.j 2.+0.j -1.+0.j]]
|
|
|
|
|
|
|
|
When ``s=[2, 3]``, shape of the transform will be ``(2, 3)``.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jax.numpy.fft.fftn(x, s=[2, 3]))
|
|
|
|
[[16. +0.j -0.5+4.33j -0.5-4.33j]
|
|
|
|
[ 0. +0.j -4.5+0.87j -4.5-0.87j]]
|
|
|
|
|
|
|
|
``jnp.fft.ifftn`` can be used to reconstruct ``x`` from the result of
|
|
|
|
``jnp.fft.fftn``.
|
|
|
|
|
|
|
|
>>> x_fftn = jnp.fft.fftn(x)
|
|
|
|
>>> jnp.allclose(x, jnp.fft.ifftn(x_fftn))
|
|
|
|
Array(True, dtype=bool)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core('fftn', lax.FftType.FFT, a, s, axes, norm)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def ifftn(a: ArrayLike, s: Shape | None = None,
|
|
|
|
axes: Sequence[int] | None = None,
|
|
|
|
norm: str | None = None) -> Array:
|
2024-07-17 07:28:52 +05:30
|
|
|
r"""Compute a multidimensional inverse discrete Fourier transform.
|
2024-07-10 20:45:35 +05:30
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.ifftn`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array
|
|
|
|
s: sequence of integers. Specifies the shape of the result. If not specified,
|
|
|
|
it will default to the shape of ``a`` along the specified ``axes``.
|
|
|
|
axes: sequence of integers, default=None. Specifies the axes along which the
|
|
|
|
transform is computed. If None, computes the transform along all the axes.
|
|
|
|
norm: string. The normalization mode. "backward", "ortho" and "forward" are
|
|
|
|
supported.
|
|
|
|
|
|
|
|
Returns:
|
2024-07-17 07:28:52 +05:30
|
|
|
An array containing the multidimensional inverse discrete Fourier transform
|
|
|
|
of ``a``.
|
2024-07-10 20:45:35 +05:30
|
|
|
|
|
|
|
See also:
|
2024-07-17 07:28:52 +05:30
|
|
|
- :func:`jax.numpy.fft.fftn`: Computes a multidimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.fft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete
|
|
|
|
Fourier transform.
|
2024-07-10 20:45:35 +05:30
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.ifftn`` computes the transform along all the axes by default when
|
|
|
|
``axes`` argument is ``None``.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[1, 2, 5, 3],
|
|
|
|
... [4, 1, 2, 6],
|
|
|
|
... [5, 3, 2, 1]])
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.ifftn(x))
|
|
|
|
[[ 2.92+0.j 0.08-0.33j 0.25+0.j 0.08+0.33j]
|
|
|
|
[-0.08+0.14j -0.04-0.03j 0. -0.29j -1.05-0.11j]
|
|
|
|
[-0.08-0.14j -1.05+0.11j 0. +0.29j -0.04+0.03j]]
|
|
|
|
|
|
|
|
When ``s=[3]``, dimension of the transform along ``axis -1`` will be ``3``
|
|
|
|
and dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.ifftn(x, s=[3]))
|
|
|
|
[[ 2.67+0.j -0.83-0.87j -0.83+0.87j]
|
|
|
|
[ 2.33+0.j 0.83-0.29j 0.83+0.29j]
|
|
|
|
[ 3.33+0.j 0.83+0.29j 0.83-0.29j]]
|
|
|
|
|
|
|
|
When ``s=[2]`` and ``axes=[0]``, dimension of the transform along ``axis 0``
|
|
|
|
will be ``2`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.ifftn(x, s=[2], axes=[0]))
|
|
|
|
[[ 2.5+0.j 1.5+0.j 3.5+0.j 4.5+0.j]
|
|
|
|
[-1.5+0.j 0.5+0.j 1.5+0.j -1.5+0.j]]
|
|
|
|
|
|
|
|
When ``s=[2, 3]``, shape of the transform will be ``(2, 3)``.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.ifftn(x, s=[2, 3]))
|
|
|
|
[[ 2.5 +0.j 0. -0.58j 0. +0.58j]
|
|
|
|
[ 0.17+0.j -0.83-0.29j -0.83+0.29j]]
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core('ifftn', lax.FftType.IFFT, a, s, axes, norm)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def rfftn(a: ArrayLike, s: Shape | None = None,
|
|
|
|
axes: Sequence[int] | None = None,
|
|
|
|
norm: str | None = None) -> Array:
|
2024-08-09 23:07:17 +05:30
|
|
|
"""Compute a multidimensional discrete Fourier transform of a real-valued array.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.rfftn`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: real-valued input array.
|
|
|
|
s: optional sequence of integers. Controls the effective size of the input
|
|
|
|
along each specified axis. If not specified, it will default to the
|
|
|
|
dimension of input along ``axes``.
|
|
|
|
axes: optional sequence of integers, default=None. Specifies the axes along
|
|
|
|
which the transform is computed. If not specified, the transform is computed
|
|
|
|
along the last ``len(s)`` axes. If neither ``axes`` nor ``s`` is specified,
|
|
|
|
the transform is computed along all the axes.
|
|
|
|
norm: string, default="backward". The normalization mode. "backward", "ortho"
|
|
|
|
and "forward" are supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the multidimensional discrete Fourier transform of ``a``
|
|
|
|
having size specified in ``s`` along the axes ``axes`` except along the axis
|
|
|
|
``axes[-1]``. The size of the output along the axis ``axes[-1]`` is
|
|
|
|
``s[-1]//2+1``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform of real-valued array.
|
|
|
|
- :func:`jax.numpy.fft.rfft2`: Computes a two-dimensional discrete Fourier
|
|
|
|
transform of real-valued array.
|
|
|
|
- :func:`jax.numpy.fft.irfftn`: Computes a real-valued multidimensional inverse
|
|
|
|
discrete Fourier transform.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> x = jnp.array([[[1, 3, 5],
|
|
|
|
... [2, 4, 6]],
|
|
|
|
... [[7, 9, 11],
|
|
|
|
... [8, 10, 12]]])
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfftn(x)
|
|
|
|
Array([[[ 78.+0.j , -12.+6.93j],
|
|
|
|
[ -6.+0.j , 0.+0.j ]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[-36.+0.j , 0.+0.j ],
|
|
|
|
[ 0.+0.j , 0.+0.j ]]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[3, 3, 4]``, size of the transform along ``axes (-3, -2)`` will
|
|
|
|
be (3, 3), and along ``axis -1`` will be ``4//2+1 = 3`` and size along
|
|
|
|
other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfftn(x, s=[3, 3, 4])
|
|
|
|
Array([[[ 78. +0.j , -16. -26.j , 26. +0.j ],
|
|
|
|
[ 15. -36.37j, -16.12 +1.93j, 5. -12.12j],
|
|
|
|
[ 15. +36.37j, 8.12-11.93j, 5. +12.12j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ -7.5 -49.36j, -20.45 +9.43j, -2.5 -16.45j],
|
|
|
|
[-25.5 -7.79j, -0.6 +11.96j, -8.5 -2.6j ],
|
|
|
|
[ 19.5 -12.99j, -8.33 -6.5j , 6.5 -4.33j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ -7.5 +49.36j, 12.45 -4.43j, -2.5 +16.45j],
|
|
|
|
[ 19.5 +12.99j, 0.33 -6.5j , 6.5 +4.33j],
|
|
|
|
[-25.5 +7.79j, 4.6 +5.04j, -8.5 +2.6j ]]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[3, 5]`` and ``axes=(0, 1)``, size of the transform along ``axis 0``
|
|
|
|
will be ``3``, along ``axis 1`` will be ``5//2+1 = 3`` and dimension along
|
|
|
|
other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfftn(x, s=[3, 5], axes=[0, 1])
|
|
|
|
Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ],
|
|
|
|
[ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j],
|
|
|
|
[ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j],
|
|
|
|
[ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j],
|
|
|
|
[ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j],
|
|
|
|
[ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j],
|
|
|
|
[ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64)
|
|
|
|
|
|
|
|
For 1-D input:
|
|
|
|
|
|
|
|
>>> x1 = jnp.array([1, 2, 3, 4])
|
|
|
|
>>> jnp.fft.rfftn(x1)
|
|
|
|
Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core('rfftn', lax.FftType.RFFT, a, s, axes, norm)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def irfftn(a: ArrayLike, s: Shape | None = None,
|
|
|
|
axes: Sequence[int] | None = None,
|
|
|
|
norm: str | None = None) -> Array:
|
2024-08-09 23:07:17 +05:30
|
|
|
"""Compute a real-valued multidimensional inverse discrete Fourier transform.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.irfftn`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array.
|
|
|
|
s: optional sequence of integers. Specifies the size of the output in each
|
|
|
|
specified axis. If not specified, the dimension of output along axis
|
|
|
|
``axes[-1]`` is ``2*(m-1)``, ``m`` is the size of input along axis ``axes[-1]``
|
|
|
|
and the dimension along other axes will be the same as that of input.
|
|
|
|
axes: optional sequence of integers, default=None. Specifies the axes along
|
|
|
|
which the transform is computed. If not specified, the transform is computed
|
|
|
|
along the last ``len(s)`` axes. If neither ``axes`` nor ``s`` is specified,
|
|
|
|
the transform is computed along all the axes.
|
|
|
|
norm: string, default="backward". The normalization mode. "backward", "ortho"
|
|
|
|
and "forward" are supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A real-valued array containing the multidimensional inverse discrete Fourier
|
|
|
|
transform of ``a`` with size ``s`` along specified ``axes``, and the same as
|
|
|
|
the input along other axes.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier
|
|
|
|
transform of a real-valued array.
|
|
|
|
- :func:`jax.numpy.fft.irfft`: Computes a real-valued one-dimensional inverse
|
|
|
|
discrete Fourier transform.
|
|
|
|
- :func:`jax.numpy.fft.irfft2`: Computes a real-valued two-dimensional inverse
|
|
|
|
discrete Fourier transform.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.irfftn`` computes the transform along all the axes by default.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[[1, 3, 5],
|
|
|
|
... [2, 4, 6]],
|
|
|
|
... [[7, 9, 11],
|
|
|
|
... [8, 10, 12]]])
|
|
|
|
>>> jnp.fft.irfftn(x)
|
|
|
|
Array([[[ 6.5, -1. , 0. , -1. ],
|
|
|
|
[-0.5, 0. , 0. , 0. ]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[-3. , 0. , 0. , 0. ],
|
|
|
|
[ 0. , 0. , 0. , 0. ]]], dtype=float32)
|
|
|
|
|
|
|
|
When ``s=[3, 4]``, size of the transform along ``axes (-2, -1)`` will be
|
|
|
|
``(3, 4)`` and size along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.irfftn(x, s=[3, 4])
|
|
|
|
Array([[[ 2.33, -0.67, 0. , -0.67],
|
|
|
|
[ 0.33, -0.74, 0. , 0.41],
|
|
|
|
[ 0.33, 0.41, 0. , -0.74]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ 6.33, -0.67, 0. , -0.67],
|
|
|
|
[ 1.33, -1.61, 0. , 1.28],
|
|
|
|
[ 1.33, 1.28, 0. , -1.61]]], dtype=float32)
|
|
|
|
|
|
|
|
When ``s=[3]`` and ``axes=[0]``, size of the transform along ``axes 0`` will
|
|
|
|
be ``3`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.irfftn(x, s=[3], axes=[0])
|
|
|
|
Array([[[ 5., 7., 9.],
|
|
|
|
[ 6., 8., 10.]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[-2., -2., -2.],
|
|
|
|
[-2., -2., -2.]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[-2., -2., -2.],
|
|
|
|
[-2., -2., -2.]]], dtype=float32)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core('irfftn', lax.FftType.IRFFT, a, s, axes, norm)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def _axis_check_1d(func_name: str, axis: int | None):
|
2023-12-28 19:38:13 +03:00
|
|
|
full_name = f"jax.numpy.fft.{func_name}"
|
2020-10-16 18:08:20 -04:00
|
|
|
if isinstance(axis, (list, tuple)):
|
|
|
|
raise ValueError(
|
|
|
|
"%s does not support multiple axes. Please use %sn. "
|
|
|
|
"Got axis = %r." % (full_name, full_name, axis)
|
|
|
|
)
|
|
|
|
|
2024-10-10 08:06:50 -07:00
|
|
|
def _fft_core_1d(func_name: str, fft_type: lax.FftType,
|
2023-12-11 13:59:29 +00:00
|
|
|
a: ArrayLike, n: int | None, axis: int | None,
|
|
|
|
norm: str | None) -> Array:
|
2020-10-16 18:08:20 -04:00
|
|
|
_axis_check_1d(func_name, axis)
|
|
|
|
axes = None if axis is None else [axis]
|
2021-02-15 20:43:55 +01:00
|
|
|
s = None if n is None else [n]
|
2020-10-16 18:08:20 -04:00
|
|
|
return _fft_core(func_name, fft_type, a, s, axes, norm)
|
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def fft(a: ArrayLike, n: int | None = None,
|
|
|
|
axis: int = -1, norm: str | None = None) -> Array:
|
2024-07-17 07:28:52 +05:30
|
|
|
r"""Compute a one-dimensional discrete Fourier transform along a given axis.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.fft`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array
|
|
|
|
n: int. Specifies the dimension of the result along ``axis``. If not specified,
|
|
|
|
it will default to the dimension of ``a`` along ``axis``.
|
|
|
|
axis: int, default=-1. Specifies the axis along which the transform is computed.
|
|
|
|
If not specified, the transform is computed along axis -1.
|
|
|
|
norm: string. The normalization mode. "backward", "ortho" and "forward" are
|
|
|
|
supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the one-dimensional discrete Fourier transform of ``a``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete
|
|
|
|
Fourier transform.
|
|
|
|
- :func:`jax.numpy.fft.fftn`: Computes a multidimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.ifftn`: Computes a multidimensional inverse discrete
|
|
|
|
Fourier transform.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.fft`` computes the transform along ``axis -1`` by default.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[1, 2, 4, 7],
|
|
|
|
... [5, 3, 1, 9]])
|
|
|
|
>>> jnp.fft.fft(x)
|
|
|
|
Array([[14.+0.j, -3.+5.j, -4.+0.j, -3.-5.j],
|
|
|
|
[18.+0.j, 4.+6.j, -6.+0.j, 4.-6.j]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``n=3``, dimension of the transform along axis -1 will be ``3`` and
|
|
|
|
dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.fft(x, n=3))
|
|
|
|
[[ 7.+0.j -2.+1.73j -2.-1.73j]
|
|
|
|
[ 9.+0.j 3.-1.73j 3.+1.73j]]
|
|
|
|
|
|
|
|
When ``n=3`` and ``axis=0``, dimension of the transform along ``axis 0`` will
|
|
|
|
be ``3`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.fft(x, n=3, axis=0))
|
|
|
|
[[ 6. +0.j 5. +0.j 5. +0.j 16. +0.j ]
|
|
|
|
[-1.5-4.33j 0.5-2.6j 3.5-0.87j 2.5-7.79j]
|
|
|
|
[-1.5+4.33j 0.5+2.6j 3.5+0.87j 2.5+7.79j]]
|
|
|
|
|
|
|
|
``jnp.fft.ifft`` can be used to reconstruct ``x`` from the result of
|
|
|
|
``jnp.fft.fft``.
|
|
|
|
|
|
|
|
>>> x_fft = jnp.fft.fft(x)
|
|
|
|
>>> jnp.allclose(x, jnp.fft.ifft(x_fft))
|
|
|
|
Array(True, dtype=bool)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_1d('fft', lax.FftType.FFT, a, n=n, axis=axis,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm)
|
|
|
|
|
2024-07-17 07:28:52 +05:30
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def ifft(a: ArrayLike, n: int | None = None,
|
|
|
|
axis: int = -1, norm: str | None = None) -> Array:
|
2024-07-17 07:28:52 +05:30
|
|
|
r"""Compute a one-dimensional inverse discrete Fourier transform.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.ifft`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array
|
|
|
|
n: int. Specifies the dimension of the result along ``axis``. If not specified,
|
|
|
|
it will default to the dimension of ``a`` along ``axis``.
|
|
|
|
axis: int, default=-1. Specifies the axis along which the transform is computed.
|
|
|
|
If not specified, the transform is computed along axis -1.
|
|
|
|
norm: string. The normalization mode. "backward", "ortho" and "forward" are
|
|
|
|
supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the one-dimensional discrete Fourier transform of ``a``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.fft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.fftn`: Computes a multidimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.ifftn`: Computes a multidimensional inverse of discrete
|
|
|
|
Fourier transform.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.ifft`` computes the transform along ``axis -1`` by default.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[3, 1, 4, 6],
|
|
|
|
... [2, 5, 7, 1]])
|
|
|
|
>>> jnp.fft.ifft(x)
|
|
|
|
Array([[ 3.5 +0.j , -0.25-1.25j, 0. +0.j , -0.25+1.25j],
|
|
|
|
[ 3.75+0.j , -1.25+1.j , 0.75+0.j , -1.25-1.j ]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``n=5``, dimension of the transform along axis -1 will be ``5`` and
|
|
|
|
dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.ifft(x, n=5))
|
|
|
|
[[ 2.8 +0.j -0.96-0.04j 1.06+0.5j 1.06-0.5j -0.96+0.04j]
|
|
|
|
[ 3. +0.j -0.59+1.66j 0.09-0.55j 0.09+0.55j -0.59-1.66j]]
|
|
|
|
|
|
|
|
When ``n=3`` and ``axis=0``, dimension of the transform along ``axis 0`` will
|
|
|
|
be ``3`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.ifft(x, n=3, axis=0))
|
|
|
|
[[ 1.67+0.j 2. +0.j 3.67+0.j 2.33+0.j ]
|
|
|
|
[ 0.67+0.58j -0.5 +1.44j 0.17+2.02j 1.83+0.29j]
|
|
|
|
[ 0.67-0.58j -0.5 -1.44j 0.17-2.02j 1.83-0.29j]]
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_1d('ifft', lax.FftType.IFFT, a, n=n, axis=axis,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm)
|
|
|
|
|
2024-07-23 11:33:03 +05:30
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def rfft(a: ArrayLike, n: int | None = None,
|
|
|
|
axis: int = -1, norm: str | None = None) -> Array:
|
2024-07-23 11:33:03 +05:30
|
|
|
r"""Compute a one-dimensional discrete Fourier transform of a real-valued array.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.rfft`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: real-valued input array.
|
|
|
|
n: int. Specifies the effective dimension of the input along ``axis``. If not
|
|
|
|
specified, it will default to the dimension of input along ``axis``.
|
|
|
|
axis: int, default=-1. Specifies the axis along which the transform is computed.
|
|
|
|
If not specified, the transform is computed along axis -1.
|
|
|
|
norm: string. The normalization mode. "backward", "ortho" and "forward" are
|
|
|
|
supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the one-dimensional discrete Fourier transform of ``a``.
|
|
|
|
The dimension of the array along ``axis`` is ``(n/2)+1``, if ``n`` is even and
|
|
|
|
``(n+1)/2``, if ``n`` is odd.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.fft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.irfft`: Computes a one-dimensional inverse discrete
|
|
|
|
Fourier transform for real input.
|
|
|
|
- :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier
|
|
|
|
transform for real input.
|
|
|
|
- :func:`jax.numpy.fft.irfftn`: Computes a multidimensional inverse discrete
|
|
|
|
Fourier transform for real input.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.rfft`` computes the transform along ``axis -1`` by default.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[1, 3, 5],
|
|
|
|
... [2, 4, 6]])
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfft(x)
|
|
|
|
Array([[ 9.+0.j , -3.+1.73j],
|
|
|
|
[12.+0.j , -3.+1.73j]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``n=5``, dimension of the transform along axis -1 will be ``(5+1)/2 =3``
|
|
|
|
and dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfft(x, n=5)
|
|
|
|
Array([[ 9. +0.j , -2.12-5.79j, 0.12+2.99j],
|
|
|
|
[12. +0.j , -1.62-7.33j, 0.62+3.36j]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``n=4`` and ``axis=0``, dimension of the transform along ``axis 0`` will
|
|
|
|
be ``(4/2)+1 =3`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfft(x, n=4, axis=0)
|
|
|
|
Array([[ 3.+0.j, 7.+0.j, 11.+0.j],
|
|
|
|
[ 1.-2.j, 3.-4.j, 5.-6.j],
|
|
|
|
[-1.+0.j, -1.+0.j, -1.+0.j]], dtype=complex64)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_1d('rfft', lax.FftType.RFFT, a, n=n, axis=axis,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm)
|
|
|
|
|
2024-07-23 11:33:03 +05:30
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def irfft(a: ArrayLike, n: int | None = None,
|
|
|
|
axis: int = -1, norm: str | None = None) -> Array:
|
2024-08-07 17:59:53 +05:30
|
|
|
"""Compute a real-valued one-dimensional inverse discrete Fourier transform.
|
2024-07-23 11:33:03 +05:30
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.irfft`.
|
|
|
|
|
|
|
|
Args:
|
2024-08-07 17:59:53 +05:30
|
|
|
a: input array.
|
2024-07-23 11:33:03 +05:30
|
|
|
n: int. Specifies the dimension of the result along ``axis``. If not specified,
|
|
|
|
``n = 2*(m-1)``, where ``m`` is the dimension of ``a`` along ``axis``.
|
|
|
|
axis: int, default=-1. Specifies the axis along which the transform is computed.
|
|
|
|
If not specified, the transform is computed along axis -1.
|
|
|
|
norm: string. The normalization mode. "backward", "ortho" and "forward" are
|
|
|
|
supported.
|
|
|
|
|
|
|
|
Returns:
|
2024-08-07 17:59:53 +05:30
|
|
|
A real-valued array containing the one-dimensional inverse discrete Fourier
|
|
|
|
transform of ``a``, with a dimension of ``n`` along ``axis``.
|
2024-07-23 11:33:03 +05:30
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete
|
|
|
|
Fourier transform.
|
|
|
|
- :func:`jax.numpy.fft.irfft`: Computes a one-dimensional inverse discrete
|
|
|
|
Fourier transform for real input.
|
|
|
|
- :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier
|
|
|
|
transform for real input.
|
|
|
|
- :func:`jax.numpy.fft.irfftn`: Computes a multidimensional inverse discrete
|
|
|
|
Fourier transform for real input.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.rfft`` computes the transform along ``axis -1`` by default.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[1, 3, 5],
|
|
|
|
... [2, 4, 6]])
|
|
|
|
>>> jnp.fft.irfft(x)
|
|
|
|
Array([[ 3., -1., 0., -1.],
|
|
|
|
[ 4., -1., 0., -1.]], dtype=float32)
|
|
|
|
|
|
|
|
When ``n=3``, dimension of the transform along axis -1 will be ``3`` and
|
|
|
|
dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.irfft(x, n=3)
|
|
|
|
Array([[ 2.33, -0.67, -0.67],
|
|
|
|
[ 3.33, -0.67, -0.67]], dtype=float32)
|
|
|
|
|
|
|
|
When ``n=4`` and ``axis=0``, dimension of the transform along ``axis 0`` will
|
|
|
|
be ``4`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.irfft(x, n=4, axis=0)
|
|
|
|
Array([[ 1.25, 2.75, 4.25],
|
|
|
|
[ 0.25, 0.75, 1.25],
|
|
|
|
[-0.75, -1.25, -1.75],
|
|
|
|
[ 0.25, 0.75, 1.25]], dtype=float32)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_1d('irfft', lax.FftType.IRFFT, a, n=n, axis=axis,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm)
|
|
|
|
|
2024-08-05 21:53:29 +05:30
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def hfft(a: ArrayLike, n: int | None = None,
|
|
|
|
axis: int = -1, norm: str | None = None) -> Array:
|
2024-08-05 21:53:29 +05:30
|
|
|
"""Compute a 1-D FFT of an array whose spectrum has Hermitian symmetry.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.hfft`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array.
|
|
|
|
n: optional, int. Specifies the dimension of the result along ``axis``. If
|
|
|
|
not specified, ``n = 2*(m-1)``, where ``m`` is the dimension of ``a``
|
|
|
|
along ``axis``.
|
|
|
|
axis: optional, int, default=-1. Specifies the axis along which the transform
|
|
|
|
is computed. If not specified, the transform is computed along axis -1.
|
|
|
|
norm: optional, string. The normalization mode. "backward", "ortho" and "forward"
|
|
|
|
are supported. Default is "backward".
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A real-valued array containing the one-dimensional discret Fourier transform
|
|
|
|
of ``a`` by exploiting its inherent Hermitian-symmetry, having a dimension of
|
|
|
|
``n`` along ``axis``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.ihfft`: Computes a one-dimensional inverse FFT of an
|
|
|
|
array whose spectrum has Hermitian symmetry.
|
|
|
|
- :func:`jax.numpy.fft.fft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform of a real-valued input.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> x = jnp.array([[1, 3, 5, 7],
|
|
|
|
... [2, 4, 6, 8]])
|
|
|
|
>>> jnp.fft.hfft(x)
|
|
|
|
Array([[24., -8., 0., -2., 0., -8.],
|
|
|
|
[30., -8., 0., -2., 0., -8.]], dtype=float32)
|
|
|
|
|
|
|
|
This value is equal to the real component of the discrete Fourier transform
|
|
|
|
of the following array ``x1`` computed using ``jnp.fft.fft``.
|
|
|
|
|
|
|
|
>>> x1 = jnp.array([[1, 3, 5, 7, 5, 3],
|
|
|
|
... [2, 4, 6, 8, 6, 4]])
|
|
|
|
>>> jnp.fft.fft(x1)
|
|
|
|
Array([[24.+0.j, -8.+0.j, 0.+0.j, -2.+0.j, 0.+0.j, -8.+0.j],
|
|
|
|
[30.+0.j, -8.+0.j, 0.+0.j, -2.+0.j, 0.+0.j, -8.+0.j]], dtype=complex64)
|
|
|
|
>>> jnp.allclose(jnp.fft.hfft(x), jnp.fft.fft(x1))
|
|
|
|
Array(True, dtype=bool)
|
|
|
|
|
|
|
|
To obtain an odd-length output from ``jnp.fft.hfft``, ``n`` must be specified
|
|
|
|
with an odd value, as the default behavior produces an even-length result
|
|
|
|
along the specified ``axis``.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... print(jnp.fft.hfft(x, n=5))
|
|
|
|
[[17. -5.24 -0.76 -0.76 -5.24]
|
|
|
|
[22. -5.24 -0.76 -0.76 -5.24]]
|
|
|
|
|
|
|
|
When ``n=3`` and ``axis=0``, dimension of the transform along ``axis 0`` will
|
|
|
|
be ``3`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> jnp.fft.hfft(x, n=3, axis=0)
|
|
|
|
Array([[ 5., 11., 17., 23.],
|
|
|
|
[-1., -1., -1., -1.],
|
|
|
|
[-1., -1., -1., -1.]], dtype=float32)
|
|
|
|
|
|
|
|
``x`` can be reconstructed (but of complex datatype) using ``jnp.fft.ihfft``
|
|
|
|
from the result of ``jnp.fft.hfft``, only when ``n`` is specified as ``2*(m-1)``
|
|
|
|
if `m` is even or ``2*m-1`` if ``m`` is odd, where ``m`` is the dimension of
|
|
|
|
input along ``axis``.
|
|
|
|
|
|
|
|
>>> jnp.fft.ihfft(jnp.fft.hfft(x, 2*(x.shape[-1]-1)))
|
|
|
|
Array([[1.+0.j, 3.+0.j, 5.+0.j, 7.+0.j],
|
|
|
|
[2.+0.j, 4.+0.j, 6.+0.j, 8.+0.j]], dtype=complex64)
|
|
|
|
>>> jnp.allclose(x, jnp.fft.ihfft(jnp.fft.hfft(x, 2*(x.shape[-1]-1))))
|
|
|
|
Array(True, dtype=bool)
|
|
|
|
|
|
|
|
For complex-valued inputs:
|
|
|
|
|
|
|
|
>>> x2 = jnp.array([[1+2j, 3-4j, 5+6j],
|
|
|
|
... [2-3j, 4+5j, 6-7j]])
|
|
|
|
>>> jnp.fft.hfft(x2)
|
|
|
|
Array([[ 12., -12., 0., 4.],
|
|
|
|
[ 16., 6., 0., -14.]], dtype=float32)
|
|
|
|
"""
|
2023-03-08 10:29:04 -08:00
|
|
|
conj_a = ufuncs.conj(a)
|
2020-10-16 18:08:20 -04:00
|
|
|
_axis_check_1d('hfft', axis)
|
2022-10-04 15:52:52 -07:00
|
|
|
nn = (conj_a.shape[axis] - 1) * 2 if n is None else n
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_1d('hfft', lax.FftType.IRFFT, conj_a, n=n, axis=axis,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm) * nn
|
|
|
|
|
2024-08-05 21:53:29 +05:30
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def ihfft(a: ArrayLike, n: int | None = None,
|
|
|
|
axis: int = -1, norm: str | None = None) -> Array:
|
2024-08-05 21:53:29 +05:30
|
|
|
r"""Compute a 1-D inverse FFT of an array whose spectrum has Hermitian-symmetry.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.ihfft`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array.
|
|
|
|
n: optional, int. Specifies the effective dimension of the input along ``axis``.
|
|
|
|
If not specified, it will default to the dimension of input along ``axis``.
|
|
|
|
axis: optional, int, default=-1. Specifies the axis along which the transform
|
|
|
|
is computed. If not specified, the transform is computed along axis -1.
|
|
|
|
norm: optional, string. The normalization mode. "backward", "ortho" and "forward"
|
|
|
|
are supported. Default is "backward".
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing one-dimensional discrete Fourier transform of ``a`` by
|
|
|
|
exploiting its inherent Hermitian symmetry. The dimension of the array along
|
|
|
|
``axis`` is ``(n/2)+1``, if ``n`` is even and ``(n+1)/2``, if ``n`` is odd.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.hfft`: Computes a one-dimensional FFT of an array
|
|
|
|
whose spectrum has Hermitian symmetry.
|
|
|
|
- :func:`jax.numpy.fft.fft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform of a real-valued input.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> x = jnp.array([[1, 3, 5, 7],
|
|
|
|
... [2, 4, 6, 8]])
|
|
|
|
>>> jnp.fft.ihfft(x)
|
|
|
|
Array([[ 4.+0.j, -1.-1.j, -1.-0.j],
|
|
|
|
[ 5.+0.j, -1.-1.j, -1.-0.j]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``n=4`` and ``axis=0``, dimension of the transform along ``axis 0`` will
|
|
|
|
be ``(4/2)+1 =3`` and dimension along other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> jnp.fft.ihfft(x, n=4, axis=0)
|
|
|
|
Array([[ 0.75+0.j , 1.75+0.j , 2.75+0.j , 3.75+0.j ],
|
|
|
|
[ 0.25+0.5j, 0.75+1.j , 1.25+1.5j, 1.75+2.j ],
|
|
|
|
[-0.25-0.j , -0.25-0.j , -0.25-0.j , -0.25-0.j ]], dtype=complex64)
|
|
|
|
"""
|
2020-10-16 18:08:20 -04:00
|
|
|
_axis_check_1d('ihfft', axis)
|
2022-10-04 15:52:52 -07:00
|
|
|
arr = jnp.asarray(a)
|
|
|
|
nn = arr.shape[axis] if n is None else n
|
2024-10-10 08:06:50 -07:00
|
|
|
output = _fft_core_1d('ihfft', lax.FftType.RFFT, arr, n=n, axis=axis,
|
2022-10-04 15:52:52 -07:00
|
|
|
norm=norm)
|
2023-03-08 10:29:04 -08:00
|
|
|
return ufuncs.conj(output) * (1 / nn)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
|
2024-10-10 08:06:50 -07:00
|
|
|
def _fft_core_2d(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
2023-12-11 13:59:29 +00:00
|
|
|
s: Shape | None, axes: Sequence[int],
|
|
|
|
norm: str | None) -> Array:
|
2023-12-28 19:38:13 +03:00
|
|
|
full_name = f"jax.numpy.fft.{func_name}"
|
2020-10-16 18:08:20 -04:00
|
|
|
if len(axes) != 2:
|
|
|
|
raise ValueError(
|
|
|
|
"%s only supports 2 axes. Got axes = %r."
|
|
|
|
% (full_name, axes)
|
|
|
|
)
|
|
|
|
return _fft_core(func_name, fft_type, a, s, axes, norm)
|
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
|
|
|
norm: str | None = None) -> Array:
|
2024-07-23 23:51:34 +05:30
|
|
|
"""Compute a two-dimensional discrete Fourier transform along given axes.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.fft2`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array. Must have ``a.ndim >= 2``.
|
|
|
|
s: optional length-2 sequence of integers. Specifies the size of the output
|
|
|
|
along each specified axis. If not specified, it will default to the size
|
|
|
|
of ``a`` along the specified ``axes``.
|
|
|
|
axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the
|
|
|
|
axes along which the transform is computed.
|
|
|
|
norm: string, default="backward". The normalization mode. "backward", "ortho"
|
|
|
|
and "forward" are supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the two-dimensional discrete Fourier transform of ``a``
|
|
|
|
along given ``axes``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.fft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.fftn`: Computes a multidimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
- :func:`jax.numpy.fft.ifft2`: Computes a two-dimensional inverse discrete
|
|
|
|
Fourier transform.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.fft2`` computes the transform along the last two axes by default.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[[1, 3],
|
|
|
|
... [2, 4]],
|
|
|
|
... [[5, 7],
|
|
|
|
... [6, 8]]])
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.fft2(x)
|
|
|
|
Array([[[10.+0.j, -4.+0.j],
|
|
|
|
[-2.+0.j, 0.+0.j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[26.+0.j, -4.+0.j],
|
|
|
|
[-2.+0.j, 0.+0.j]]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[2, 3]``, dimension of the transform along ``axes (-2, -1)`` will be
|
|
|
|
``(2, 3)`` and dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.fft2(x, s=[2, 3])
|
|
|
|
Array([[[10. +0.j , -0.5 -6.06j, -0.5 +6.06j],
|
|
|
|
[-2. +0.j , -0.5 +0.87j, -0.5 -0.87j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[26. +0.j , 3.5-12.99j, 3.5+12.99j],
|
|
|
|
[-2. +0.j , -0.5 +0.87j, -0.5 -0.87j]]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[2, 3]`` and ``axes=(0, 1)``, shape of the transform along
|
|
|
|
``axes (0, 1)`` will be ``(2, 3)`` and dimension along other axes will be
|
|
|
|
same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.fft2(x, s=[2, 3], axes=(0, 1))
|
|
|
|
Array([[[14. +0.j , 22. +0.j ],
|
|
|
|
[ 2. -6.93j, 4.-10.39j],
|
|
|
|
[ 2. +6.93j, 4.+10.39j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[-8. +0.j , -8. +0.j ],
|
|
|
|
[-2. +3.46j, -2. +3.46j],
|
|
|
|
[-2. -3.46j, -2. -3.46j]]], dtype=complex64)
|
|
|
|
|
|
|
|
``jnp.fft.ifft2`` can be used to reconstruct ``x`` from the result of
|
|
|
|
``jnp.fft.fft2``.
|
|
|
|
|
|
|
|
>>> x_fft2 = jnp.fft.fft2(x)
|
|
|
|
>>> jnp.allclose(x, jnp.fft.ifft2(x_fft2))
|
|
|
|
Array(True, dtype=bool)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_2d('fft2', lax.FftType.FFT, a, s=s, axes=axes,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm)
|
|
|
|
|
2024-07-23 23:51:34 +05:30
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
|
|
|
norm: str | None = None) -> Array:
|
2024-07-23 23:51:34 +05:30
|
|
|
"""Compute a two-dimensional inverse discrete Fourier transform.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.ifft2`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array. Must have ``a.ndim >= 2``.
|
|
|
|
s: optional length-2 sequence of integers. Specifies the size of the output
|
|
|
|
in each specified axis. If not specified, it will default to the size of
|
|
|
|
``a`` along the specified ``axes``.
|
|
|
|
axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the
|
|
|
|
axes along which the transform is computed.
|
|
|
|
norm: string, default="backward". The normalization mode. "backward", "ortho"
|
|
|
|
and "forward" are supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the two-dimensional inverse discrete Fourier transform
|
|
|
|
of ``a`` along given ``axes``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete
|
|
|
|
Fourier transform.
|
|
|
|
- :func:`jax.numpy.fft.ifftn`: Computes a multidimensional inverse discrete
|
|
|
|
Fourier transform.
|
|
|
|
- :func:`jax.numpy.fft.fft2`: Computes a two-dimensional discrete Fourier
|
|
|
|
transform.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.ifft2`` computes the transform along the last two axes by default.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[[1, 3],
|
|
|
|
... [2, 4]],
|
|
|
|
... [[5, 7],
|
|
|
|
... [6, 8]]])
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.ifft2(x)
|
|
|
|
Array([[[ 2.5+0.j, -1. +0.j],
|
|
|
|
[-0.5+0.j, 0. +0.j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ 6.5+0.j, -1. +0.j],
|
|
|
|
[-0.5+0.j, 0. +0.j]]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[2, 3]``, dimension of the transform along ``axes (-2, -1)`` will be
|
|
|
|
``(2, 3)`` and dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.ifft2(x, s=[2, 3])
|
|
|
|
Array([[[ 1.67+0.j , -0.08+1.01j, -0.08-1.01j],
|
|
|
|
[-0.33+0.j , -0.08-0.14j, -0.08+0.14j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ 4.33+0.j , 0.58+2.17j, 0.58-2.17j],
|
|
|
|
[-0.33+0.j , -0.08-0.14j, -0.08+0.14j]]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[2, 3]`` and ``axes=(0, 1)``, shape of the transform along
|
|
|
|
``axes (0, 1)`` will be ``(2, 3)`` and dimension along other axes will be
|
|
|
|
same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.ifft2(x, s=[2, 3], axes=(0, 1))
|
|
|
|
Array([[[ 2.33+0.j , 3.67+0.j ],
|
|
|
|
[ 0.33+1.15j, 0.67+1.73j],
|
|
|
|
[ 0.33-1.15j, 0.67-1.73j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[-1.33+0.j , -1.33+0.j ],
|
|
|
|
[-0.33-0.58j, -0.33-0.58j],
|
|
|
|
[-0.33+0.58j, -0.33+0.58j]]], dtype=complex64)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_2d('ifft2', lax.FftType.IFFT, a, s=s, axes=axes,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm)
|
|
|
|
|
2024-08-07 17:59:53 +05:30
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
|
|
|
norm: str | None = None) -> Array:
|
2024-08-07 17:59:53 +05:30
|
|
|
"""Compute a two-dimensional discrete Fourier transform of a real-valued array.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.rfft2`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: real-valued input array. Must have ``a.ndim >= 2``.
|
|
|
|
s: optional length-2 sequence of integers. Specifies the effective size of the
|
|
|
|
output along each specified axis. If not specified, it will default to the
|
|
|
|
dimension of input along ``axes``.
|
|
|
|
axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the
|
|
|
|
axes along which the transform is computed.
|
|
|
|
norm: string, default="backward". The normalization mode. "backward", "ortho"
|
|
|
|
and "forward" are supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the two-dimensional discrete Fourier transform of ``a``.
|
|
|
|
The size of the output along the axis ``axes[1]`` is ``(s[1]/2)+1``, if ``s[1]``
|
|
|
|
is even and ``(s[1]+1)/2``, if ``s[1]`` is odd. The size of the output along
|
|
|
|
the axis ``axes[0]`` is ``s[0]``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier
|
|
|
|
transform of real-valued array.
|
|
|
|
- :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier
|
|
|
|
transform of real-valued array.
|
|
|
|
- :func:`jax.numpy.fft.irfft2`: Computes a real-valued two-dimensional inverse
|
|
|
|
discrete Fourier transform.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
``jnp.fft.rfft2`` computes the transform along the last two axes by default.
|
|
|
|
|
|
|
|
>>> x = jnp.array([[[1, 3, 5],
|
|
|
|
... [2, 4, 6]],
|
|
|
|
... [[7, 9, 11],
|
|
|
|
... [8, 10, 12]]])
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfft2(x)
|
|
|
|
Array([[[21.+0.j , -6.+3.46j],
|
|
|
|
[-3.+0.j , 0.+0.j ]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[57.+0.j , -6.+3.46j],
|
|
|
|
[-3.+0.j , 0.+0.j ]]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[2, 4]``, dimension of the transform along ``axis -2`` will be
|
|
|
|
``2``, along ``axis -1`` will be ``(4/2)+1) = 3`` and dimension along other
|
|
|
|
axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfft2(x, s=[2, 4])
|
|
|
|
Array([[[21. +0.j, -8. -7.j, 7. +0.j],
|
|
|
|
[-3. +0.j, 0. +1.j, -1. +0.j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[57. +0.j, -8.-19.j, 19. +0.j],
|
|
|
|
[-3. +0.j, 0. +1.j, -1. +0.j]]], dtype=complex64)
|
|
|
|
|
|
|
|
When ``s=[3, 5]`` and ``axes=(0, 1)``, shape of the transform along ``axis 0``
|
|
|
|
will be ``3``, along ``axis 1`` will be ``(5+1)/2 = 3`` and dimension along
|
|
|
|
other axes will be same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.rfft2(x, s=[3, 5], axes=(0, 1))
|
|
|
|
Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ],
|
|
|
|
[ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j],
|
|
|
|
[ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j],
|
|
|
|
[ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j],
|
|
|
|
[ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j],
|
|
|
|
[ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j],
|
|
|
|
[ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_2d('rfft2', lax.FftType.RFFT, a, s=s, axes=axes,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm)
|
|
|
|
|
2024-08-07 17:59:53 +05:30
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
|
|
|
|
norm: str | None = None) -> Array:
|
2024-08-07 17:59:53 +05:30
|
|
|
"""Compute a real-valued two-dimensional inverse discrete Fourier transform.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.irfft2`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: input array. Must have ``a.ndim >= 2``.
|
|
|
|
s: optional length-2 sequence of integers. Specifies the size of the output
|
|
|
|
in each specified axis. If not specified, the dimension of output along
|
|
|
|
axis ``axes[1]`` is ``2*(m-1)``, ``m`` is the size of input along axis
|
|
|
|
``axes[1]`` and the dimension along other axes will be the same as that of
|
|
|
|
input.
|
|
|
|
axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the
|
|
|
|
axes along which the transform is computed.
|
|
|
|
norm: string, default="backward". The normalization mode. "backward", "ortho"
|
|
|
|
and "forward" are supported.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A real-valued array containing the two-dimensional inverse discrete Fourier
|
|
|
|
transform of ``a``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.rfft2`: Computes a two-dimensional discrete Fourier
|
|
|
|
transform of a real-valued array.
|
|
|
|
- :func:`jax.numpy.fft.irfft`: Computes a real-valued one-dimensional inverse
|
|
|
|
discrete Fourier transform.
|
|
|
|
- :func:`jax.numpy.fft.irfftn`: Computes a real-valued multidimensional inverse
|
|
|
|
discrete Fourier transform.
|
|
|
|
|
|
|
|
Examples:
|
2024-08-09 23:07:17 +05:30
|
|
|
``jnp.fft.irfft2`` computes the transform along the last two axes by default.
|
2024-08-07 17:59:53 +05:30
|
|
|
|
|
|
|
>>> x = jnp.array([[[1, 3, 5],
|
|
|
|
... [2, 4, 6]],
|
|
|
|
... [[7, 9, 11],
|
|
|
|
... [8, 10, 12]]])
|
|
|
|
>>> jnp.fft.irfft2(x)
|
|
|
|
Array([[[ 3.5, -1. , 0. , -1. ],
|
|
|
|
[-0.5, 0. , 0. , 0. ]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ 9.5, -1. , 0. , -1. ],
|
|
|
|
[-0.5, 0. , 0. , 0. ]]], dtype=float32)
|
|
|
|
|
|
|
|
When ``s=[3, 3]``, dimension of the transform along ``axes (-2, -1)`` will be
|
|
|
|
``(3, 3)`` and dimension along other axes will be the same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.irfft2(x, s=[3, 3])
|
|
|
|
Array([[[ 1.89, -0.44, -0.44],
|
|
|
|
[ 0.22, -0.78, 0.56],
|
|
|
|
[ 0.22, 0.56, -0.78]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[ 5.89, -0.44, -0.44],
|
|
|
|
[ 1.22, -1.78, 1.56],
|
|
|
|
[ 1.22, 1.56, -1.78]]], dtype=float32)
|
|
|
|
|
|
|
|
When ``s=[2, 3]`` and ``axes=(0, 1)``, shape of the transform along
|
|
|
|
``axes (0, 1)`` will be ``(2, 3)`` and dimension along other axes will be
|
|
|
|
same as that of input.
|
|
|
|
|
|
|
|
>>> with jnp.printoptions(precision=2, suppress=True):
|
|
|
|
... jnp.fft.irfft2(x, s=[2, 3], axes=(0, 1))
|
|
|
|
Array([[[ 4.67, 6.67, 8.67],
|
|
|
|
[-0.33, -0.33, -0.33],
|
|
|
|
[-0.33, -0.33, -0.33]],
|
|
|
|
<BLANKLINE>
|
|
|
|
[[-3. , -3. , -3. ],
|
|
|
|
[ 0. , 0. , 0. ],
|
|
|
|
[ 0. , 0. , 0. ]]], dtype=float32)
|
|
|
|
"""
|
2024-10-10 08:06:50 -07:00
|
|
|
return _fft_core_2d('irfft2', lax.FftType.IRFFT, a, s=s, axes=axes,
|
2020-10-16 18:08:20 -04:00
|
|
|
norm=norm)
|
|
|
|
|
|
|
|
|
2024-07-29 12:40:28 -07:00
|
|
|
def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
|
|
|
|
device: xla_client.Device | Sharding | None = None) -> Array:
|
|
|
|
"""Return sample frequencies for the discrete Fourier transform.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
|
2024-08-09 14:23:56 +05:30
|
|
|
for use with the outputs of :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`.
|
2024-07-29 12:40:28 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
n: length of the FFT window
|
|
|
|
d: optional scalar sample spacing (default: 1.0)
|
|
|
|
dtype: optional dtype of returned frequencies. If not specified, JAX's default
|
|
|
|
floating point dtype will be used.
|
|
|
|
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
|
|
to which the created array will be committed.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Array of sample frequencies, length ``n``.
|
|
|
|
|
|
|
|
See also:
|
2024-08-09 14:23:56 +05:30
|
|
|
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with
|
|
|
|
:func:`~jax.numpy.fft.rfft` and :func:`~jax.numpy.fft.irfft`.
|
2024-07-29 12:40:28 -07:00
|
|
|
"""
|
2025-02-06 14:48:10 -08:00
|
|
|
dtype = dtype or dtypes.canonicalize_dtype(dtypes.float_)
|
2020-10-16 18:08:20 -04:00
|
|
|
if isinstance(n, (list, tuple)):
|
|
|
|
raise ValueError(
|
|
|
|
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
|
|
|
|
"Got n = %s." % list(n))
|
|
|
|
|
|
|
|
elif isinstance(d, (list, tuple)):
|
|
|
|
raise ValueError(
|
|
|
|
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
|
|
|
|
"Got d = %s." % list(d))
|
|
|
|
|
2024-06-25 14:05:21 +02:00
|
|
|
k = jnp.zeros(n, dtype=dtype, device=device)
|
2020-10-16 18:08:20 -04:00
|
|
|
if n % 2 == 0:
|
|
|
|
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
|
2024-06-25 14:05:21 +02:00
|
|
|
k = k.at[0: n // 2].set(jnp.arange(0, n // 2, dtype=dtype))
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
# k[n // 2:] = jnp.arange(-n // 2, -1)
|
2024-06-25 14:05:21 +02:00
|
|
|
k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype))
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
else:
|
|
|
|
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
|
2022-06-15 10:10:44 -07:00
|
|
|
k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype))
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
|
2022-06-15 10:10:44 -07:00
|
|
|
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype))
|
2020-10-16 18:08:20 -04:00
|
|
|
|
2024-06-25 14:05:21 +02:00
|
|
|
return k / jnp.array(d * n, dtype=dtype, device=device)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
|
2024-07-29 12:40:28 -07:00
|
|
|
def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
|
|
|
|
device: xla_client.Device | Sharding | None = None) -> Array:
|
|
|
|
"""Return sample frequencies for the discrete Fourier transform.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
|
2024-08-09 14:23:56 +05:30
|
|
|
for use with the outputs of :func:`~jax.numpy.fft.rfft` and
|
|
|
|
:func:`~jax.numpy.fft.irfft`.
|
2024-07-29 12:40:28 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
n: length of the FFT window
|
|
|
|
d: optional scalar sample spacing (default: 1.0)
|
|
|
|
dtype: optional dtype of returned frequencies. If not specified, JAX's default
|
|
|
|
floating point dtype will be used.
|
|
|
|
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
|
|
to which the created array will be committed.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Array of sample frequencies, length ``n // 2 + 1``.
|
|
|
|
|
|
|
|
See also:
|
2024-10-04 06:12:02 -07:00
|
|
|
- :func:`jax.numpy.fft.fftfreq`: frequencies for use with
|
2024-08-09 14:23:56 +05:30
|
|
|
:func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`.
|
2024-07-29 12:40:28 -07:00
|
|
|
"""
|
2025-02-06 14:48:10 -08:00
|
|
|
dtype = dtype or dtypes.canonicalize_dtype(dtypes.float_)
|
2020-10-16 18:08:20 -04:00
|
|
|
if isinstance(n, (list, tuple)):
|
|
|
|
raise ValueError(
|
|
|
|
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
|
|
|
|
"Got n = %s." % list(n))
|
|
|
|
|
|
|
|
elif isinstance(d, (list, tuple)):
|
|
|
|
raise ValueError(
|
|
|
|
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
|
|
|
|
"Got d = %s." % list(d))
|
|
|
|
|
|
|
|
if n % 2 == 0:
|
2022-06-15 10:10:44 -07:00
|
|
|
k = jnp.arange(0, n // 2 + 1, dtype=dtype)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
else:
|
2022-06-15 10:10:44 -07:00
|
|
|
k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)
|
2020-10-16 18:08:20 -04:00
|
|
|
|
2024-07-29 12:40:28 -07:00
|
|
|
result = k / jnp.array(d * n, dtype=dtype)
|
|
|
|
|
|
|
|
if device is not None:
|
|
|
|
return result.to_device(device)
|
|
|
|
return result
|
2020-10-16 18:08:20 -04:00
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
|
2024-10-04 06:12:02 -07:00
|
|
|
"""Shift zero-frequency fft component to the center of the spectrum.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.fftshift`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: N-dimensional array array of frequencies.
|
|
|
|
axes: optional integer or sequence of integers specifying which axes to
|
|
|
|
shift. If None (default), then shift all axes.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A shifted copy of ``x``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.ifftshift`: inverse of ``fftshift``.
|
|
|
|
- :func:`jax.numpy.fft.fftfreq`: generate FFT frequencies.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Generate FFT frequencies with :func:`~jax.numpy.fft.fftfreq`:
|
|
|
|
|
|
|
|
>>> freq = jnp.fft.fftfreq(5)
|
|
|
|
>>> freq
|
|
|
|
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
|
|
|
|
|
|
|
|
Use ``fftshift`` to shift the zero-frequency entry to the middle of the array:
|
|
|
|
|
|
|
|
>>> shifted_freq = jnp.fft.fftshift(freq)
|
|
|
|
>>> shifted_freq
|
|
|
|
Array([-0.4, -0.2, 0. , 0.2, 0.4], dtype=float32)
|
|
|
|
|
|
|
|
Unshift with :func:`~jax.numpy.fft.ifftshift` to recover the original frequencies:
|
|
|
|
|
|
|
|
>>> jnp.fft.ifftshift(shifted_freq)
|
|
|
|
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
|
|
|
|
"""
|
2025-01-23 08:48:13 -08:00
|
|
|
x = ensure_arraylike("fftshift", x)
|
2023-12-11 13:59:29 +00:00
|
|
|
shift: int | Sequence[int]
|
2020-10-16 18:08:20 -04:00
|
|
|
if axes is None:
|
|
|
|
axes = tuple(range(x.ndim))
|
|
|
|
shift = [dim // 2 for dim in x.shape]
|
|
|
|
elif isinstance(axes, int):
|
|
|
|
shift = x.shape[axes] // 2
|
|
|
|
else:
|
|
|
|
shift = [x.shape[ax] // 2 for ax in axes]
|
|
|
|
|
|
|
|
return jnp.roll(x, shift, axes)
|
|
|
|
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
|
2024-10-04 06:12:02 -07:00
|
|
|
"""The inverse of :func:`jax.numpy.fft.fftshift`.
|
|
|
|
|
|
|
|
JAX implementation of :func:`numpy.fft.ifftshift`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: N-dimensional array array of frequencies.
|
|
|
|
axes: optional integer or sequence of integers specifying which axes to
|
|
|
|
shift. If None (default), then shift all axes.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A shifted copy of ``x``.
|
|
|
|
|
|
|
|
See also:
|
|
|
|
- :func:`jax.numpy.fft.fftshift`: inverse of ``ifftshift``.
|
|
|
|
- :func:`jax.numpy.fft.fftfreq`: generate FFT frequencies.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Generate FFT frequencies with :func:`~jax.numpy.fft.fftfreq`:
|
|
|
|
|
|
|
|
>>> freq = jnp.fft.fftfreq(5)
|
|
|
|
>>> freq
|
|
|
|
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
|
|
|
|
|
|
|
|
Use :func:`~jax.numpy.fft.fftshift` to shift the zero-frequency entry
|
|
|
|
to the middle of the array:
|
|
|
|
|
|
|
|
>>> shifted_freq = jnp.fft.fftshift(freq)
|
|
|
|
>>> shifted_freq
|
|
|
|
Array([-0.4, -0.2, 0. , 0.2, 0.4], dtype=float32)
|
|
|
|
|
|
|
|
Unshift with ``ifftshift`` to recover the original frequencies:
|
|
|
|
|
|
|
|
>>> jnp.fft.ifftshift(shifted_freq)
|
|
|
|
Array([ 0. , 0.2, 0.4, -0.4, -0.2], dtype=float32)
|
|
|
|
"""
|
2025-01-23 08:48:13 -08:00
|
|
|
x = ensure_arraylike("ifftshift", x)
|
2023-12-11 13:59:29 +00:00
|
|
|
shift: int | Sequence[int]
|
2020-10-16 18:08:20 -04:00
|
|
|
if axes is None:
|
|
|
|
axes = tuple(range(x.ndim))
|
|
|
|
shift = [-(dim // 2) for dim in x.shape]
|
|
|
|
elif isinstance(axes, int):
|
|
|
|
shift = -(x.shape[axes] // 2)
|
|
|
|
else:
|
|
|
|
shift = [-(x.shape[ax] // 2) for ax in axes]
|
|
|
|
|
|
|
|
return jnp.roll(x, shift, axes)
|