mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add support for N-D FFTs with D>3.
This commit is contained in:
parent
7680532512
commit
c6131ee527
@ -30,6 +30,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
are now deprecated, having been replaced by symbols of the same name
|
||||
in {mod}`jax.core`.
|
||||
|
||||
* New Features
|
||||
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
|
||||
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
|
||||
transforms in more than 3 dimensions, which was previously the limit. See
|
||||
{jax-issue}`#25606` for more details.
|
||||
|
||||
## jax 0.4.38 (Dec 17, 2024)
|
||||
|
||||
* Changes:
|
||||
|
@ -72,12 +72,6 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
||||
raise ValueError(
|
||||
f"{full_name} does not support repeated axes. Got axes {axes}.")
|
||||
|
||||
if len(axes) > 3:
|
||||
# XLA does not support FFTs over more than 3 dimensions
|
||||
raise ValueError(
|
||||
"%s only supports 1D, 2D, and 3D FFTs. "
|
||||
"Got axes %s with input rank %s." % (full_name, orig_axes, arr.ndim))
|
||||
|
||||
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
|
||||
if orig_axes is not None:
|
||||
axes = tuple(range(arr.ndim - len(axes), arr.ndim))
|
||||
@ -100,7 +94,7 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
||||
s += [max(0, 2 * (arr.shape[axes[-1]] - 1))]
|
||||
else:
|
||||
s = [arr.shape[axis] for axis in axes]
|
||||
transformed = lax.fft(arr, fft_type, tuple(s))
|
||||
transformed = _fft_core_nd(arr, fft_type, s)
|
||||
if norm is not None:
|
||||
transformed *= _fft_norm(
|
||||
jnp.array(s, dtype=transformed.dtype), func_name, norm)
|
||||
@ -110,6 +104,31 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
||||
return transformed
|
||||
|
||||
|
||||
def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array:
|
||||
# XLA supports N-D transforms up to N=3 so we use XLA's FFT N-D directly.
|
||||
if len(s) <= 3:
|
||||
return lax.fft(arr, fft_type, tuple(s))
|
||||
|
||||
# For larger N, we repeatedly apply N<=3 transforms until we reach the
|
||||
# requested dimension. We special case N=4 to use two 2-D transforms instead
|
||||
# of one 3-D and one 1-D, since we typically expect better accelerator
|
||||
# performance when N>1.
|
||||
n = 2 if len(s) == 4 else 3
|
||||
src = tuple(range(arr.ndim - len(s), arr.ndim - n))
|
||||
dst = tuple(range(arr.ndim - len(s) + n, arr.ndim))
|
||||
if fft_type in {lax.FftType.RFFT, lax.FftType.FFT}:
|
||||
arr = lax.fft(arr, fft_type, tuple(s)[-n:])
|
||||
arr = jnp.moveaxis(arr, src, dst)
|
||||
arr = _fft_core_nd(arr, lax.FftType.FFT, s[:-n])
|
||||
arr = jnp.moveaxis(arr, dst, src)
|
||||
else:
|
||||
arr = jnp.moveaxis(arr, src, dst)
|
||||
arr = _fft_core_nd(arr, lax.FftType.IFFT, s[:-n])
|
||||
arr = jnp.moveaxis(arr, dst, src)
|
||||
arr = lax.fft(arr, fft_type, tuple(s)[-n:])
|
||||
return arr
|
||||
|
||||
|
||||
def fftn(a: ArrayLike, s: Shape | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
|
@ -41,12 +41,9 @@ all_dtypes = real_dtypes + jtu.dtypes.complex
|
||||
|
||||
|
||||
def _get_fftn_test_axes(shape):
|
||||
axes = [[]]
|
||||
axes = [[], None]
|
||||
ndims = len(shape)
|
||||
# XLA's FFT op only supports up to 3 innermost dimensions.
|
||||
if ndims <= 3:
|
||||
axes.append(None)
|
||||
for naxes in range(1, min(ndims, 3) + 1):
|
||||
for naxes in range(1, ndims + 1):
|
||||
axes.extend(itertools.combinations(range(ndims), naxes))
|
||||
for index in range(1, ndims + 1):
|
||||
axes.append((-index,))
|
||||
@ -145,7 +142,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
for dtype in (real_dtypes if real and not inverse else all_dtypes)
|
||||
],
|
||||
[dict(shape=shape, axes=axes, s=s)
|
||||
for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)]
|
||||
for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5), (2, 3, 4, 5, 6)]
|
||||
for axes in _get_fftn_test_axes(shape)
|
||||
for s in _get_fftn_test_s(shape, axes)
|
||||
],
|
||||
@ -203,11 +200,6 @@ class FftTest(jtu.JaxTestCase):
|
||||
if inverse:
|
||||
name = 'i' + name
|
||||
func = _get_fftn_func(jnp.fft, inverse, real)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. "
|
||||
"Got axes None with input rank 4.".format(name),
|
||||
lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None))
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f"jax.numpy.fft.{name} does not support repeated axes. Got axes \\[1, 1\\].",
|
||||
|
Loading…
x
Reference in New Issue
Block a user