mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #9422 from yotarok:signal_stft
PiperOrigin-RevId: 429377655
This commit is contained in:
commit
54a6e4dad3
@ -13,15 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
import scipy.signal as osp_signal
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy.fft
|
||||
from jax import lax
|
||||
from jax._src.numpy.lax_numpy import _check_arraylike
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.third_party.scipy import signal_helper
|
||||
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
|
||||
|
||||
|
||||
# Note: we do not re-use the code from jax.numpy.convolve here, because the handling
|
||||
@ -146,3 +152,343 @@ def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
|
||||
coef, *_ = linalg.lstsq(A, data[sl])
|
||||
data = data.at[sl].add(-jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
|
||||
return jnp.moveaxis(data.reshape(shape), 0, axis)
|
||||
|
||||
|
||||
def _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides):
|
||||
"""Calculate windowed FFT in the same way the original SciPy does.
|
||||
"""
|
||||
if x.dtype.kind == 'i':
|
||||
x = x.astype(win.dtype)
|
||||
|
||||
# Created strided array of data segments
|
||||
if nperseg == 1 and noverlap == 0:
|
||||
result = x[..., np.newaxis]
|
||||
else:
|
||||
step = nperseg - noverlap
|
||||
*batch_shape, signal_length = x.shape
|
||||
batch_shape = tuple(batch_shape)
|
||||
x = x.reshape((int(np.prod(batch_shape)), signal_length))[..., np.newaxis]
|
||||
result = jax.lax.conv_general_dilated_patches(
|
||||
x, (nperseg,), (step,),
|
||||
'VALID',
|
||||
dimension_numbers=('NTC', 'OIT', 'NTC'))
|
||||
result = result.reshape(batch_shape + result.shape[-2:])
|
||||
|
||||
# Detrend each data segment individually
|
||||
result = detrend_func(result)
|
||||
|
||||
# Apply window by multiplication
|
||||
result = win.reshape((1,) * len(batch_shape) + (1, nperseg)) * result
|
||||
|
||||
# Perform the fft on last axis. Zero-pads automatically
|
||||
if sides == 'twosided':
|
||||
return jax.numpy.fft.fft(result, n=nfft)
|
||||
else:
|
||||
return jax.numpy.fft.rfft(result.real, n=nfft)
|
||||
|
||||
|
||||
def odd_ext(x, n, axis=-1):
|
||||
"""Extends `x` along with `axis` by odd-extension.
|
||||
|
||||
This function was previously a part of "scipy.signal.signaltools" but is no
|
||||
longer exposed.
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
n : the number of points to be added to the both end
|
||||
axis: the axis to be extended
|
||||
"""
|
||||
if n < 1:
|
||||
return x
|
||||
if n > x.shape[axis] - 1:
|
||||
raise ValueError(
|
||||
f"The extension length n ({n}) is too big. "
|
||||
f"It must not exceed x.shape[axis]-1, which is {x.shape[axis] - 1}.")
|
||||
left_end = lax.slice_in_dim(x, 0, 1, axis=axis)
|
||||
left_ext = jnp.flip(lax.slice_in_dim(x, 1, n + 1, axis=axis), axis=axis)
|
||||
right_end = lax.slice_in_dim(x, -1, None, axis=axis)
|
||||
right_ext = jnp.flip(lax.slice_in_dim(x, -(n + 1), -1, axis=axis), axis=axis)
|
||||
ext = jnp.concatenate((2 * left_end - left_ext,
|
||||
x,
|
||||
2 * right_end - right_ext),
|
||||
axis=axis)
|
||||
return ext
|
||||
|
||||
|
||||
def _spectral_helper(x, y,
|
||||
fs=1.0, window='hann', nperseg=None, noverlap=None,
|
||||
nfft=None, detrend_type='constant', return_onesided=True,
|
||||
scaling='density', axis=-1, mode='psd', boundary=None,
|
||||
padded=False):
|
||||
"""LAX-backend implementation of `scipy.signal._spectral_helper`.
|
||||
|
||||
Unlike the original helper function, `y` can be None for explicitly
|
||||
indicating auto-spectral (non cross-spectral) computation. In addition to
|
||||
this, `detrend` argument is renamed to `detrend_type` for avoiding internal
|
||||
name overlap.
|
||||
"""
|
||||
if mode not in ('psd', 'stft'):
|
||||
raise ValueError(f"Unknown value for mode {mode}, "
|
||||
"must be one of: ('psd', 'stft')")
|
||||
|
||||
def make_pad(mode, **kwargs):
|
||||
def pad(x, n, axis=-1):
|
||||
pad_width = [(0, 0) for unused_n in range(x.ndim)]
|
||||
pad_width[axis] = (n, n)
|
||||
return jnp.pad(x, pad_width, mode, **kwargs)
|
||||
return pad
|
||||
|
||||
boundary_funcs = {
|
||||
'even': make_pad('reflect'),
|
||||
'odd': odd_ext,
|
||||
'constant': make_pad('edge'),
|
||||
'zeros': make_pad('constant', constant_values=0.0),
|
||||
None: lambda x, *args, **kwargs: x
|
||||
}
|
||||
|
||||
# Check/ normalize inputs
|
||||
if boundary not in boundary_funcs:
|
||||
raise ValueError(
|
||||
f"Unknown boundary option '{boundary}', "
|
||||
f"must be one of: {list(boundary_funcs.keys())}")
|
||||
|
||||
axis = jax.core.concrete_or_error(operator.index, axis,
|
||||
"axis of windowed-FFT")
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
|
||||
if nperseg is not None: # if specified by user
|
||||
nperseg = jax.core.concrete_or_error(int, nperseg,
|
||||
"nperseg of windowed-FFT")
|
||||
if nperseg < 1:
|
||||
raise ValueError('nperseg must be a positive integer')
|
||||
# parse window; if array like, then set nperseg = win.shape
|
||||
win, nperseg = signal_helper._triage_segments(
|
||||
window, nperseg, input_length=x.shape[axis])
|
||||
|
||||
if noverlap is None:
|
||||
noverlap = nperseg // 2
|
||||
else:
|
||||
noverlap = jax.core.concrete_or_error(int, noverlap,
|
||||
"noverlap of windowed-FFT")
|
||||
if nfft is None:
|
||||
nfft = nperseg
|
||||
else:
|
||||
nfft = jax.core.concrete_or_error(int, nfft,
|
||||
"nfft of windowed-FFT")
|
||||
|
||||
_check_arraylike("_spectral_helper", x)
|
||||
x = jnp.asarray(x)
|
||||
|
||||
if y is None:
|
||||
outdtype = jax.dtypes.canonicalize_dtype(np.result_type(x, np.complex64))
|
||||
else:
|
||||
_check_arraylike("_spectral_helper", y)
|
||||
y = jnp.asarray(y)
|
||||
outdtype = jax.dtypes.canonicalize_dtype(
|
||||
np.result_type(x, y, np.complex64))
|
||||
if mode != 'psd':
|
||||
raise ValueError("two-argument mode is available only when mode=='psd'")
|
||||
if x.ndim != y.ndim:
|
||||
raise ValueError(
|
||||
"two-arguments must have the same rank ({x.ndim} vs {y.ndim}).")
|
||||
|
||||
# Check if we can broadcast the outer axes together
|
||||
try:
|
||||
outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis),
|
||||
tuple_delete(y.shape, axis))
|
||||
except ValueError as e:
|
||||
raise ValueError('x and y cannot be broadcast together.') from e
|
||||
|
||||
# Special cases for size == 0
|
||||
if y is None:
|
||||
if x.size == 0:
|
||||
return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape)
|
||||
else:
|
||||
if x.size == 0 or y.size == 0:
|
||||
outshape = tuple_insert(
|
||||
outershape, min([x.shape[axis], y.shape[axis]]), axis)
|
||||
emptyout = jnp.zeros(outshape)
|
||||
return emptyout, emptyout, emptyout
|
||||
|
||||
# Move time-axis to the end
|
||||
if x.ndim > 1:
|
||||
if axis != -1:
|
||||
x = jnp.moveaxis(x, axis, -1)
|
||||
if y is not None and y.ndim > 1:
|
||||
y = jnp.moveaxis(y, axis, -1)
|
||||
|
||||
# Check if x and y are the same length, zero-pad if necessary
|
||||
if y is not None:
|
||||
if x.shape[-1] != y.shape[-1]:
|
||||
if x.shape[-1] < y.shape[-1]:
|
||||
pad_shape = list(x.shape)
|
||||
pad_shape[-1] = y.shape[-1] - x.shape[-1]
|
||||
x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1)
|
||||
else:
|
||||
pad_shape = list(y.shape)
|
||||
pad_shape[-1] = x.shape[-1] - y.shape[-1]
|
||||
y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1)
|
||||
|
||||
if nfft < nperseg:
|
||||
raise ValueError('nfft must be greater than or equal to nperseg.')
|
||||
if noverlap >= nperseg:
|
||||
raise ValueError('noverlap must be less than nperseg.')
|
||||
nstep = nperseg - noverlap
|
||||
|
||||
# Apply paddings
|
||||
if boundary is not None:
|
||||
ext_func = boundary_funcs[boundary]
|
||||
x = ext_func(x, nperseg // 2, axis=-1)
|
||||
if y is not None:
|
||||
y = ext_func(y, nperseg // 2, axis=-1)
|
||||
|
||||
if padded:
|
||||
# Pad to integer number of windowed segments
|
||||
# I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
|
||||
nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg
|
||||
zeros_shape = list(x.shape[:-1]) + [nadd]
|
||||
x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1)
|
||||
if y is not None:
|
||||
zeros_shape = list(y.shape[:-1]) + [nadd]
|
||||
y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1)
|
||||
|
||||
# Handle detrending and window functions
|
||||
if not detrend_type:
|
||||
def detrend_func(d):
|
||||
return d
|
||||
elif not hasattr(detrend_type, '__call__'):
|
||||
def detrend_func(d):
|
||||
return detrend(d, type=detrend_type, axis=-1)
|
||||
elif axis != -1:
|
||||
# Wrap this function so that it receives a shape that it could
|
||||
# reasonably expect to receive.
|
||||
def detrend_func(d):
|
||||
d = jnp.moveaxis(d, axis, -1)
|
||||
d = detrend_type(d)
|
||||
return jnp.moveaxis(d, -1, axis)
|
||||
else:
|
||||
detrend_func = detrend_type
|
||||
|
||||
if np.result_type(win, np.complex64) != outdtype:
|
||||
win = win.astype(outdtype)
|
||||
|
||||
# Determine scale
|
||||
if scaling == 'density':
|
||||
scale = 1.0 / (fs * (win * win).sum())
|
||||
elif scaling == 'spectrum':
|
||||
scale = 1.0 / win.sum()**2
|
||||
else:
|
||||
raise ValueError(f'Unknown scaling: {scaling}')
|
||||
if mode == 'stft':
|
||||
scale = jnp.sqrt(scale)
|
||||
|
||||
# Determine onesided/ two-sided
|
||||
if return_onesided:
|
||||
sides = 'onesided'
|
||||
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
|
||||
sides = 'twosided'
|
||||
warnings.warn('Input data is complex, switching to '
|
||||
'return_onesided=False')
|
||||
else:
|
||||
sides = 'twosided'
|
||||
|
||||
if sides == 'twosided':
|
||||
freqs = jax.numpy.fft.fftfreq(nfft, 1/fs)
|
||||
elif sides == 'onesided':
|
||||
freqs = jax.numpy.fft.rfftfreq(nfft, 1/fs)
|
||||
|
||||
# Perform the windowed FFTs
|
||||
result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides)
|
||||
|
||||
if y is not None:
|
||||
# All the same operations on the y data
|
||||
result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft,
|
||||
sides)
|
||||
result = jnp.conjugate(result) * result_y
|
||||
elif mode == 'psd':
|
||||
result = jnp.conjugate(result) * result
|
||||
|
||||
result *= scale
|
||||
|
||||
if sides == 'onesided' and mode == 'psd':
|
||||
end = None if nfft % 2 else -1
|
||||
result = result.at[..., 1:end].mul(2)
|
||||
|
||||
time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1,
|
||||
nperseg - noverlap) / fs
|
||||
if boundary is not None:
|
||||
time -= (nperseg / 2) / fs
|
||||
|
||||
result = result.astype(outdtype)
|
||||
|
||||
# All imaginary parts are zero anyways
|
||||
if y is None and mode != 'stft':
|
||||
result = result.real
|
||||
|
||||
# Move frequency axis back to axis where the data came from
|
||||
result = jnp.moveaxis(result, -1, axis)
|
||||
|
||||
return freqs, time, result
|
||||
|
||||
|
||||
@_wraps(osp_signal.stft)
|
||||
def stft(x, fs=1.0, window='hann', nperseg=256, noverlap=None, nfft=None,
|
||||
detrend=False, return_onesided=True, boundary='zeros', padded=True,
|
||||
axis=-1):
|
||||
freqs, time, Zxx = _spectral_helper(x, None, fs, window, nperseg, noverlap,
|
||||
nfft, detrend, return_onesided,
|
||||
scaling='spectrum', axis=axis,
|
||||
mode='stft', boundary=boundary,
|
||||
padded=padded)
|
||||
|
||||
return freqs, time, Zxx
|
||||
|
||||
|
||||
_csd_description = """
|
||||
The original SciPy function exhibits slightly different behavior between
|
||||
``csd(x, x)``` and ```csd(x, x.copy())```. The LAX-backend version is designed
|
||||
to follow the latter behavior. For using the former behavior, call this
|
||||
function as `csd(x, None)`."""
|
||||
|
||||
|
||||
@_wraps(osp_signal.csd, lax_description=_csd_description)
|
||||
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
|
||||
detrend='constant', return_onesided=True, scaling='density',
|
||||
axis=-1, average='mean'):
|
||||
freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
|
||||
detrend, return_onesided, scaling, axis,
|
||||
mode='psd')
|
||||
if y is not None:
|
||||
Pxy = Pxy + 0j # Ensure complex output when x is not y
|
||||
|
||||
# Average over windows.
|
||||
if Pxy.ndim >= 2 and Pxy.size > 0:
|
||||
if Pxy.shape[-1] > 1:
|
||||
if average == 'median':
|
||||
bias = signal_helper._median_bias(Pxy.shape[-1]).astype(Pxy.dtype)
|
||||
if jnp.iscomplexobj(Pxy):
|
||||
Pxy = (jnp.median(jnp.real(Pxy), axis=-1)
|
||||
+ 1j * jnp.median(jnp.imag(Pxy), axis=-1))
|
||||
else:
|
||||
Pxy = jnp.median(Pxy, axis=-1)
|
||||
Pxy /= bias
|
||||
elif average == 'mean':
|
||||
Pxy = Pxy.mean(axis=-1)
|
||||
else:
|
||||
raise ValueError(f'average must be "median" or "mean", got {average}')
|
||||
else:
|
||||
Pxy = jnp.reshape(Pxy, Pxy.shape[:-1])
|
||||
|
||||
return freqs, Pxy
|
||||
|
||||
|
||||
@_wraps(osp_signal.welch)
|
||||
def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
|
||||
detrend='constant', return_onesided=True, scaling='density',
|
||||
axis=-1, average='mean'):
|
||||
freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg,
|
||||
noverlap=noverlap, nfft=nfft, detrend=detrend,
|
||||
return_onesided=return_onesided, scaling=scaling,
|
||||
axis=axis, average=average)
|
||||
|
||||
return freqs, Pxx.real
|
||||
|
84
jax/_src/third_party/scipy/signal_helper.py
vendored
Normal file
84
jax/_src/third_party/scipy/signal_helper.py
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
"""Utility functions adopted from scipy.signal."""
|
||||
|
||||
import scipy.signal as osp_signal
|
||||
import warnings
|
||||
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
|
||||
|
||||
def _triage_segments(window, nperseg, input_length):
|
||||
"""
|
||||
Parses window and nperseg arguments for spectrogram and _spectral_helper.
|
||||
This is a helper function, not meant to be called externally.
|
||||
Parameters
|
||||
----------
|
||||
window : string, tuple, or ndarray
|
||||
If window is specified by a string or tuple and nperseg is not
|
||||
specified, nperseg is set to the default of 256 and returns a window of
|
||||
that length.
|
||||
If instead the window is array_like and nperseg is not specified, then
|
||||
nperseg is set to the length of the window. A ValueError is raised if
|
||||
the user supplies both an array_like window and a value for nperseg but
|
||||
nperseg does not equal the length of the window.
|
||||
nperseg : int
|
||||
Length of each segment
|
||||
input_length: int
|
||||
Length of input signal, i.e. x.shape[-1]. Used to test for errors.
|
||||
Returns
|
||||
-------
|
||||
win : ndarray
|
||||
window. If function was called with string or tuple than this will hold
|
||||
the actual array used as a window.
|
||||
nperseg : int
|
||||
Length of each segment. If window is str or tuple, nperseg is set to
|
||||
256. If window is array_like, nperseg is set to the length of the
|
||||
6
|
||||
window.
|
||||
"""
|
||||
# parse window; if array like, then set nperseg = win.shape
|
||||
if isinstance(window, (str, tuple)):
|
||||
# if nperseg not specified
|
||||
if nperseg is None:
|
||||
nperseg = 256 # then change to default
|
||||
if nperseg > input_length:
|
||||
warnings.warn(f'nperseg = {nperseg} is greater than input length '
|
||||
f' = {input_length}, using nperseg = {nperseg}')
|
||||
nperseg = input_length
|
||||
win = jnp.array(osp_signal.get_window(window, nperseg))
|
||||
else:
|
||||
win = jnp.asarray(window)
|
||||
if len(win.shape) != 1:
|
||||
raise ValueError('window must be 1-D')
|
||||
if input_length < win.shape[-1]:
|
||||
raise ValueError('window is longer than input signal')
|
||||
if nperseg is None:
|
||||
nperseg = win.shape[0]
|
||||
elif nperseg is not None:
|
||||
if nperseg != win.shape[0]:
|
||||
raise ValueError("value specified for nperseg is different"
|
||||
" from length of window")
|
||||
return win, nperseg
|
||||
|
||||
|
||||
def _median_bias(n):
|
||||
"""
|
||||
Returns the bias of the median of a set of periodograms relative to
|
||||
the mean.
|
||||
See Appendix B from [1]_ for details.
|
||||
Parameters
|
||||
----------
|
||||
n : int
|
||||
Numbers of periodograms being averaged.
|
||||
Returns
|
||||
-------
|
||||
bias : float
|
||||
Calculated bias.
|
||||
References
|
||||
----------
|
||||
.. [1] B. Allen, W.G. Anderson, P.R. Brady, D.A. Brown, J.D.E. Creighton.
|
||||
"FINDCHIRP: an algorithm for detection of gravitational waves from
|
||||
inspiraling compact binaries", Physical Review D 85, 2012,
|
||||
:arxiv:`gr-qc/0509116`
|
||||
"""
|
||||
ii_2 = jnp.arange(2., n, 2)
|
||||
return 1 + jnp.sum(1. / (ii_2 + 1) - 1. / ii_2)
|
@ -20,4 +20,7 @@ from jax._src.scipy.signal import (
|
||||
correlate as correlate,
|
||||
correlate2d as correlate2d,
|
||||
detrend as detrend,
|
||||
csd as csd,
|
||||
stft as stft,
|
||||
welch as welch,
|
||||
)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
@ -30,9 +31,25 @@ config.parse_flags_with_absl()
|
||||
onedim_shapes = [(1,), (2,), (5,), (10,)]
|
||||
twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)]
|
||||
threedim_shapes = [(2, 2, 2), (3, 3, 2), (4, 4, 2), (5, 5, 2)]
|
||||
stft_test_shapes = [
|
||||
# (input_shape, nperseg, noverlap, axis)
|
||||
((50,), 17, 5, -1),
|
||||
((2, 13), 7, 0, -1),
|
||||
((3, 17, 2), 9, 3, 1),
|
||||
((2, 3, 389, 5), 17, 13, 2),
|
||||
((2, 1, 133, 3), 17, 13, -2),
|
||||
]
|
||||
csd_test_shapes = [
|
||||
# (x_input_shape, y_input_shape, nperseg, noverlap, axis)
|
||||
((50,), (13,), 17, 5, -1),
|
||||
((2, 13), (2, 13), 7, 0, -1),
|
||||
((3, 17, 2), (3, 12, 2), 9, 3, 1),
|
||||
]
|
||||
welch_test_shapes = stft_test_shapes
|
||||
|
||||
|
||||
default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
_TPU_FFT_TOL = 0.15
|
||||
|
||||
|
||||
class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
@ -104,6 +121,255 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}"
|
||||
f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}"
|
||||
f"_axis={timeaxis}_nfft={nfft}",
|
||||
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
||||
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
||||
"detrend": detrend, "boundary": boundary, "padded": padded,
|
||||
"timeaxis": timeaxis}
|
||||
for shape, nperseg, noverlap, timeaxis in stft_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for detrend in ['constant', 'linear', False]
|
||||
for boundary in [None, 'even', 'odd', 'zeros']
|
||||
for padded in [True, False]))
|
||||
def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
||||
noverlap, nfft, detrend, boundary, padded,
|
||||
timeaxis):
|
||||
is_complex = np.dtype(dtype).kind == 'c'
|
||||
if is_complex and detrend is not None:
|
||||
return
|
||||
|
||||
osp_fun = partial(osp_signal.stft,
|
||||
fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded,
|
||||
detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis,
|
||||
return_onesided=not is_complex)
|
||||
jsp_fun = partial(jsp_signal.stft,
|
||||
fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded,
|
||||
detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis,
|
||||
return_onesided=not is_complex)
|
||||
tol = {
|
||||
np.float32: 1e-5, np.float64: 1e-12,
|
||||
np.complex64: 1e-5, np.complex128: 1e-12
|
||||
}
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
tol = _TPU_FFT_TOL
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
# Tests with `average == 'median'`` is excluded from `testCsd*`
|
||||
# due to the issue:
|
||||
# https://github.com/scipy/scipy/issues/15601
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}"
|
||||
f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}"
|
||||
f"_average={average}_scaling={scaling}_nfft={nfft}"
|
||||
f"_fs={fs}_window={window}_detrend={detrend}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}"
|
||||
f"_axis={timeaxis}",
|
||||
"xshape": xshape, "yshape": yshape, "dtype": dtype, "fs": fs,
|
||||
"window": window, "nperseg": nperseg, "noverlap": noverlap,
|
||||
"nfft": nfft, "detrend": detrend, "scaling": scaling,
|
||||
"timeaxis": timeaxis, "average": average}
|
||||
for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for detrend in ['constant', 'linear', False]
|
||||
for scaling in ['density', 'spectrum']
|
||||
for average in ['mean']))
|
||||
def testCsdAgainstNumpy(
|
||||
self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft,
|
||||
detrend, scaling, timeaxis, average):
|
||||
is_complex = np.dtype(dtype).kind == 'c'
|
||||
if is_complex and detrend is not None:
|
||||
raise unittest.SkipTest(
|
||||
"Complex signal is not supported in lax-backed `signal.detrend`.")
|
||||
|
||||
osp_fun = partial(osp_signal.csd,
|
||||
fs=fs, window=window,
|
||||
nperseg=nperseg, noverlap=noverlap, nfft=nfft,
|
||||
detrend=detrend, return_onesided=not is_complex,
|
||||
scaling=scaling, axis=timeaxis, average=average)
|
||||
jsp_fun = partial(jsp_signal.csd,
|
||||
fs=fs, window=window,
|
||||
nperseg=nperseg, noverlap=noverlap, nfft=nfft,
|
||||
detrend=detrend, return_onesided=not is_complex,
|
||||
scaling=scaling, axis=timeaxis, average=average)
|
||||
tol = {
|
||||
np.float32: 1e-5, np.float64: 1e-12,
|
||||
np.complex64: 1e-5, np.complex128: 1e-12
|
||||
}
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
tol = _TPU_FFT_TOL
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_average={average}_scaling={scaling}_nfft={nfft}"
|
||||
f"_fs={fs}_window={window}_detrend={detrend}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}"
|
||||
f"_axis={timeaxis}",
|
||||
"shape": shape, "dtype": dtype, "fs": fs,
|
||||
"window": window, "nperseg": nperseg, "noverlap": noverlap,
|
||||
"nfft": nfft, "detrend": detrend, "scaling": scaling,
|
||||
"timeaxis": timeaxis, "average": average}
|
||||
for shape, unused_yshape, nperseg, noverlap, timeaxis in csd_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for detrend in ['constant', 'linear', False]
|
||||
for scaling in ['density', 'spectrum']
|
||||
for average in ['mean']))
|
||||
def testCsdWithSameParamAgainstNumpy(
|
||||
self, *, shape, dtype, fs, window, nperseg, noverlap, nfft,
|
||||
detrend, scaling, timeaxis, average):
|
||||
is_complex = np.dtype(dtype).kind == 'c'
|
||||
if is_complex and detrend is not None:
|
||||
raise unittest.SkipTest(
|
||||
"Complex signal is not supported in lax-backed `signal.detrend`.")
|
||||
|
||||
def osp_fun(x, y):
|
||||
# When the identical parameters are given, jsp-version follows
|
||||
# the behavior with copied parameters.
|
||||
freqs, Pxy = osp_signal.csd(
|
||||
x, y.copy(),
|
||||
fs=fs, window=window,
|
||||
nperseg=nperseg, noverlap=noverlap, nfft=nfft,
|
||||
detrend=detrend, return_onesided=not is_complex,
|
||||
scaling=scaling, axis=timeaxis, average=average)
|
||||
return freqs, Pxy
|
||||
jsp_fun = partial(jsp_signal.csd,
|
||||
fs=fs, window=window,
|
||||
nperseg=nperseg, noverlap=noverlap, nfft=nfft,
|
||||
detrend=detrend, return_onesided=not is_complex,
|
||||
scaling=scaling, axis=timeaxis, average=average)
|
||||
|
||||
tol = {
|
||||
np.float32: 1e-5, np.float64: 1e-12,
|
||||
np.complex64: 1e-5, np.complex128: 1e-12
|
||||
}
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
tol = _TPU_FFT_TOL
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)] * 2
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_fs={fs}_window={window}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}"
|
||||
f"_detrend={detrend}_return_onesided={return_onesided}"
|
||||
f"_scaling={scaling}_axis={timeaxis}_average={average}",
|
||||
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
||||
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
||||
"detrend": detrend, "return_onesided": return_onesided,
|
||||
"scaling": scaling, "timeaxis": timeaxis, "average": average}
|
||||
for shape, nperseg, noverlap, timeaxis in welch_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for detrend in ['constant', 'linear', False]
|
||||
for return_onesided in [True, False]
|
||||
for scaling in ['density', 'spectrum']
|
||||
for average in ['mean', 'median']))
|
||||
def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
||||
noverlap, nfft, detrend, return_onesided,
|
||||
scaling, timeaxis, average):
|
||||
if np.dtype(dtype).kind == 'c':
|
||||
return_onesided = False
|
||||
if detrend is not None:
|
||||
raise unittest.SkipTest(
|
||||
"Complex signal is not supported in lax-backed `signal.detrend`.")
|
||||
|
||||
osp_fun = partial(osp_signal.welch,
|
||||
fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft,
|
||||
detrend=detrend, return_onesided=return_onesided, scaling=scaling,
|
||||
axis=timeaxis, average=average)
|
||||
jsp_fun = partial(jsp_signal.welch,
|
||||
fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft,
|
||||
detrend=detrend, return_onesided=return_onesided, scaling=scaling,
|
||||
axis=timeaxis, average=average)
|
||||
tol = {
|
||||
np.float32: 1e-5, np.float64: 1e-12,
|
||||
np.complex64: 1e-5, np.complex128: 1e-12
|
||||
}
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
tol = _TPU_FFT_TOL
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}"
|
||||
f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}"
|
||||
f"_axis={timeaxis}",
|
||||
"shape": shape, "dtype": dtype,
|
||||
"nperseg": nperseg, "noverlap": noverlap,
|
||||
"use_nperseg": use_nperseg, "use_noverlap": use_noverlap,
|
||||
"timeaxis": timeaxis}
|
||||
for shape, nperseg, noverlap, timeaxis in welch_test_shapes
|
||||
for use_nperseg in [False, True]
|
||||
for use_noverlap in [False, True]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
|
||||
def testWelchWithDefaultStepArgsAgainstNumpy(
|
||||
self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap,
|
||||
timeaxis):
|
||||
kwargs = {
|
||||
'axis': timeaxis
|
||||
}
|
||||
|
||||
if use_nperseg:
|
||||
kwargs['nperseg'] = nperseg
|
||||
else:
|
||||
kwargs['window'] = osp_signal.get_window('hann', nperseg)
|
||||
if use_noverlap:
|
||||
kwargs['noverlap'] = noverlap
|
||||
|
||||
osp_fun = partial(osp_signal.welch, **kwargs)
|
||||
jsp_fun = partial(jsp_signal.welch, **kwargs)
|
||||
tol = {
|
||||
np.float32: 1e-5, np.float64: 1e-12,
|
||||
np.complex64: 1e-5, np.complex128: 1e-12
|
||||
}
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
tol = _TPU_FFT_TOL
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user