Add support for N-D FFTs with D>3.

This commit is contained in:
Dan Foreman-Mackey 2024-12-19 12:02:36 +00:00
parent 7680532512
commit c6131ee527
3 changed files with 35 additions and 18 deletions

View File

@ -30,6 +30,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
## jax 0.4.38 (Dec 17, 2024)
* Changes:

View File

@ -72,12 +72,6 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
raise ValueError(
f"{full_name} does not support repeated axes. Got axes {axes}.")
if len(axes) > 3:
# XLA does not support FFTs over more than 3 dimensions
raise ValueError(
"%s only supports 1D, 2D, and 3D FFTs. "
"Got axes %s with input rank %s." % (full_name, orig_axes, arr.ndim))
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
if orig_axes is not None:
axes = tuple(range(arr.ndim - len(axes), arr.ndim))
@ -100,7 +94,7 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
s += [max(0, 2 * (arr.shape[axes[-1]] - 1))]
else:
s = [arr.shape[axis] for axis in axes]
transformed = lax.fft(arr, fft_type, tuple(s))
transformed = _fft_core_nd(arr, fft_type, s)
if norm is not None:
transformed *= _fft_norm(
jnp.array(s, dtype=transformed.dtype), func_name, norm)
@ -110,6 +104,31 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
return transformed
def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array:
# XLA supports N-D transforms up to N=3 so we use XLA's FFT N-D directly.
if len(s) <= 3:
return lax.fft(arr, fft_type, tuple(s))
# For larger N, we repeatedly apply N<=3 transforms until we reach the
# requested dimension. We special case N=4 to use two 2-D transforms instead
# of one 3-D and one 1-D, since we typically expect better accelerator
# performance when N>1.
n = 2 if len(s) == 4 else 3
src = tuple(range(arr.ndim - len(s), arr.ndim - n))
dst = tuple(range(arr.ndim - len(s) + n, arr.ndim))
if fft_type in {lax.FftType.RFFT, lax.FftType.FFT}:
arr = lax.fft(arr, fft_type, tuple(s)[-n:])
arr = jnp.moveaxis(arr, src, dst)
arr = _fft_core_nd(arr, lax.FftType.FFT, s[:-n])
arr = jnp.moveaxis(arr, dst, src)
else:
arr = jnp.moveaxis(arr, src, dst)
arr = _fft_core_nd(arr, lax.FftType.IFFT, s[:-n])
arr = jnp.moveaxis(arr, dst, src)
arr = lax.fft(arr, fft_type, tuple(s)[-n:])
return arr
def fftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:

View File

@ -41,12 +41,9 @@ all_dtypes = real_dtypes + jtu.dtypes.complex
def _get_fftn_test_axes(shape):
axes = [[]]
axes = [[], None]
ndims = len(shape)
# XLA's FFT op only supports up to 3 innermost dimensions.
if ndims <= 3:
axes.append(None)
for naxes in range(1, min(ndims, 3) + 1):
for naxes in range(1, ndims + 1):
axes.extend(itertools.combinations(range(ndims), naxes))
for index in range(1, ndims + 1):
axes.append((-index,))
@ -145,7 +142,7 @@ class FftTest(jtu.JaxTestCase):
for dtype in (real_dtypes if real and not inverse else all_dtypes)
],
[dict(shape=shape, axes=axes, s=s)
for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)]
for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5), (2, 3, 4, 5, 6)]
for axes in _get_fftn_test_axes(shape)
for s in _get_fftn_test_s(shape, axes)
],
@ -203,11 +200,6 @@ class FftTest(jtu.JaxTestCase):
if inverse:
name = 'i' + name
func = _get_fftn_func(jnp.fft, inverse, real)
self.assertRaisesRegex(
ValueError,
"jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. "
"Got axes None with input rank 4.".format(name),
lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None))
self.assertRaisesRegex(
ValueError,
f"jax.numpy.fft.{name} does not support repeated axes. Got axes \\[1, 1\\].",