From 4beee13ba03f617dd282f88ce005980a41217360 Mon Sep 17 00:00:00 2001 From: Jean-Eric Campagne Date: Fri, 7 Apr 2023 17:19:08 +0200 Subject: [PATCH] Add implementation of jax.scipy.fftconvolve --- docs/jax.scipy.rst | 1 + jax/_src/scipy/signal.py | 91 +++++++++++++++++++++++++++++++++++++- jax/scipy/signal.py | 1 + tests/scipy_signal_test.py | 23 ++++++++++ 4 files changed, 115 insertions(+), 1 deletion(-) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index cad15d09f..a9a616276 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -77,6 +77,7 @@ jax.scipy.signal .. autosummary:: :toctree: _autosummary + fftconvolve convolve convolve2d correlate diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 0e5f7aa31..0e4f3e6e7 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -14,11 +14,12 @@ from functools import partial import operator -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, Sequence import warnings import numpy as np import scipy.signal as osp_signal +from scipy.fft import next_fast_len as osp_fft_next_fast_len import jax import jax.numpy.fft @@ -34,6 +35,94 @@ from jax._src.typing import Array, ArrayLike from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert + +@_wraps(osp_signal.fftconvolve) +def fftconvolve(in1: ArrayLike ,in2: ArrayLike, mode: str = "full", + axes: Optional[Sequence[int]] = None) -> Array: + # sanity checks + check_arraylike('fftconvolve', in1, in2) + in1, in2 = promote_dtypes_inexact(in1, in2) + if in1.ndim == in2.ndim == 0: # scalar inputs + return in1 * in2 + elif in1.ndim != in2.ndim: + raise ValueError("in1 and in2 should have the same dimensionality") + elif in1.size == 0 or in2.size == 0: # empty arrays + return jnp.array([], dtype=in1.dtype) + # warn current limitations + if mode not in ["same", "full", "valid"]: + raise ValueError("mode must be one of ['same', 'full', 'valid']") + + if axes is None: + axes = range(in1.ndim) + else: + try: + axes = tuple(axes) + except TypeError: + raise ValueError("axes must be a tuple of ints or (single-int,)") + + axes = [a + in1.ndim if a < 0 else a for a in axes] + if any(a >= in1.ndim or a < 0 for a in axes): + raise ValueError("axes exceeds dimensionality of input") + if len(set(axes)) != len(axes): + raise ValueError("all axes must be unique") + + # necessary for mode=valid + axes = [a for a in axes if in1.shape[a] != 1 and in2.shape[a] != 1] + + # see if one should swap inputs + if mode == "valid": + ok1 = all(in1.shape[i] >= in2.shape[i] for i in axes) + ok2 = all(in2.shape[i] >= in1.shape[i] for i in axes) + if not (ok1 or ok2): + raise ValueError("For 'valid' mode, one must be at least " + "as large as the other in every dimension") + if not ok1: + in1, in2 = in2, in1 + + s1 = in1.shape + s2 = in2.shape + + shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 + for i in range(in1.ndim)] + + def _centered(arr, newshape): + # Return the center newshape portion of the array. + newshape = np.asarray(newshape) + currshape = np.array(arr.shape) + startind = (currshape - newshape) // 2 + endind = startind + newshape + myslice = [slice(startind[k], endind[k]) for k in range(len(endind))] + return arr[tuple(myslice)] + + def _finalize(res, s1, mode, axes): + if mode == "full": + return res + elif mode == "same": + return _centered(conv, s1) + else: + shape_valid = [res.shape[a] if a not in axes else s1[a] - s2[a] + 1 + for a in range(res.ndim)] + return _centered(res, shape_valid) + + if len(axes) == 0: + conv = in1 * in2 + return _finalize(conv, s1, mode, axes) + + # compute the optimized FFT size (use of original scipy code as fshape is static) + fshape = [osp_fft_next_fast_len(shape[a]) for a in axes] + + if in1.dtype.kind == 'c': + fft, ifft = jnp.fft.fftn, jnp.fft.ifftn + else: + fft, ifft = jnp.fft.rfftn, jnp.fft.irfftn + + sp1 = fft(in1, fshape, axes=axes) + sp2 = fft(in2, fshape, axes=axes) + conv = ifft(sp1 * sp2, fshape, axes=axes) + conv = conv[tuple(map(slice, shape))] + return _finalize(conv, s1, mode, axes) + + # Note: we do not re-use the code from jax.numpy.convolve here, because the handling # of padding differs slightly between the two implementations (particularly for # mode='same'). diff --git a/jax/scipy/signal.py b/jax/scipy/signal.py index a0fe5987c..7e39da3f9 100644 --- a/jax/scipy/signal.py +++ b/jax/scipy/signal.py @@ -16,6 +16,7 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.scipy.signal import ( + fftconvolve as fftconvolve, convolve as convolve, convolve2d as convolve2d, correlate as correlate, diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 622e36ac4..48df72c12 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -91,6 +91,29 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) + # Test fftconvolve: present JAX code implement only mode="same" or "full" + @jtu.sample_product( + [dict(xshape=xshape, yshape=yshape) + for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes] + for xshape in shapeset + for yshape in shapeset + ], + mode=['full', 'same', 'valid'], + dtype=default_dtypes, + ) + def testFFTConvolution(self, xshape, yshape, dtype, mode): + jsp_op = jsp_signal.fftconvolve + osp_op = osp_signal.fftconvolve + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] + osp_fun = partial(osp_op, mode=mode) + jsp_fun = partial(jsp_op, mode=mode) + tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12} + self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol) + + @jtu.sample_product( mode=['full', 'same', 'valid'], op=['convolve2d', 'correlate2d'],