From c6131ee527ec4cd7b320efa151fffccaf0ae4023 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 19 Dec 2024 12:02:36 +0000 Subject: [PATCH] Add support for N-D FFTs with D>3. --- CHANGELOG.md | 6 ++++++ jax/_src/numpy/fft.py | 33 ++++++++++++++++++++++++++------- tests/fft_test.py | 14 +++----------- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18be5ec9d..e86dece51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. are now deprecated, having been replaced by symbols of the same name in {mod}`jax.core`. +* New Features + * {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`, + {func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support + transforms in more than 3 dimensions, which was previously the limit. See + {jax-issue}`#25606` for more details. + ## jax 0.4.38 (Dec 17, 2024) * Changes: diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 4a9ec23fd..c0707ea1c 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -72,12 +72,6 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, raise ValueError( f"{full_name} does not support repeated axes. Got axes {axes}.") - if len(axes) > 3: - # XLA does not support FFTs over more than 3 dimensions - raise ValueError( - "%s only supports 1D, 2D, and 3D FFTs. " - "Got axes %s with input rank %s." % (full_name, orig_axes, arr.ndim)) - # XLA only supports FFTs over the innermost axes, so rearrange if necessary. if orig_axes is not None: axes = tuple(range(arr.ndim - len(axes), arr.ndim)) @@ -100,7 +94,7 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, s += [max(0, 2 * (arr.shape[axes[-1]] - 1))] else: s = [arr.shape[axis] for axis in axes] - transformed = lax.fft(arr, fft_type, tuple(s)) + transformed = _fft_core_nd(arr, fft_type, s) if norm is not None: transformed *= _fft_norm( jnp.array(s, dtype=transformed.dtype), func_name, norm) @@ -110,6 +104,31 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, return transformed +def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array: + # XLA supports N-D transforms up to N=3 so we use XLA's FFT N-D directly. + if len(s) <= 3: + return lax.fft(arr, fft_type, tuple(s)) + + # For larger N, we repeatedly apply N<=3 transforms until we reach the + # requested dimension. We special case N=4 to use two 2-D transforms instead + # of one 3-D and one 1-D, since we typically expect better accelerator + # performance when N>1. + n = 2 if len(s) == 4 else 3 + src = tuple(range(arr.ndim - len(s), arr.ndim - n)) + dst = tuple(range(arr.ndim - len(s) + n, arr.ndim)) + if fft_type in {lax.FftType.RFFT, lax.FftType.FFT}: + arr = lax.fft(arr, fft_type, tuple(s)[-n:]) + arr = jnp.moveaxis(arr, src, dst) + arr = _fft_core_nd(arr, lax.FftType.FFT, s[:-n]) + arr = jnp.moveaxis(arr, dst, src) + else: + arr = jnp.moveaxis(arr, src, dst) + arr = _fft_core_nd(arr, lax.FftType.IFFT, s[:-n]) + arr = jnp.moveaxis(arr, dst, src) + arr = lax.fft(arr, fft_type, tuple(s)[-n:]) + return arr + + def fftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: diff --git a/tests/fft_test.py b/tests/fft_test.py index c0fc07ce6..3668f534d 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -41,12 +41,9 @@ all_dtypes = real_dtypes + jtu.dtypes.complex def _get_fftn_test_axes(shape): - axes = [[]] + axes = [[], None] ndims = len(shape) - # XLA's FFT op only supports up to 3 innermost dimensions. - if ndims <= 3: - axes.append(None) - for naxes in range(1, min(ndims, 3) + 1): + for naxes in range(1, ndims + 1): axes.extend(itertools.combinations(range(ndims), naxes)) for index in range(1, ndims + 1): axes.append((-index,)) @@ -145,7 +142,7 @@ class FftTest(jtu.JaxTestCase): for dtype in (real_dtypes if real and not inverse else all_dtypes) ], [dict(shape=shape, axes=axes, s=s) - for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)] + for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5), (2, 3, 4, 5, 6)] for axes in _get_fftn_test_axes(shape) for s in _get_fftn_test_s(shape, axes) ], @@ -203,11 +200,6 @@ class FftTest(jtu.JaxTestCase): if inverse: name = 'i' + name func = _get_fftn_func(jnp.fft, inverse, real) - self.assertRaisesRegex( - ValueError, - "jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. " - "Got axes None with input rank 4.".format(name), - lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None)) self.assertRaisesRegex( ValueError, f"jax.numpy.fft.{name} does not support repeated axes. Got axes \\[1, 1\\].",