Add implementation of jax.scipy.fftconvolve

This commit is contained in:
Jean-Eric Campagne 2023-04-07 17:19:08 +02:00
parent b15ebb1bc5
commit 4beee13ba0
4 changed files with 115 additions and 1 deletions

View File

@ -77,6 +77,7 @@ jax.scipy.signal
.. autosummary::
:toctree: _autosummary
fftconvolve
convolve
convolve2d
correlate

View File

@ -14,11 +14,12 @@
from functools import partial
import operator
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union, Sequence
import warnings
import numpy as np
import scipy.signal as osp_signal
from scipy.fft import next_fast_len as osp_fft_next_fast_len
import jax
import jax.numpy.fft
@ -34,6 +35,94 @@ from jax._src.typing import Array, ArrayLike
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
@_wraps(osp_signal.fftconvolve)
def fftconvolve(in1: ArrayLike ,in2: ArrayLike, mode: str = "full",
axes: Optional[Sequence[int]] = None) -> Array:
# sanity checks
check_arraylike('fftconvolve', in1, in2)
in1, in2 = promote_dtypes_inexact(in1, in2)
if in1.ndim == in2.ndim == 0: # scalar inputs
return in1 * in2
elif in1.ndim != in2.ndim:
raise ValueError("in1 and in2 should have the same dimensionality")
elif in1.size == 0 or in2.size == 0: # empty arrays
return jnp.array([], dtype=in1.dtype)
# warn current limitations
if mode not in ["same", "full", "valid"]:
raise ValueError("mode must be one of ['same', 'full', 'valid']")
if axes is None:
axes = range(in1.ndim)
else:
try:
axes = tuple(axes)
except TypeError:
raise ValueError("axes must be a tuple of ints or (single-int,)")
axes = [a + in1.ndim if a < 0 else a for a in axes]
if any(a >= in1.ndim or a < 0 for a in axes):
raise ValueError("axes exceeds dimensionality of input")
if len(set(axes)) != len(axes):
raise ValueError("all axes must be unique")
# necessary for mode=valid
axes = [a for a in axes if in1.shape[a] != 1 and in2.shape[a] != 1]
# see if one should swap inputs
if mode == "valid":
ok1 = all(in1.shape[i] >= in2.shape[i] for i in axes)
ok2 = all(in2.shape[i] >= in1.shape[i] for i in axes)
if not (ok1 or ok2):
raise ValueError("For 'valid' mode, one must be at least "
"as large as the other in every dimension")
if not ok1:
in1, in2 = in2, in1
s1 = in1.shape
s2 = in2.shape
shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1
for i in range(in1.ndim)]
def _centered(arr, newshape):
# Return the center newshape portion of the array.
newshape = np.asarray(newshape)
currshape = np.array(arr.shape)
startind = (currshape - newshape) // 2
endind = startind + newshape
myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
return arr[tuple(myslice)]
def _finalize(res, s1, mode, axes):
if mode == "full":
return res
elif mode == "same":
return _centered(conv, s1)
else:
shape_valid = [res.shape[a] if a not in axes else s1[a] - s2[a] + 1
for a in range(res.ndim)]
return _centered(res, shape_valid)
if len(axes) == 0:
conv = in1 * in2
return _finalize(conv, s1, mode, axes)
# compute the optimized FFT size (use of original scipy code as fshape is static)
fshape = [osp_fft_next_fast_len(shape[a]) for a in axes]
if in1.dtype.kind == 'c':
fft, ifft = jnp.fft.fftn, jnp.fft.ifftn
else:
fft, ifft = jnp.fft.rfftn, jnp.fft.irfftn
sp1 = fft(in1, fshape, axes=axes)
sp2 = fft(in2, fshape, axes=axes)
conv = ifft(sp1 * sp2, fshape, axes=axes)
conv = conv[tuple(map(slice, shape))]
return _finalize(conv, s1, mode, axes)
# Note: we do not re-use the code from jax.numpy.convolve here, because the handling
# of padding differs slightly between the two implementations (particularly for
# mode='same').

View File

@ -16,6 +16,7 @@
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.signal import (
fftconvolve as fftconvolve,
convolve as convolve,
convolve2d as convolve2d,
correlate as correlate,

View File

@ -91,6 +91,29 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
# Test fftconvolve: present JAX code implement only mode="same" or "full"
@jtu.sample_product(
[dict(xshape=xshape, yshape=yshape)
for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
for xshape in shapeset
for yshape in shapeset
],
mode=['full', 'same', 'valid'],
dtype=default_dtypes,
)
def testFFTConvolution(self, xshape, yshape, dtype, mode):
jsp_op = jsp_signal.fftconvolve
osp_op = osp_signal.fftconvolve
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
osp_fun = partial(osp_op, mode=mode)
jsp_fun = partial(jsp_op, mode=mode)
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve2d', 'correlate2d'],