mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add support for N-D FFTs with D>3.
This commit is contained in:
parent
7680532512
commit
c6131ee527
@ -30,6 +30,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
are now deprecated, having been replaced by symbols of the same name
|
are now deprecated, having been replaced by symbols of the same name
|
||||||
in {mod}`jax.core`.
|
in {mod}`jax.core`.
|
||||||
|
|
||||||
|
* New Features
|
||||||
|
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
|
||||||
|
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
|
||||||
|
transforms in more than 3 dimensions, which was previously the limit. See
|
||||||
|
{jax-issue}`#25606` for more details.
|
||||||
|
|
||||||
## jax 0.4.38 (Dec 17, 2024)
|
## jax 0.4.38 (Dec 17, 2024)
|
||||||
|
|
||||||
* Changes:
|
* Changes:
|
||||||
|
@ -72,12 +72,6 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{full_name} does not support repeated axes. Got axes {axes}.")
|
f"{full_name} does not support repeated axes. Got axes {axes}.")
|
||||||
|
|
||||||
if len(axes) > 3:
|
|
||||||
# XLA does not support FFTs over more than 3 dimensions
|
|
||||||
raise ValueError(
|
|
||||||
"%s only supports 1D, 2D, and 3D FFTs. "
|
|
||||||
"Got axes %s with input rank %s." % (full_name, orig_axes, arr.ndim))
|
|
||||||
|
|
||||||
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
|
# XLA only supports FFTs over the innermost axes, so rearrange if necessary.
|
||||||
if orig_axes is not None:
|
if orig_axes is not None:
|
||||||
axes = tuple(range(arr.ndim - len(axes), arr.ndim))
|
axes = tuple(range(arr.ndim - len(axes), arr.ndim))
|
||||||
@ -100,7 +94,7 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
|||||||
s += [max(0, 2 * (arr.shape[axes[-1]] - 1))]
|
s += [max(0, 2 * (arr.shape[axes[-1]] - 1))]
|
||||||
else:
|
else:
|
||||||
s = [arr.shape[axis] for axis in axes]
|
s = [arr.shape[axis] for axis in axes]
|
||||||
transformed = lax.fft(arr, fft_type, tuple(s))
|
transformed = _fft_core_nd(arr, fft_type, s)
|
||||||
if norm is not None:
|
if norm is not None:
|
||||||
transformed *= _fft_norm(
|
transformed *= _fft_norm(
|
||||||
jnp.array(s, dtype=transformed.dtype), func_name, norm)
|
jnp.array(s, dtype=transformed.dtype), func_name, norm)
|
||||||
@ -110,6 +104,31 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike,
|
|||||||
return transformed
|
return transformed
|
||||||
|
|
||||||
|
|
||||||
|
def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array:
|
||||||
|
# XLA supports N-D transforms up to N=3 so we use XLA's FFT N-D directly.
|
||||||
|
if len(s) <= 3:
|
||||||
|
return lax.fft(arr, fft_type, tuple(s))
|
||||||
|
|
||||||
|
# For larger N, we repeatedly apply N<=3 transforms until we reach the
|
||||||
|
# requested dimension. We special case N=4 to use two 2-D transforms instead
|
||||||
|
# of one 3-D and one 1-D, since we typically expect better accelerator
|
||||||
|
# performance when N>1.
|
||||||
|
n = 2 if len(s) == 4 else 3
|
||||||
|
src = tuple(range(arr.ndim - len(s), arr.ndim - n))
|
||||||
|
dst = tuple(range(arr.ndim - len(s) + n, arr.ndim))
|
||||||
|
if fft_type in {lax.FftType.RFFT, lax.FftType.FFT}:
|
||||||
|
arr = lax.fft(arr, fft_type, tuple(s)[-n:])
|
||||||
|
arr = jnp.moveaxis(arr, src, dst)
|
||||||
|
arr = _fft_core_nd(arr, lax.FftType.FFT, s[:-n])
|
||||||
|
arr = jnp.moveaxis(arr, dst, src)
|
||||||
|
else:
|
||||||
|
arr = jnp.moveaxis(arr, src, dst)
|
||||||
|
arr = _fft_core_nd(arr, lax.FftType.IFFT, s[:-n])
|
||||||
|
arr = jnp.moveaxis(arr, dst, src)
|
||||||
|
arr = lax.fft(arr, fft_type, tuple(s)[-n:])
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def fftn(a: ArrayLike, s: Shape | None = None,
|
def fftn(a: ArrayLike, s: Shape | None = None,
|
||||||
axes: Sequence[int] | None = None,
|
axes: Sequence[int] | None = None,
|
||||||
norm: str | None = None) -> Array:
|
norm: str | None = None) -> Array:
|
||||||
|
@ -41,12 +41,9 @@ all_dtypes = real_dtypes + jtu.dtypes.complex
|
|||||||
|
|
||||||
|
|
||||||
def _get_fftn_test_axes(shape):
|
def _get_fftn_test_axes(shape):
|
||||||
axes = [[]]
|
axes = [[], None]
|
||||||
ndims = len(shape)
|
ndims = len(shape)
|
||||||
# XLA's FFT op only supports up to 3 innermost dimensions.
|
for naxes in range(1, ndims + 1):
|
||||||
if ndims <= 3:
|
|
||||||
axes.append(None)
|
|
||||||
for naxes in range(1, min(ndims, 3) + 1):
|
|
||||||
axes.extend(itertools.combinations(range(ndims), naxes))
|
axes.extend(itertools.combinations(range(ndims), naxes))
|
||||||
for index in range(1, ndims + 1):
|
for index in range(1, ndims + 1):
|
||||||
axes.append((-index,))
|
axes.append((-index,))
|
||||||
@ -145,7 +142,7 @@ class FftTest(jtu.JaxTestCase):
|
|||||||
for dtype in (real_dtypes if real and not inverse else all_dtypes)
|
for dtype in (real_dtypes if real and not inverse else all_dtypes)
|
||||||
],
|
],
|
||||||
[dict(shape=shape, axes=axes, s=s)
|
[dict(shape=shape, axes=axes, s=s)
|
||||||
for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)]
|
for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5), (2, 3, 4, 5, 6)]
|
||||||
for axes in _get_fftn_test_axes(shape)
|
for axes in _get_fftn_test_axes(shape)
|
||||||
for s in _get_fftn_test_s(shape, axes)
|
for s in _get_fftn_test_s(shape, axes)
|
||||||
],
|
],
|
||||||
@ -203,11 +200,6 @@ class FftTest(jtu.JaxTestCase):
|
|||||||
if inverse:
|
if inverse:
|
||||||
name = 'i' + name
|
name = 'i' + name
|
||||||
func = _get_fftn_func(jnp.fft, inverse, real)
|
func = _get_fftn_func(jnp.fft, inverse, real)
|
||||||
self.assertRaisesRegex(
|
|
||||||
ValueError,
|
|
||||||
"jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. "
|
|
||||||
"Got axes None with input rank 4.".format(name),
|
|
||||||
lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None))
|
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
f"jax.numpy.fft.{name} does not support repeated axes. Got axes \\[1, 1\\].",
|
f"jax.numpy.fft.{name} does not support repeated axes. Got axes \\[1, 1\\].",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user