rocm_jax/jax/_src/scipy/signal.py
Yotaro Kubo e085370ec4 Add some functions for spectral analysis.
This commit adds "stft", "csd", and "welch" functions in scipy.signal.
2022-02-17 15:59:24 +09:00

495 lines
17 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import scipy.signal as osp_signal
import operator
import warnings
import numpy as np
import jax
import jax.numpy.fft
from jax import lax
from jax._src.numpy.lax_numpy import _check_arraylike
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import linalg
from jax._src.numpy.lax_numpy import _promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.third_party.scipy import signal_helper
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
# Note: we do not re-use the code from jax.numpy.convolve here, because the handling
# of padding differs slightly between the two implementations (particularly for
# mode='same').
def _convolve_nd(in1, in2, mode, *, precision):
if mode not in ["full", "same", "valid"]:
raise ValueError("mode must be one of ['full', 'same', 'valid']")
if in1.ndim != in2.ndim:
raise ValueError("in1 and in2 must have the same number of dimensions")
if in1.size == 0 or in2.size == 0:
raise ValueError(f"zero-size arrays not supported in convolutions, got shapes {in1.shape} and {in2.shape}.")
in1, in2 = _promote_dtypes_inexact(in1, in2)
no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape))
swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
if not (no_swap or swap):
raise ValueError("One input must be smaller than the other in every dimension.")
shape_o = in2.shape
if swap:
in1, in2 = in2, in1
shape = in2.shape
in2 = jnp.flip(in2)
if mode == 'valid':
padding = [(0, 0) for s in shape]
elif mode == 'same':
padding = [(s - 1 - (s_o - 1) // 2, s - s_o + (s_o - 1) // 2)
for (s, s_o) in zip(shape, shape_o)]
elif mode == 'full':
padding = [(s - 1, s - 1) for s in shape]
strides = tuple(1 for s in shape)
result = lax.conv_general_dilated(in1[None, None], in2[None, None], strides,
padding, precision=precision)
return result[0, 0]
@_wraps(osp_signal.convolve)
def convolve(in1, in2, mode='full', method='auto',
precision=None):
if method != 'auto':
warnings.warn("convolve() ignores method argument")
return _convolve_nd(in1, in2, mode, precision=precision)
@_wraps(osp_signal.convolve2d)
def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
precision=None):
if boundary != 'fill' or fillvalue != 0:
raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
raise ValueError("convolve2d() only supports 2-dimensional inputs.")
return _convolve_nd(in1, in2, mode, precision=precision)
@_wraps(osp_signal.correlate)
def correlate(in1, in2, mode='full', method='auto',
precision=None):
if method != 'auto':
warnings.warn("correlate() ignores method argument")
return _convolve_nd(in1, jnp.flip(in2.conj()), mode, precision=precision)
@_wraps(osp_signal.correlate2d)
def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
precision=None):
if boundary != 'fill' or fillvalue != 0:
raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0")
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
raise ValueError("correlate2d() only supports 2-dimensional inputs.")
swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
same_shape = all(s1 == s2 for s1, s2 in zip(in1.shape, in2.shape))
if mode == "same":
in1, in2 = jnp.flip(in1), in2.conj()
result = jnp.flip(_convolve_nd(in1, in2, mode, precision=precision))
elif mode == "valid":
if swap and not same_shape:
in1, in2 = jnp.flip(in2), in1.conj()
result = _convolve_nd(in1, in2, mode, precision=precision)
else:
in1, in2 = jnp.flip(in1), in2.conj()
result = jnp.flip(_convolve_nd(in1, in2, mode, precision=precision))
else:
if swap:
in1, in2 = jnp.flip(in2), in1.conj()
result = _convolve_nd(in1, in2, mode, precision=precision).conj()
else:
in1, in2 = jnp.flip(in1), in2.conj()
result = jnp.flip(_convolve_nd(in1, in2, mode, precision=precision))
return result
@_wraps(osp_signal.detrend)
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
if overwrite_data is not None:
raise NotImplementedError("overwrite_data argument not implemented.")
if type not in ['constant', 'linear']:
raise ValueError("Trend type must be 'linear' or 'constant'.")
data, = _promote_dtypes_inexact(jnp.asarray(data))
if type == 'constant':
return data - data.mean(axis, keepdims=True)
else:
N = data.shape[axis]
# bp is static, so we use np operations to avoid pushing to device.
bp = np.sort(np.unique(np.r_[0, bp, N]))
if bp[0] < 0 or bp[-1] > N:
raise ValueError("Breakpoints must be non-negative and less than length of data along given axis.")
data = jnp.moveaxis(data, axis, 0)
shape = data.shape
data = data.reshape(N, -1)
for m in range(len(bp) - 1):
Npts = bp[m + 1] - bp[m]
A = jnp.vstack([
jnp.ones(Npts, dtype=data.dtype),
jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts
]).T
sl = slice(bp[m], bp[m + 1])
coef, *_ = linalg.lstsq(A, data[sl])
data = data.at[sl].add(-jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
return jnp.moveaxis(data.reshape(shape), 0, axis)
def _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides):
"""Calculate windowed FFT in the same way the original SciPy does.
"""
if x.dtype.kind == 'i':
x = x.astype(win.dtype)
# Created strided array of data segments
if nperseg == 1 and noverlap == 0:
result = x[..., np.newaxis]
else:
step = nperseg - noverlap
*batch_shape, signal_length = x.shape
batch_shape = tuple(batch_shape)
x = x.reshape((int(np.prod(batch_shape)), signal_length))[..., np.newaxis]
result = jax.lax.conv_general_dilated_patches(
x, (nperseg,), (step,),
'VALID',
dimension_numbers=('NTC', 'OIT', 'NTC'))
result = result.reshape(batch_shape + result.shape[-2:])
# Detrend each data segment individually
result = detrend_func(result)
# Apply window by multiplication
result = win.reshape((1,) * len(batch_shape) + (1, nperseg)) * result
# Perform the fft on last axis. Zero-pads automatically
if sides == 'twosided':
return jax.numpy.fft.fft(result, n=nfft)
else:
return jax.numpy.fft.rfft(result.real, n=nfft)
def odd_ext(x, n, axis=-1):
"""Extends `x` along with `axis` by odd-extension.
This function was previously a part of "scipy.signal.signaltools" but is no
longer exposed.
Args:
x : input array
n : the number of points to be added to the both end
axis: the axis to be extended
"""
if n < 1:
return x
if n > x.shape[axis] - 1:
raise ValueError(
f"The extension length n ({n}) is too big. "
f"It must not exceed x.shape[axis]-1, which is {x.shape[axis] - 1}.")
left_end = lax.slice_in_dim(x, 0, 1, axis=axis)
left_ext = jnp.flip(lax.slice_in_dim(x, 1, n + 1, axis=axis), axis=axis)
right_end = lax.slice_in_dim(x, -1, None, axis=axis)
right_ext = jnp.flip(lax.slice_in_dim(x, -(n + 1), -1, axis=axis), axis=axis)
ext = jnp.concatenate((2 * left_end - left_ext,
x,
2 * right_end - right_ext),
axis=axis)
return ext
def _spectral_helper(x, y,
fs=1.0, window='hann', nperseg=None, noverlap=None,
nfft=None, detrend_type='constant', return_onesided=True,
scaling='density', axis=-1, mode='psd', boundary=None,
padded=False):
"""LAX-backend implementation of `scipy.signal._spectral_helper`.
Unlike the original helper function, `y` can be None for explicitly
indicating auto-spectral (non cross-spectral) computation. In addition to
this, `detrend` argument is renamed to `detrend_type` for avoiding internal
name overlap.
"""
if mode not in ('psd', 'stft'):
raise ValueError(f"Unknown value for mode {mode}, "
"must be one of: ('psd', 'stft')")
def make_pad(mode, **kwargs):
def pad(x, n, axis=-1):
pad_width = [(0, 0) for unused_n in range(x.ndim)]
pad_width[axis] = (n, n)
return jnp.pad(x, pad_width, mode, **kwargs)
return pad
boundary_funcs = {
'even': make_pad('reflect'),
'odd': odd_ext,
'constant': make_pad('edge'),
'zeros': make_pad('constant', constant_values=0.0),
None: lambda x, *args, **kwargs: x
}
# Check/ normalize inputs
if boundary not in boundary_funcs:
raise ValueError(
f"Unknown boundary option '{boundary}', "
f"must be one of: {list(boundary_funcs.keys())}")
axis = jax.core.concrete_or_error(operator.index, axis,
"axis of windowed-FFT")
axis = canonicalize_axis(axis, x.ndim)
if nperseg is not None: # if specified by user
nperseg = jax.core.concrete_or_error(int, nperseg,
"nperseg of windowed-FFT")
if nperseg < 1:
raise ValueError('nperseg must be a positive integer')
# parse window; if array like, then set nperseg = win.shape
win, nperseg = signal_helper._triage_segments(
window, nperseg, input_length=x.shape[axis])
if noverlap is None:
noverlap = nperseg // 2
else:
noverlap = jax.core.concrete_or_error(int, noverlap,
"noverlap of windowed-FFT")
if nfft is None:
nfft = nperseg
else:
nfft = jax.core.concrete_or_error(int, nfft,
"nfft of windowed-FFT")
_check_arraylike("_spectral_helper", x)
x = jnp.asarray(x)
if y is None:
outdtype = jax.dtypes.canonicalize_dtype(np.result_type(x, np.complex64))
else:
_check_arraylike("_spectral_helper", y)
y = jnp.asarray(y)
outdtype = jax.dtypes.canonicalize_dtype(
np.result_type(x, y, np.complex64))
if mode != 'psd':
raise ValueError("two-argument mode is available only when mode=='psd'")
if x.ndim != y.ndim:
raise ValueError(
"two-arguments must have the same rank ({x.ndim} vs {y.ndim}).")
# Check if we can broadcast the outer axes together
try:
outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis),
tuple_delete(y.shape, axis))
except ValueError as e:
raise ValueError('x and y cannot be broadcast together.') from e
# Special cases for size == 0
if y is None:
if x.size == 0:
return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape)
else:
if x.size == 0 or y.size == 0:
outshape = tuple_insert(
outershape, min([x.shape[axis], y.shape[axis]]), axis)
emptyout = jnp.zeros(outshape)
return emptyout, emptyout, emptyout
# Move time-axis to the end
if x.ndim > 1:
if axis != -1:
x = jnp.moveaxis(x, axis, -1)
if y is not None and y.ndim > 1:
y = jnp.moveaxis(y, axis, -1)
# Check if x and y are the same length, zero-pad if necessary
if y is not None:
if x.shape[-1] != y.shape[-1]:
if x.shape[-1] < y.shape[-1]:
pad_shape = list(x.shape)
pad_shape[-1] = y.shape[-1] - x.shape[-1]
x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1)
else:
pad_shape = list(y.shape)
pad_shape[-1] = x.shape[-1] - y.shape[-1]
y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1)
if nfft < nperseg:
raise ValueError('nfft must be greater than or equal to nperseg.')
if noverlap >= nperseg:
raise ValueError('noverlap must be less than nperseg.')
nstep = nperseg - noverlap
# Apply paddings
if boundary is not None:
ext_func = boundary_funcs[boundary]
x = ext_func(x, nperseg // 2, axis=-1)
if y is not None:
y = ext_func(y, nperseg // 2, axis=-1)
if padded:
# Pad to integer number of windowed segments
# I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg
zeros_shape = list(x.shape[:-1]) + [nadd]
x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1)
if y is not None:
zeros_shape = list(y.shape[:-1]) + [nadd]
y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1)
# Handle detrending and window functions
if not detrend_type:
def detrend_func(d):
return d
elif not hasattr(detrend_type, '__call__'):
def detrend_func(d):
return detrend(d, type=detrend_type, axis=-1)
elif axis != -1:
# Wrap this function so that it receives a shape that it could
# reasonably expect to receive.
def detrend_func(d):
d = jnp.moveaxis(d, axis, -1)
d = detrend_type(d)
return jnp.moveaxis(d, -1, axis)
else:
detrend_func = detrend_type
if np.result_type(win, np.complex64) != outdtype:
win = win.astype(outdtype)
# Determine scale
if scaling == 'density':
scale = 1.0 / (fs * (win * win).sum())
elif scaling == 'spectrum':
scale = 1.0 / win.sum()**2
else:
raise ValueError(f'Unknown scaling: {scaling}')
if mode == 'stft':
scale = jnp.sqrt(scale)
# Determine onesided/ two-sided
if return_onesided:
sides = 'onesided'
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
sides = 'twosided'
warnings.warn('Input data is complex, switching to '
'return_onesided=False')
else:
sides = 'twosided'
if sides == 'twosided':
freqs = jax.numpy.fft.fftfreq(nfft, 1/fs)
elif sides == 'onesided':
freqs = jax.numpy.fft.rfftfreq(nfft, 1/fs)
# Perform the windowed FFTs
result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides)
if y is not None:
# All the same operations on the y data
result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft,
sides)
result = jnp.conjugate(result) * result_y
elif mode == 'psd':
result = jnp.conjugate(result) * result
result *= scale
if sides == 'onesided' and mode == 'psd':
end = None if nfft % 2 else -1
result = result.at[..., 1:end].mul(2)
time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1,
nperseg - noverlap) / fs
if boundary is not None:
time -= (nperseg / 2) / fs
result = result.astype(outdtype)
# All imaginary parts are zero anyways
if y is None and mode != 'stft':
result = result.real
# Move frequency axis back to axis where the data came from
result = jnp.moveaxis(result, -1, axis)
return freqs, time, result
@_wraps(osp_signal.stft)
def stft(x, fs=1.0, window='hann', nperseg=256, noverlap=None, nfft=None,
detrend=False, return_onesided=True, boundary='zeros', padded=True,
axis=-1):
freqs, time, Zxx = _spectral_helper(x, None, fs, window, nperseg, noverlap,
nfft, detrend, return_onesided,
scaling='spectrum', axis=axis,
mode='stft', boundary=boundary,
padded=padded)
return freqs, time, Zxx
_csd_description = """
The original SciPy function exhibits slightly different behavior between
``csd(x, x)``` and ```csd(x, x.copy())```. The LAX-backend version is designed
to follow the latter behavior. For using the former behavior, call this
function as `csd(x, None)`."""
@_wraps(osp_signal.csd, lax_description=_csd_description)
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density',
axis=-1, average='mean'):
freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
detrend, return_onesided, scaling, axis,
mode='psd')
if y is not None:
Pxy = Pxy + 0j # Ensure complex output when x is not y
# Average over windows.
if Pxy.ndim >= 2 and Pxy.size > 0:
if Pxy.shape[-1] > 1:
if average == 'median':
bias = signal_helper._median_bias(Pxy.shape[-1]).astype(Pxy.dtype)
if jnp.iscomplexobj(Pxy):
Pxy = (jnp.median(jnp.real(Pxy), axis=-1)
+ 1j * jnp.median(jnp.imag(Pxy), axis=-1))
else:
Pxy = jnp.median(Pxy, axis=-1)
Pxy /= bias
elif average == 'mean':
Pxy = Pxy.mean(axis=-1)
else:
raise ValueError(f'average must be "median" or "mean", got {average}')
else:
Pxy = jnp.reshape(Pxy, Pxy.shape[:-1])
return freqs, Pxy
@_wraps(osp_signal.welch)
def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density',
axis=-1, average='mean'):
freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend,
return_onesided=return_onesided, scaling=scaling,
axis=axis, average=average)
return freqs, Pxx.real