mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
DOC: Improve docstrings for jax.scipy.signal
This commit is contained in:
parent
e691c19bb2
commit
18703d9385
@ -93,6 +93,7 @@ jax.scipy.signal
|
||||
correlate
|
||||
correlate2d
|
||||
csd
|
||||
detrend
|
||||
istft
|
||||
stft
|
||||
welch
|
||||
|
@ -34,15 +34,66 @@ from jax._src import dtypes
|
||||
from jax._src.lax.lax import PrecisionLike
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, implements, promote_dtypes_inexact, promote_dtypes_complex)
|
||||
check_arraylike, promote_dtypes_inexact, promote_dtypes_complex)
|
||||
from jax._src.third_party.scipy import signal_helper
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert
|
||||
|
||||
|
||||
@implements(osp_signal.fftconvolve)
|
||||
def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full",
|
||||
axes: Sequence[int] | None = None) -> Array:
|
||||
"""
|
||||
Convolve two N-dimensional arrays using Fast Fourier Transform (FFT).
|
||||
|
||||
JAX implementation of :func:`scipy.signal.fftconvolve`.
|
||||
|
||||
Args:
|
||||
in1: left-hand input to the convolution.
|
||||
in2: right-hand input to the convolution. Must have ``in1.ndim == in2.ndim``.
|
||||
mode: controls the size of the output. Available operations are:
|
||||
|
||||
* ``"full"``: (default) output the full convolution of the inputs.
|
||||
* ``"same"``: return a centered portion of the ``"full"`` output which
|
||||
is the same size as ``in1``.
|
||||
* ``"valid"``: return the portion of the ``"full"`` output which do not
|
||||
depend on padding at the array edges.
|
||||
|
||||
axes: optional sequence of axes along which to apply the convolution.
|
||||
|
||||
Returns:
|
||||
Array containing the convolved result.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.convolve`: 1D convolution
|
||||
- :func:`jax.scipy.signal.convolve`: direct convolution
|
||||
|
||||
Examples:
|
||||
A few 1D convolution examples. Because FFT-based convolution is approximate,
|
||||
We use :func:`jax.numpy.printoptions` below to adjust the printing precision:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3, 2, 1])
|
||||
>>> y = jnp.array([1, 1, 1])
|
||||
|
||||
Full convolution uses implicit zero-padding at the edges:
|
||||
|
||||
>>> with jax.numpy.printoptions(precision=3):
|
||||
... print(jax.scipy.signal.fftconvolve(x, y, mode='full'))
|
||||
[1. 3. 6. 7. 6. 3. 1.]
|
||||
|
||||
Specifying ``mode = 'same'`` returns a centered convolution the same size
|
||||
as the first input:
|
||||
|
||||
>>> with jax.numpy.printoptions(precision=3):
|
||||
... print(jax.scipy.signal.fftconvolve(x, y, mode='same'))
|
||||
[3. 6. 7. 6. 3.]
|
||||
|
||||
Specifying ``mode = 'valid'`` returns only the portion where the two arrays
|
||||
fully overlap:
|
||||
|
||||
>>> with jax.numpy.printoptions(precision=3):
|
||||
... print(jax.scipy.signal.fftconvolve(x, y, mode='valid'))
|
||||
[6. 7. 6.]
|
||||
"""
|
||||
check_arraylike('fftconvolve', in1, in2)
|
||||
in1, in2 = promote_dtypes_inexact(in1, in2)
|
||||
if in1.ndim != in2.ndim:
|
||||
@ -133,9 +184,63 @@ def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike)
|
||||
return result[0, 0]
|
||||
|
||||
|
||||
@implements(osp_signal.convolve)
|
||||
def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
|
||||
precision: PrecisionLike = None) -> Array:
|
||||
"""Convolution of two N-dimensional arrays.
|
||||
|
||||
JAX implementation of :func:`jax.scipy.signal.convolve`.
|
||||
|
||||
Args:
|
||||
in1: left-hand input to the convolution.
|
||||
in2: right-hand input to the convolution. Must have ``in1.ndim == in2.ndim``.
|
||||
mode: controls the size of the output. Available operations are:
|
||||
|
||||
* ``"full"``: (default) output the full convolution of the inputs.
|
||||
* ``"same"``: return a centered portion of the ``"full"`` output which
|
||||
is the same size as ``in1``.
|
||||
* ``"valid"``: return the portion of the ``"full"`` output which do not
|
||||
depend on padding at the array edges.
|
||||
|
||||
method: controls the computation method. Options are
|
||||
|
||||
* ``"auto"``: (default) always uses the ``"direct"`` method.
|
||||
* ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`.
|
||||
* ``"fft"``: compute the result via a fast Fourier transform.
|
||||
|
||||
precision: Specify the precision of the computation. Refer to
|
||||
:class:`jax.lax.Precision` for a description of available values.
|
||||
|
||||
Returns:
|
||||
Array containing the convolved result.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.convolve`: 1D convolution
|
||||
- :func:`jax.scipy.signal.convolve2d`: 2D convolution
|
||||
- :func:`jax.scipy.signal.correlate`: ND correlation
|
||||
|
||||
Examples:
|
||||
A few 1D convolution examples:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3, 2, 1])
|
||||
>>> y = jnp.array([1, 1, 1])
|
||||
|
||||
Full convolution uses implicit zero-padding at the edges:
|
||||
|
||||
>>> jax.scipy.signal.convolve(x, y, mode='full')
|
||||
Array([1., 3., 6., 7., 6., 3., 1.], dtype=float32)
|
||||
|
||||
Specifying ``mode = 'same'`` returns a centered convolution the same size
|
||||
as the first input:
|
||||
|
||||
>>> jax.scipy.signal.convolve(x, y, mode='same')
|
||||
Array([3., 6., 7., 6., 3.], dtype=float32)
|
||||
|
||||
Specifying ``mode = 'valid'`` returns only the portion where the two arrays
|
||||
fully overlap:
|
||||
|
||||
>>> jax.scipy.signal.convolve(x, y, mode='valid')
|
||||
Array([6., 7., 6.], dtype=float32)
|
||||
"""
|
||||
if method == 'fft':
|
||||
return fftconvolve(in1, in2, mode=mode)
|
||||
elif method in ['direct', 'auto']:
|
||||
@ -144,9 +249,42 @@ def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
|
||||
raise ValueError(f"Got {method=}; expected 'auto', 'fft', or 'direct'.")
|
||||
|
||||
|
||||
@implements(osp_signal.convolve2d)
|
||||
def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill',
|
||||
fillvalue: float = 0, precision: PrecisionLike = None) -> Array:
|
||||
"""Convolution of two 2-dimensional arrays.
|
||||
|
||||
JAX implementation of :func:`jax.scipy.signal.convolve2d`.
|
||||
|
||||
Args:
|
||||
in1: left-hand input to the convolution. Must have ``in1.ndim == 2``.
|
||||
in2: right-hand input to the convolution. Must have ``in2.ndim == 2``.
|
||||
mode: controls the size of the output. Available operations are:
|
||||
|
||||
* ``"full"``: (default) output the full convolution of the inputs.
|
||||
* ``"same"``: return a centered portion of the ``"full"`` output which
|
||||
is the same size as ``in1``.
|
||||
* ``"valid"``: return the portion of the ``"full"`` output which do not
|
||||
depend on padding at the array edges.
|
||||
|
||||
boundary: only ``"fill"`` is supported.
|
||||
fillvalue: only ``0`` is supported.
|
||||
method: controls the computation method. Options are
|
||||
|
||||
* ``"auto"``: (default) always uses the ``"direct"`` method.
|
||||
* ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`.
|
||||
* ``"fft"``: compute the result via a fast Fourier transform.
|
||||
|
||||
precision: Specify the precision of the computation. Refer to
|
||||
:class:`jax.lax.Precision` for a description of available values.
|
||||
|
||||
Returns:
|
||||
Array containing the convolved result.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.convolve`: 1D convolution
|
||||
- :func:`jax.scipy.signal.convolve`: ND convolution
|
||||
- :func:`jax.scipy.signal.correlate`: ND correlation
|
||||
"""
|
||||
if boundary != 'fill' or fillvalue != 0:
|
||||
raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
|
||||
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
|
||||
@ -154,15 +292,79 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill
|
||||
return _convolve_nd(in1, in2, mode, precision=precision)
|
||||
|
||||
|
||||
@implements(osp_signal.correlate)
|
||||
def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto',
|
||||
precision: PrecisionLike = None) -> Array:
|
||||
"""Cross-correlation of two N-dimensional arrays.
|
||||
|
||||
JAX implementation of :func:`jax.scipy.signal.correlate`.
|
||||
|
||||
Args:
|
||||
in1: left-hand input to the cross-correlation.
|
||||
in2: right-hand input to the cross-correlation. Must have ``in1.ndim == in2.ndim``.
|
||||
mode: controls the size of the output. Available operations are:
|
||||
|
||||
* ``"full"``: (default) output the full cross-correlation of the inputs.
|
||||
* ``"same"``: return a centered portion of the ``"full"`` output which
|
||||
is the same size as ``in1``.
|
||||
* ``"valid"``: return the portion of the ``"full"`` output which do not
|
||||
depend on padding at the array edges.
|
||||
|
||||
method: controls the computation method. Options are
|
||||
|
||||
* ``"auto"``: (default) always uses the ``"direct"`` method.
|
||||
* ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`.
|
||||
* ``"fft"``: compute the result via a fast Fourier transform.
|
||||
|
||||
precision: Specify the precision of the computation. Refer to
|
||||
:class:`jax.lax.Precision` for a description of available values.
|
||||
|
||||
Returns:
|
||||
Array containing the cross-correlation result.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.correlate`: 1D cross-correlation
|
||||
- :func:`jax.scipy.signal.correlate2d`: 2D cross-correlation
|
||||
- :func:`jax.scipy.signal.convolve`: ND convolution
|
||||
"""
|
||||
return convolve(in1, jnp.flip(in2.conj()), mode, precision=precision, method=method)
|
||||
|
||||
|
||||
@implements(osp_signal.correlate2d)
|
||||
def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill',
|
||||
fillvalue: float = 0, precision: PrecisionLike = None) -> Array:
|
||||
"""Cross-correlation of two 2-dimensional arrays.
|
||||
|
||||
JAX implementation of :func:`jax.scipy.signal.correlate2d`.
|
||||
|
||||
Args:
|
||||
in1: left-hand input to the cross-correlation. Must have ``in1.ndim == 2``.
|
||||
in2: right-hand input to the cross-correlation. Must have ``in2.ndim == 2``.
|
||||
mode: controls the size of the output. Available operations are:
|
||||
|
||||
* ``"full"``: (default) output the full cross-correlation of the inputs.
|
||||
* ``"same"``: return a centered portion of the ``"full"`` output which
|
||||
is the same size as ``in1``.
|
||||
* ``"valid"``: return the portion of the ``"full"`` output which do not
|
||||
depend on padding at the array edges.
|
||||
|
||||
boundary: only ``"fill"`` is supported.
|
||||
fillvalue: only ``0`` is supported.
|
||||
method: controls the computation method. Options are
|
||||
|
||||
* ``"auto"``: (default) always uses the ``"direct"`` method.
|
||||
* ``"direct"``: lower to :func:`jax.lax.conv_general_dilated`.
|
||||
* ``"fft"``: compute the result via a fast Fourier transform.
|
||||
|
||||
precision: Specify the precision of the computation. Refer to
|
||||
:class:`jax.lax.Precision` for a description of available values.
|
||||
|
||||
Returns:
|
||||
Array containing the cross-correlation result.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.correlate`: 1D cross-correlation
|
||||
- :func:`jax.scipy.signal.correlate`: ND cross-correlation
|
||||
- :func:`jax.scipy.signal.convolve`: ND convolution
|
||||
"""
|
||||
if boundary != 'fill' or fillvalue != 0:
|
||||
raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0")
|
||||
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
|
||||
@ -191,9 +393,51 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil
|
||||
return result
|
||||
|
||||
|
||||
@implements(osp_signal.detrend)
|
||||
def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0,
|
||||
overwrite_data: None = None) -> Array:
|
||||
"""
|
||||
Remove linear or piecewise linear trends from data.
|
||||
|
||||
JAX implementation of :func:`scipy.signal.detrend`.
|
||||
|
||||
Args:
|
||||
data: The input array containing the data to detrend.
|
||||
axis: The axis along which to detrend. Default is -1 (the last axis).
|
||||
type: The type of detrending. Can be:
|
||||
|
||||
* ``'linear'``: Fit a single linear trend for the entire data.
|
||||
* ``'constant'``: Remove the mean value of the data.
|
||||
|
||||
bp: A sequence of breakpoints. If given, piecewise linear trends
|
||||
are fit between these breakpoints.
|
||||
overwrite_data: This argument is not supported by JAX's implementation.
|
||||
|
||||
Returns:
|
||||
The detrended data array.
|
||||
|
||||
Example:
|
||||
A simple detrend operation in one dimension:
|
||||
|
||||
>>> data = jnp.array([1., 4., 8., 8., 9.])
|
||||
|
||||
Removing a linear trend from the data:
|
||||
|
||||
>>> detrended = jax.scipy.signal.detrend(data)
|
||||
>>> with jnp.printoptions(precision=3, suppress=True): # suppress float error
|
||||
... print("Detrended:", detrended)
|
||||
... print("Underlying trend:", data - detrended)
|
||||
Detrended: [-1. -0. 2. -0. -1.]
|
||||
Underlying trend: [ 2. 4. 6. 8. 10.]
|
||||
|
||||
Removing a constant trend from the data:
|
||||
|
||||
>>> detrended = jax.scipy.signal.detrend(data, type='constant')
|
||||
>>> with jnp.printoptions(precision=3): # suppress float error
|
||||
... print("Detrended:", detrended)
|
||||
... print("Underlying trend:", data - detrended)
|
||||
Detrended: [-5. -2. 2. 2. 3.]
|
||||
Underlying trend: [6. 6. 6. 6. 6.]
|
||||
"""
|
||||
if overwrite_data is not None:
|
||||
raise NotImplementedError("overwrite_data argument not implemented.")
|
||||
if type not in ['constant', 'linear']:
|
||||
@ -499,11 +743,44 @@ def _spectral_helper(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0,
|
||||
return freqs, time, result
|
||||
|
||||
|
||||
@implements(osp_signal.stft)
|
||||
def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256,
|
||||
noverlap: int | None = None, nfft: int | None = None,
|
||||
detrend: bool = False, return_onesided: bool = True, boundary: str | None = 'zeros',
|
||||
padded: bool = True, axis: int = -1) -> tuple[Array, Array, Array]:
|
||||
"""
|
||||
Compute the short-time Fourier transform (STFT).
|
||||
|
||||
JAX implementation of :func:`scipy.signal.stft`.
|
||||
|
||||
Args:
|
||||
x: Array representing a time series of input values.
|
||||
fs: Sampling frequency of the time series (default: 1.0).
|
||||
window: Data tapering window to apply to each segment. Can be a window function name,
|
||||
a tuple specifying a window length and function, or an array (default: ``'hann'``).
|
||||
nperseg: Length of each segment (default: 256).
|
||||
noverlap: Number of points to overlap between segments (default: ``nperseg // 2``).
|
||||
nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default),
|
||||
the FFT length is ``nperseg``.
|
||||
detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending),
|
||||
``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable
|
||||
accepting a segment and returning a detrended segment.
|
||||
return_onesided: If True (default), return a one-sided spectrum for real inputs.
|
||||
If False, return a two-sided spectrum.
|
||||
boundary: Specifies whether the input signal is extended at both ends, and how.
|
||||
Options are ``None`` (no extension), ``'zeros'`` (default), ``'even'``, ``'odd'``,
|
||||
or ``'constant'``.
|
||||
padded: Specifies whether the input signal is zero-padded at the end to make its
|
||||
length a multiple of `nperseg`. If True (default), the padded signal length is
|
||||
the next multiple of ``nperseg``.
|
||||
axis: Axis along which the STFT is computed; the default is over the last axis (-1).
|
||||
|
||||
Returns:
|
||||
A length-3 tuple of arrays ``(f, t, Zxx)``. ``f`` is the Array of sample frequencies.
|
||||
``t`` is the Array of segment times, and ``Zxx`` is the STFT of ``x``.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.signal.istft`: inverse short-time Fourier transform.
|
||||
"""
|
||||
return _spectral_helper(x, None, fs, window, nperseg, noverlap,
|
||||
nfft, detrend, return_onesided,
|
||||
scaling='spectrum', axis=axis,
|
||||
@ -511,19 +788,56 @@ def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256
|
||||
padded=padded)
|
||||
|
||||
|
||||
_csd_description = """
|
||||
The original SciPy function exhibits slightly different behavior between
|
||||
``csd(x, x)``` and ```csd(x, x.copy())```. The LAX-backend version is designed
|
||||
to follow the latter behavior. For using the former behavior, call this
|
||||
function as `csd(x, None)`."""
|
||||
|
||||
|
||||
@implements(osp_signal.csd, lax_description=_csd_description)
|
||||
def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
nperseg: int | None = None, noverlap: int | None = None,
|
||||
nfft: int | None = None, detrend: str = 'constant',
|
||||
return_onesided: bool = True, scaling: str = 'density',
|
||||
axis: int = -1, average: str = 'mean') -> tuple[Array, Array]:
|
||||
"""
|
||||
Estimate cross power spectral density (CSD) using Welch's method.
|
||||
|
||||
This is a JAX implementation of :func:`scipy.signal.csd`. It is similar to
|
||||
:func:`jax.scipy.signal.welch`, but it operates on two input signals and
|
||||
estimates their cross-spectral density instead of the power spectral density
|
||||
(PSD).
|
||||
|
||||
Args:
|
||||
x: Array representing a time series of input values.
|
||||
y: Array representing the second time series of input values, the same length as ``x``
|
||||
along the specified ``axis``. If not specified, then assume ``y = x`` and compute
|
||||
the PSD ``Pxx`` of ``x`` via Welch's method.
|
||||
fs: Sampling frequency of the inputs (default: 1.0).
|
||||
window: Data tapering window to apply to each segment. Can be a window function name,
|
||||
a tuple specifying a window length and function, or an array (default: ``'hann'``).
|
||||
nperseg: Length of each segment (default: 256).
|
||||
noverlap: Number of points to overlap between segments (default: ``nperseg // 2``).
|
||||
nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default),
|
||||
the FFT length is ``nperseg``.
|
||||
detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending),
|
||||
``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable
|
||||
accepting a segment and returning a detrended segment.
|
||||
return_onesided: If True (default), return a one-sided spectrum for real inputs.
|
||||
If False, return a two-sided spectrum.
|
||||
scaling: Selects between computing the power spectral density (``'density'``, default)
|
||||
or the power spectrum (``'spectrum'``)
|
||||
axis: Axis along which the CSD is computed (default: -1).
|
||||
average: The type of averaging to use on the periodograms; one of ``'mean'`` (default)
|
||||
or ``'median'``.
|
||||
|
||||
Returns:
|
||||
A length-2 tuple of arrays ``(f, Pxy)``. ``f`` is the array of sample frequencies,
|
||||
and ``Pxy`` is the cross spectral density of `x` and `y`
|
||||
|
||||
Notes:
|
||||
The original SciPy function exhibits slightly different behavior between
|
||||
``csd(x, x)`` and ``csd(x, x.copy())``. The LAX-backend version is designed
|
||||
to follow the latter behavior. To replicate the former, call this function
|
||||
function as ``csd(x, None)``.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.signal.welch`: Power spectral density.
|
||||
- :func:`jax.scipy.signal.stft`: Short-time Fourier transform.
|
||||
"""
|
||||
freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft,
|
||||
detrend, return_onesided, scaling, axis,
|
||||
mode='psd')
|
||||
@ -551,12 +865,46 @@ def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann'
|
||||
return freqs, Pxy
|
||||
|
||||
|
||||
@implements(osp_signal.welch)
|
||||
def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
nperseg: int | None = None, noverlap: int | None = None,
|
||||
nfft: int | None = None, detrend: str = 'constant',
|
||||
return_onesided: bool = True, scaling: str = 'density',
|
||||
axis: int = -1, average: str = 'mean') -> tuple[Array, Array]:
|
||||
"""
|
||||
Estimate power spectral density (PSD) using Welch's method.
|
||||
|
||||
This is a JAX implementation of :func:`scipy.signal.welch`. It divides the
|
||||
input signal into overlapping segments, computes the modified periodogram for
|
||||
each segment, and averages the results to obtain a smoother estimate of the PSD.
|
||||
|
||||
Args:
|
||||
x: Array representing a time series of input values.
|
||||
fs: Sampling frequency of the inputs (default: 1.0).
|
||||
window: Data tapering window to apply to each segment. Can be a window function name,
|
||||
a tuple specifying a window length and function, or an array (default: ``'hann'``).
|
||||
nperseg: Length of each segment (default: 256).
|
||||
noverlap: Number of points to overlap between segments (default: ``nperseg // 2``).
|
||||
nfft: Length of the FFT used, if a zero-padded FFT is desired. If ``None`` (default),
|
||||
the FFT length is ``nperseg``.
|
||||
detrend: Specifies how to detrend each segment. Can be ``False`` (default: no detrending),
|
||||
``'constant'`` (remove mean), ``'linear'`` (remove linear trend), or a callable
|
||||
accepting a segment and returning a detrended segment.
|
||||
return_onesided: If True (default), return a one-sided spectrum for real inputs.
|
||||
If False, return a two-sided spectrum.
|
||||
scaling: Selects between computing the power spectral density (``'density'``, default)
|
||||
or the power spectrum (``'spectrum'``)
|
||||
axis: Axis along which the PSD is computed (default: -1).
|
||||
average: The type of averaging to use on the periodograms; one of ``'mean'`` (default)
|
||||
or ``'median'``.
|
||||
|
||||
Returns:
|
||||
A length-2 tuple of arrays ``(f, Pxx)``. ``f`` is the array of sample frequencies,
|
||||
and ``Pxx`` is the power spectral density of ``x``.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.signal.csd`: Cross power spectral density.
|
||||
- :func:`jax.scipy.signal.stft`: Short-time Fourier transform.
|
||||
"""
|
||||
freqs, Pxx = csd(x, None, fs=fs, window=window, nperseg=nperseg,
|
||||
noverlap=noverlap, nfft=nfft, detrend=detrend,
|
||||
return_onesided=return_onesided, scaling=scaling,
|
||||
@ -613,12 +961,54 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
|
||||
return x.reshape(tuple(batch_shape) + (-1,))
|
||||
|
||||
|
||||
@implements(osp_signal.istft)
|
||||
def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
|
||||
nperseg: int | None = None, noverlap: int | None = None,
|
||||
nfft: int | None = None, input_onesided: bool = True,
|
||||
boundary: bool = True, time_axis: int = -1,
|
||||
freq_axis: int = -2) -> tuple[Array, Array]:
|
||||
"""
|
||||
Perform the inverse short-time Fourier transform (ISTFT).
|
||||
|
||||
JAX implementation of :func:`scipy.signal.istft`; computes the inverse of
|
||||
:func:`jax.scipy.signal.stft`.
|
||||
|
||||
Args:
|
||||
Zxx: STFT of the signal to be reconstructed.
|
||||
fs: Sampling frequency of the time series (default: 1.0)
|
||||
window: Data tapering window to apply to each segment. Can be a window function name,
|
||||
a tuple specifying a window length and function, or an array (default: ``'hann'``).
|
||||
nperseg: Number of data points per segment in the STFT. If ``None`` (default), the
|
||||
value is determined from the size of ``Zxx``.
|
||||
noverlap: Number of points to overlap between segments (default: ``nperseg // 2``).
|
||||
nfft: Number of FFT points used in the STFT. If ``None`` (default), the
|
||||
value is determined from the size of ``Zxx``.
|
||||
input_onesided: If Tru` (default), interpret the input as a one-sided STFT
|
||||
(positive frequencies only). If False, interpret the input as a two-sided STFT.
|
||||
boundary: If True (default), it is assumed that the input signal was extended at
|
||||
its boundaries by ``stft``. If `False`, the input signal is assumed to have been truncated at the boundaries by `stft`.
|
||||
time_axis: Axis in `Zxx` corresponding to time segments (default: -1).
|
||||
freq_axis: Axis in `Zxx` corresponding to frequency bins (default: -2).
|
||||
|
||||
Returns:
|
||||
A length-2 tuple of arrays ``(t, x)``. ``t`` is the Array of signal times, and ``x``
|
||||
is the reconstructed time series.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.signal.stft`: short-time Fourier transform.
|
||||
|
||||
Example:
|
||||
Demonstrate that this gives the inverse of :func:`~jax.scipy.signal.stft`:
|
||||
|
||||
>>> x = jnp.array([1., 2., 3., 2., 1., 0., 1., 2.])
|
||||
>>> f, t, Zxx = jax.scipy.signal.stft(x, nperseg=4)
|
||||
>>> print(Zxx)
|
||||
[[ 1. +0.j 2.5+0.j 1. +0.j 1. +0.j 0.5+0.j ]
|
||||
[-0.5+0.5j -1.5+0.j -0.5-0.5j -0.5+0.5j 0. -0.5j]
|
||||
[ 0. +0.j 0.5+0.j 0. +0.j 0. +0.j -0.5+0.j ]]
|
||||
>>> t, x_reconstructed = jax.scipy.signal.istft(Zxx)
|
||||
>>> print(x_reconstructed)
|
||||
[1. 2. 3. 2. 1. 0. 1. 2.]
|
||||
"""
|
||||
# Input validation
|
||||
check_arraylike("istft", Zxx)
|
||||
if Zxx.ndim < 2:
|
||||
|
Loading…
x
Reference in New Issue
Block a user