mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add implementation of jax.scipy.fftconvolve
This commit is contained in:
parent
b15ebb1bc5
commit
4beee13ba0
@ -77,6 +77,7 @@ jax.scipy.signal
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
fftconvolve
|
||||
convolve
|
||||
convolve2d
|
||||
correlate
|
||||
|
@ -14,11 +14,12 @@
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union, Sequence
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal as osp_signal
|
||||
from scipy.fft import next_fast_len as osp_fft_next_fast_len
|
||||
|
||||
import jax
|
||||
import jax.numpy.fft
|
||||
@ -34,6 +35,94 @@ from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
|
||||
|
||||
|
||||
|
||||
@_wraps(osp_signal.fftconvolve)
|
||||
def fftconvolve(in1: ArrayLike ,in2: ArrayLike, mode: str = "full",
|
||||
axes: Optional[Sequence[int]] = None) -> Array:
|
||||
# sanity checks
|
||||
check_arraylike('fftconvolve', in1, in2)
|
||||
in1, in2 = promote_dtypes_inexact(in1, in2)
|
||||
if in1.ndim == in2.ndim == 0: # scalar inputs
|
||||
return in1 * in2
|
||||
elif in1.ndim != in2.ndim:
|
||||
raise ValueError("in1 and in2 should have the same dimensionality")
|
||||
elif in1.size == 0 or in2.size == 0: # empty arrays
|
||||
return jnp.array([], dtype=in1.dtype)
|
||||
# warn current limitations
|
||||
if mode not in ["same", "full", "valid"]:
|
||||
raise ValueError("mode must be one of ['same', 'full', 'valid']")
|
||||
|
||||
if axes is None:
|
||||
axes = range(in1.ndim)
|
||||
else:
|
||||
try:
|
||||
axes = tuple(axes)
|
||||
except TypeError:
|
||||
raise ValueError("axes must be a tuple of ints or (single-int,)")
|
||||
|
||||
axes = [a + in1.ndim if a < 0 else a for a in axes]
|
||||
if any(a >= in1.ndim or a < 0 for a in axes):
|
||||
raise ValueError("axes exceeds dimensionality of input")
|
||||
if len(set(axes)) != len(axes):
|
||||
raise ValueError("all axes must be unique")
|
||||
|
||||
# necessary for mode=valid
|
||||
axes = [a for a in axes if in1.shape[a] != 1 and in2.shape[a] != 1]
|
||||
|
||||
# see if one should swap inputs
|
||||
if mode == "valid":
|
||||
ok1 = all(in1.shape[i] >= in2.shape[i] for i in axes)
|
||||
ok2 = all(in2.shape[i] >= in1.shape[i] for i in axes)
|
||||
if not (ok1 or ok2):
|
||||
raise ValueError("For 'valid' mode, one must be at least "
|
||||
"as large as the other in every dimension")
|
||||
if not ok1:
|
||||
in1, in2 = in2, in1
|
||||
|
||||
s1 = in1.shape
|
||||
s2 = in2.shape
|
||||
|
||||
shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1
|
||||
for i in range(in1.ndim)]
|
||||
|
||||
def _centered(arr, newshape):
|
||||
# Return the center newshape portion of the array.
|
||||
newshape = np.asarray(newshape)
|
||||
currshape = np.array(arr.shape)
|
||||
startind = (currshape - newshape) // 2
|
||||
endind = startind + newshape
|
||||
myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
|
||||
return arr[tuple(myslice)]
|
||||
|
||||
def _finalize(res, s1, mode, axes):
|
||||
if mode == "full":
|
||||
return res
|
||||
elif mode == "same":
|
||||
return _centered(conv, s1)
|
||||
else:
|
||||
shape_valid = [res.shape[a] if a not in axes else s1[a] - s2[a] + 1
|
||||
for a in range(res.ndim)]
|
||||
return _centered(res, shape_valid)
|
||||
|
||||
if len(axes) == 0:
|
||||
conv = in1 * in2
|
||||
return _finalize(conv, s1, mode, axes)
|
||||
|
||||
# compute the optimized FFT size (use of original scipy code as fshape is static)
|
||||
fshape = [osp_fft_next_fast_len(shape[a]) for a in axes]
|
||||
|
||||
if in1.dtype.kind == 'c':
|
||||
fft, ifft = jnp.fft.fftn, jnp.fft.ifftn
|
||||
else:
|
||||
fft, ifft = jnp.fft.rfftn, jnp.fft.irfftn
|
||||
|
||||
sp1 = fft(in1, fshape, axes=axes)
|
||||
sp2 = fft(in2, fshape, axes=axes)
|
||||
conv = ifft(sp1 * sp2, fshape, axes=axes)
|
||||
conv = conv[tuple(map(slice, shape))]
|
||||
return _finalize(conv, s1, mode, axes)
|
||||
|
||||
|
||||
# Note: we do not re-use the code from jax.numpy.convolve here, because the handling
|
||||
# of padding differs slightly between the two implementations (particularly for
|
||||
# mode='same').
|
||||
|
@ -16,6 +16,7 @@
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.signal import (
|
||||
fftconvolve as fftconvolve,
|
||||
convolve as convolve,
|
||||
convolve2d as convolve2d,
|
||||
correlate as correlate,
|
||||
|
@ -91,6 +91,29 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
# Test fftconvolve: present JAX code implement only mode="same" or "full"
|
||||
@jtu.sample_product(
|
||||
[dict(xshape=xshape, yshape=yshape)
|
||||
for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
|
||||
for xshape in shapeset
|
||||
for yshape in shapeset
|
||||
],
|
||||
mode=['full', 'same', 'valid'],
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testFFTConvolution(self, xshape, yshape, dtype, mode):
|
||||
jsp_op = jsp_signal.fftconvolve
|
||||
osp_op = osp_signal.fftconvolve
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
||||
osp_fun = partial(osp_op, mode=mode)
|
||||
jsp_fun = partial(jsp_op, mode=mode)
|
||||
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12}
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
mode=['full', 'same', 'valid'],
|
||||
op=['convolve2d', 'correlate2d'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user