DOC: Improve docstrings for jax.scipy.signal

This commit is contained in:
Jake VanderPlas 2024-05-01 14:26:00 -07:00
parent e691c19bb2
commit 18703d9385
2 changed files with 409 additions and 18 deletions

View File

@ -93,6 +93,7 @@ jax.scipy.signal
correlate
correlate2d
csd
detrend
istft
stft
welch

View File

@ -34,15 +34,66 @@ from jax._src import dtypes
from jax._src.lax.lax import PrecisionLike
from jax._src.numpy import linalg
from jax._src.numpy.util import (
check_arraylike, implements, promote_dtypes_inexact, promote_dtypes_complex)
check_arraylike, promote_dtypes_inexact, promote_dtypes_complex)
from jax._src.third_party.scipy import signal_helper
from jax._src.typing import Array, ArrayLike
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
@implements(osp_signal.fftconvolve)
def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full",
axes: Sequence[int] | None = None) -> Array:
"""
Convolve two N-dimensional arrays using Fast Fourier Transform (FFT).
JAX implementation of :func:`scipy.signal.fftconvolve`.
Args:
in1: left-hand input to the convolution.
in2: right-hand input to the convolution. Must have ``in1.ndim == in2.ndim``.
mode: controls the size of the output. Available operations are:
* ``"full"``: (default) output the full convolution of the inputs.
* ``"same"``: return a centered portion of the ``"full"`` output which
is the same size as ``in1``.
* ``"valid"``: return the portion of the ``"full"`` output which do not
depend on padding at the array edges.
axes: optional sequence of axes along which to apply the convolution.
Returns:
Array containing the convolved result.
See Also:
- :func:`jax.numpy.convolve`: 1D convolution
- :func:`jax.scipy.signal.convolve`: direct convolution
Examples:
A few 1D convolution examples. Because FFT-based convolution is approximate,
We use :func:`jax.numpy.printoptions` below to adjust the printing precision:
>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([1, 1, 1])
Full convolution uses implicit zero-padding at the edges:
>>> with jax.numpy.printoptions(precision=3):
... print(jax.scipy.signal.fftconvolve(x, y, mode='full'))
[1. 3. 6. 7. 6. 3. 1.]
Specifying ``mode = 'same'`` returns a centered convolution the same size
as the first input:
>>> with jax.numpy.printoptions(precision=3):
... print(jax.scipy.signal.fftconvolve(x, y, mode='same'))
[3. 6. 7. 6. 3.]
Specifying ``mode = 'valid'`` returns only the portion where the two arrays
fully overlap:
>>> with jax.numpy.printoptions(precision=3):
... print(jax.scipy.signal.fftconvolve(x, y, mode='valid'))
[6. 7. 6.]
"""
check_arraylike('fftconvolve', in1, in2)
in1, in2 = promote_dtypes_inexact(in1, in2)
if in1.ndim != in2.ndim:
@ -133,9 +184,63 @@ def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike)
return result[0, 0]
@implements(osp_signal.convolve)
def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
precision: PrecisionLike = None) -> Array:
"""Convolution of two N-dimensional arrays.
JAX implementation of :func:`jax.scipy.signal.convolve`.
Args:
in1: left-hand input to the convolution.
in2: right-hand input to the convolution. Must have ``in1.ndim == in2.ndim``.
mode: controls the size of the output. Available operations are:
* ``"full"``: (default) output the full convolution of the inputs.
* ``"same"``: return a centered portion of the ``"full"`` output which
is the same size as ``in1``.
* ``"valid"``: return the portion of the ``"full"`` output which do not
depend on padding at the array edges.
method: controls the computation method. Options are
* ``"auto"``: (default) always uses the ``"direct"`` method.
* ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`.
* ``"fft"``: compute the result via a fast Fourier transform.
precision: Specify the precision of the computation. Refer to
:class:`jax.lax.Precision` for a description of available values.
Returns:
Array containing the convolved result.
See Also:
- :func:`jax.numpy.convolve`: 1D convolution
- :func:`jax.scipy.signal.convolve2d`: 2D convolution
- :func:`jax.scipy.signal.correlate`: ND correlation
Examples:
A few 1D convolution examples:
>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([1, 1, 1])
Full convolution uses implicit zero-padding at the edges:
>>> jax.scipy.signal.convolve(x, y, mode='full')
Array([1., 3., 6., 7., 6., 3., 1.], dtype=float32)
Specifying ``mode = 'same'`` returns a centered convolution the same size
as the first input:
>>> jax.scipy.signal.convolve(x, y, mode='same')
Array([3., 6., 7., 6., 3.], dtype=float32)
Specifying ``mode = 'valid'`` returns only the portion where the two arrays
fully overlap:
>>> jax.scipy.signal.convolve(x, y, mode='valid')
Array([6., 7., 6.], dtype=float32)
"""
if method == 'fft':
return fftconvolve(in1, in2, mode=mode)
elif method in ['direct', 'auto']:
@ -144,9 +249,42 @@ def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
raise ValueError(f"Got {method=}; expected 'auto', 'fft', or 'direct'.")
@implements(osp_signal.convolve2d)
def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill',
fillvalue: float = 0, precision: PrecisionLike = None) -> Array:
"""Convolution of two 2-dimensional arrays.
JAX implementation of :func:`jax.scipy.signal.convolve2d`.
Args:
in1: left-hand input to the convolution. Must have ``in1.ndim == 2``.
in2: right-hand input to the convolution. Must have ``in2.ndim == 2``.
mode: controls the size of the output. Available operations are:
* ``"full"``: (default) output the full convolution of the inputs.
* ``"same"``: return a centered portion of the ``"full"`` output which
is the same size as ``in1``.
* ``"valid"``: return the portion of the ``"full"`` output which do not
depend on padding at the array edges.
boundary: only ``"fill"`` is supported.
fillvalue: only ``0`` is supported.
method: controls the computation method. Options are
* ``"auto"``: (default) always uses the ``"direct"`` method.
* ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`.
* ``"fft"``: compute the result via a fast Fourier transform.
precision: Specify the precision of the computation. Refer to
:class:`jax.lax.Precision` for a description of available values.
Returns:
Array containing the convolved result.
See Also:
- :func:`jax.numpy.convolve`: 1D convolution
- :func:`jax.scipy.signal.convolve`: ND convolution
- :func:`jax.scipy.signal.correlate`: ND correlation
"""
if boundary != 'fill' or fillvalue != 0:
raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
@ -154,15 +292,79 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill
return _convolve_nd(in1, in2, mode, precision=precision)
@implements(osp_signal.correlate)
def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
precision: PrecisionLike = None) -> Array:
"""Cross-correlation of two N-dimensional arrays.
JAX implementation of :func:`jax.scipy.signal.correlate`.
Args:
in1: left-hand input to the cross-correlation.
in2: right-hand input to the cross-correlation. Must have ``in1.ndim == in2.ndim``.
mode: controls the size of the output. Available operations are:
* ``"full"``: (default) output the full cross-correlation of the inputs.
* ``"same"``: return a centered portion of the ``"full"`` output which
is the same size as ``in1``.
* ``"valid"``: return the portion of the ``"full"`` output which do not
depend on padding at the array edges.
method: controls the computation method. Options are
* ``"auto"``: (default) always uses the ``"direct"`` method.
* ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`.
* ``"fft"``: compute the result via a fast Fourier transform.
precision: Specify the precision of the computation. Refer to
:class:`jax.lax.Precision` for a description of available values.
Returns:
Array containing the cross-correlation result.
See Also:
- :func:`jax.numpy.correlate`: 1D cross-correlation
- :func:`jax.scipy.signal.correlate2d`: 2D cross-correlation
- :func:`jax.scipy.signal.convolve`: ND convolution
"""
return convolve(in1, jnp.flip(in2.conj()), mode, precision=precision, method=method)
@implements(osp_signal.correlate2d)
def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill',
fillvalue: float = 0, precision: PrecisionLike = None) -> Array:
"""Cross-correlation of two 2-dimensional arrays.
JAX implementation of :func:`jax.scipy.signal.correlate2d`.
Args:
in1: left-hand input to the cross-correlation. Must have ``in1.ndim == 2``.
in2: right-hand input to the cross-correlation. Must have ``in2.ndim == 2``.
mode: controls the size of the output. Available operations are:
* ``"full"``: (default) output the full cross-correlation of the inputs.
* ``"same"``: return a centered portion of the ``"full"`` output which
is the same size as ``in1``.
* ``"valid"``: return the portion of the ``"full"`` output which do not
depend on padding at the array edges.
boundary: only ``"fill"`` is supported.
fillvalue: only ``0`` is supported.
method: controls the computation method. Options are
* ``"auto"``: (default) always uses the ``"direct"`` method.
* ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`.
* ``"fft"``: compute the result via a fast Fourier transform.
precision: Specify the precision of the computation. Refer to
:class:`jax.lax.Precision` for a description of available values.
Returns:
Array containing the cross-correlation result.
See Also:
- :func:`jax.numpy.correlate`: 1D cross-correlation
- :func:`jax.scipy.signal.correlate`: ND cross-correlation
- :func:`jax.scipy.signal.convolve`: ND convolution
"""
if boundary != 'fill' or fillvalue != 0:
raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0")
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
@ -191,9 +393,51 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil
return result
@implements(osp_signal.detrend)
def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0,
overwrite_data: None = None) -> Array:
"""
Remove linear or piecewise linear trends from data.
JAX implementation of :func:`scipy.signal.detrend`.
Args:
data: The input array containing the data to detrend.
axis: The axis along which to detrend. Default is -1 (the last axis).
type: The type of detrending. Can be:
* ``'linear'``: Fit a single linear trend for the entire data.
* ``'constant'``: Remove the mean value of the data.
bp: A sequence of breakpoints. If given, piecewise linear trends
are fit between these breakpoints.
overwrite_data: This argument is not supported by JAX's implementation.
Returns:
The detrended data array.
Example:
A simple detrend operation in one dimension:
>>> data = jnp.array([1., 4., 8., 8., 9.])
Removing a linear trend from the data:
>>> detrended = jax.scipy.signal.detrend(data)
>>> with jnp.printoptions(precision=3, suppress=True): # suppress float error
... print("Detrended:", detrended)
... print("Underlying trend:", data - detrended)
Detrended: [-1. -0. 2. -0. -1.]
Underlying trend: [ 2. 4. 6. 8. 10.]
Removing a constant trend from the data:
>>> detrended = jax.scipy.signal.detrend(data, type='constant')
>>> with jnp.printoptions(precision=3): # suppress float error
... print("Detrended:", detrended)
... print("Underlying trend:", data - detrended)
Detrended: [-5. -2. 2. 2. 3.]
Underlying trend: [6. 6. 6. 6. 6.]
"""
if overwrite_data is not None:
raise NotImplementedError("overwrite_data argument not implemented.")
if type not in ['constant', 'linear']:
@ -499,11 +743,44 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
return freqs, time, result
@implements(osp_signal.stft)
def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256,
noverlap: int | None = None, nfft: int | None = None,
detrend: bool = False, return_onesided: bool = True, boundary: str | None = 'zeros',
padded: bool = True, axis: int = -1) -> tuple[Array, Array, Array]:
"""
Compute the short-time Fourier transform (STFT).
JAX implementation of :func:`scipy.signal.stft`.
Args:
x: Array representing a time series of input values.
fs: Sampling frequency of the time series (default: 1.0).
window: Data tapering window to apply to each segment. Can be a window function name,
a tuple specifying a window length and function, or an array (default: ``'hann'``).
nperseg: Length of each segment (default: 256).
noverlap: Number of points to overlap between segments (default: ``nperseg // 2``).
nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default),
the FFT length is ``nperseg``.
detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending),
``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable
accepting a segment and returning a detrended segment.
return_onesided: If True (default), return a one-sided spectrum for real inputs.
If False, return a two-sided spectrum.
boundary: Specifies whether the input signal is extended at both ends, and how.
Options are ``None`` (no extension), ``'zeros'`` (default), ``'even'``, ``'odd'``,
or ``'constant'``.
padded: Specifies whether the input signal is zero-padded at the end to make its
length a multiple of `nperseg`. If True (default), the padded signal length is
the next multiple of ``nperseg``.
axis: Axis along which the STFT is computed; the default is over the last axis (-1).
Returns:
A length-3 tuple of arrays ``(f, t, Zxx)``. ``f`` is the Array of sample frequencies.
``t`` is the Array of segment times, and ``Zxx`` is the STFT of ``x``.
See Also:
:func:`jax.scipy.signal.istft`: inverse short-time Fourier transform.
"""
return _spectral_helper(x, None, fs, window, nperseg, noverlap,
nfft, detrend, return_onesided,
scaling='spectrum', axis=axis,
@ -511,19 +788,56 @@ def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256
padded=padded)
_csd_description = """
The original SciPy function exhibits slightly different behavior between
``csd(x, x)``` and ```csd(x, x.copy())```. The LAX-backend version is designed
to follow the latter behavior. For using the former behavior, call this
function as `csd(x, None)`."""
@implements(osp_signal.csd, lax_description=_csd_description)
def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann',
nperseg: int | None = None, noverlap: int | None = None,
nfft: int | None = None, detrend: str = 'constant',
return_onesided: bool = True, scaling: str = 'density',
axis: int = -1, average: str = 'mean') -> tuple[Array, Array]:
"""
Estimate cross power spectral density (CSD) using Welch's method.
This is a JAX implementation of :func:`scipy.signal.csd`. It is similar to
:func:`jax.scipy.signal.welch`, but it operates on two input signals and
estimates their cross-spectral density instead of the power spectral density
(PSD).
Args:
x: Array representing a time series of input values.
y: Array representing the second time series of input values, the same length as ``x``
along the specified ``axis``. If not specified, then assume ``y = x`` and compute
the PSD ``Pxx`` of ``x`` via Welch's method.
fs: Sampling frequency of the inputs (default: 1.0).
window: Data tapering window to apply to each segment. Can be a window function name,
a tuple specifying a window length and function, or an array (default: ``'hann'``).
nperseg: Length of each segment (default: 256).
noverlap: Number of points to overlap between segments (default: ``nperseg // 2``).
nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default),
the FFT length is ``nperseg``.
detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending),
``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable
accepting a segment and returning a detrended segment.
return_onesided: If True (default), return a one-sided spectrum for real inputs.
If False, return a two-sided spectrum.
scaling: Selects between computing the power spectral density (``'density'``, default)
or the power spectrum (``'spectrum'``)
axis: Axis along which the CSD is computed (default: -1).
average: The type of averaging to use on the periodograms; one of ``'mean'`` (default)
or ``'median'``.
Returns:
A length-2 tuple of arrays ``(f, Pxy)``. ``f`` is the array of sample frequencies,
and ``Pxy`` is the cross spectral density of `x` and `y`
Notes:
The original SciPy function exhibits slightly different behavior between
``csd(x, x)`` and ``csd(x, x.copy())``. The LAX-backend version is designed
to follow the latter behavior. To replicate the former, call this function
function as ``csd(x, None)``.
See Also:
- :func:`jax.scipy.signal.welch`: Power spectral density.
- :func:`jax.scipy.signal.stft`: Short-time Fourier transform.
"""
freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
detrend, return_onesided, scaling, axis,
mode='psd')
@ -551,12 +865,46 @@ def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann'
return freqs, Pxy
@implements(osp_signal.welch)
def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann',
nperseg: int | None = None, noverlap: int | None = None,
nfft: int | None = None, detrend: str = 'constant',
return_onesided: bool = True, scaling: str = 'density',
axis: int = -1, average: str = 'mean') -> tuple[Array, Array]:
"""
Estimate power spectral density (PSD) using Welch's method.
This is a JAX implementation of :func:`scipy.signal.welch`. It divides the
input signal into overlapping segments, computes the modified periodogram for
each segment, and averages the results to obtain a smoother estimate of the PSD.
Args:
x: Array representing a time series of input values.
fs: Sampling frequency of the inputs (default: 1.0).
window: Data tapering window to apply to each segment. Can be a window function name,
a tuple specifying a window length and function, or an array (default: ``'hann'``).
nperseg: Length of each segment (default: 256).
noverlap: Number of points to overlap between segments (default: ``nperseg // 2``).
nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default),
the FFT length is ``nperseg``.
detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending),
``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable
accepting a segment and returning a detrended segment.
return_onesided: If True (default), return a one-sided spectrum for real inputs.
If False, return a two-sided spectrum.
scaling: Selects between computing the power spectral density (``'density'``, default)
or the power spectrum (``'spectrum'``)
axis: Axis along which the PSD is computed (default: -1).
average: The type of averaging to use on the periodograms; one of ``'mean'`` (default)
or ``'median'``.
Returns:
A length-2 tuple of arrays ``(f, Pxx)``. ``f`` is the array of sample frequencies,
and ``Pxx`` is the power spectral density of ``x``.
See Also:
- :func:`jax.scipy.signal.csd`: Cross power spectral density.
- :func:`jax.scipy.signal.stft`: Short-time Fourier transform.
"""
freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend,
return_onesided=return_onesided, scaling=scaling,
@ -613,12 +961,54 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
return x.reshape(tuple(batch_shape) + (-1,))
@implements(osp_signal.istft)
def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
nperseg: int | None = None, noverlap: int | None = None,
nfft: int | None = None, input_onesided: bool = True,
boundary: bool = True, time_axis: int = -1,
freq_axis: int = -2) -> tuple[Array, Array]:
"""
Perform the inverse short-time Fourier transform (ISTFT).
JAX implementation of :func:`scipy.signal.istft`; computes the inverse of
:func:`jax.scipy.signal.stft`.
Args:
Zxx: STFT of the signal to be reconstructed.
fs: Sampling frequency of the time series (default: 1.0)
window: Data tapering window to apply to each segment. Can be a window function name,
a tuple specifying a window length and function, or an array (default: ``'hann'``).
nperseg: Number of data points per segment in the STFT. If ``None`` (default), the
value is determined from the size of ``Zxx``.
noverlap: Number of points to overlap between segments (default: ``nperseg // 2``).
nfft: Number of FFT points used in the STFT. If ``None`` (default), the
value is determined from the size of ``Zxx``.
input_onesided: If Tru` (default), interpret the input as a one-sided STFT
(positive frequencies only). If False, interpret the input as a two-sided STFT.
boundary: If True (default), it is assumed that the input signal was extended at
its boundaries by ``stft``. If `False`, the input signal is assumed to have been truncated at the boundaries by `stft`.
time_axis: Axis in `Zxx` corresponding to time segments (default: -1).
freq_axis: Axis in `Zxx` corresponding to frequency bins (default: -2).
Returns:
A length-2 tuple of arrays ``(t, x)``. ``t`` is the Array of signal times, and ``x``
is the reconstructed time series.
See Also:
:func:`jax.scipy.signal.stft`: short-time Fourier transform.
Example:
Demonstrate that this gives the inverse of :func:`~jax.scipy.signal.stft`:
>>> x = jnp.array([1., 2., 3., 2., 1., 0., 1., 2.])
>>> f, t, Zxx = jax.scipy.signal.stft(x, nperseg=4)
>>> print(Zxx)
[[ 1. +0.j 2.5+0.j 1. +0.j 1. +0.j 0.5+0.j ]
[-0.5+0.5j -1.5+0.j -0.5-0.5j -0.5+0.5j 0. -0.5j]
[ 0. +0.j 0.5+0.j 0. +0.j 0. +0.j -0.5+0.j ]]
>>> t, x_reconstructed = jax.scipy.signal.istft(Zxx)
>>> print(x_reconstructed)
[1. 2. 3. 2. 1. 0. 1. 2.]
"""
# Input validation
check_arraylike("istft", Zxx)
if Zxx.ndim < 2: