mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #10058 from yotarok:istft
PiperOrigin-RevId: 438832534
This commit is contained in:
commit
e766b96063
@ -79,6 +79,7 @@ jax.scipy.signal
|
||||
correlate
|
||||
correlate2d
|
||||
csd
|
||||
istft
|
||||
stft
|
||||
welch
|
||||
|
||||
|
@ -492,3 +492,138 @@ def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
|
||||
axis=axis, average=average)
|
||||
|
||||
return freqs, Pxx.real
|
||||
|
||||
|
||||
def _overlap_and_add(x, step_size):
|
||||
"""Utility function compatible with tf.signal.overlap_and_add.
|
||||
|
||||
Args:
|
||||
x: An array with `(..., frames, frame_length)`-shape.
|
||||
step_size: An integer denoting overlap offsets. Must be less than
|
||||
`frame_length`.
|
||||
|
||||
Returns:
|
||||
An array with `(..., output_size)`-shape containing overlapped signal.
|
||||
"""
|
||||
_check_arraylike("_overlap_and_add", x)
|
||||
step_size = jax.core.concrete_or_error(int, step_size,
|
||||
"step_size for overlap_and_add")
|
||||
if x.ndim < 2:
|
||||
raise ValueError('Input must have (..., frames, frame_length) shape.')
|
||||
|
||||
*batch_shape, nframes, segment_len = x.shape
|
||||
flat_batchsize = np.prod(batch_shape, dtype=np.int64)
|
||||
x = x.reshape((flat_batchsize, nframes, segment_len))
|
||||
output_size = step_size * (nframes - 1) + segment_len
|
||||
nstep_per_segment = 1 + (segment_len - 1) // step_size
|
||||
|
||||
# Here, we use shorter notation for axes.
|
||||
# B: batch_size, N: nframes, S: nstep_per_segment,
|
||||
# T: segment_len divided by S
|
||||
|
||||
padded_segment_len = nstep_per_segment * step_size
|
||||
x = jnp.pad(x, ((0, 0), (0, 0), (0, padded_segment_len - segment_len)))
|
||||
x = x.reshape((flat_batchsize, nframes, nstep_per_segment, step_size))
|
||||
|
||||
# For obtaining shifted signals, this routine reinterprets flattened array
|
||||
# with a shrinked axis. With appropriate truncation/ padding, this operation
|
||||
# pushes the last padded elements of the previous row to the head of the
|
||||
# current row.
|
||||
# See implementation of `overlap_and_add` in Tensorflow for details.
|
||||
x = x.transpose((0, 2, 1, 3)) # x: (B, S, N, T)
|
||||
x = jnp.pad(x, ((0, 0), (0, 0), (0, nframes), (0, 0))) # x: (B, S, N*2, T)
|
||||
shrinked = x.shape[2] - 1
|
||||
x = x.reshape((flat_batchsize, -1))
|
||||
x = x[:, :(nstep_per_segment * shrinked * step_size)]
|
||||
x = x.reshape((flat_batchsize, nstep_per_segment, shrinked * step_size))
|
||||
|
||||
# Finally, sum shifted segments, and truncate results to the output_size.
|
||||
x = x.sum(axis=1)[:, :output_size]
|
||||
return x.reshape(tuple(batch_shape) + (-1,))
|
||||
|
||||
|
||||
@_wraps(osp_signal.istft)
|
||||
def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
|
||||
input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2):
|
||||
# Input validation
|
||||
_check_arraylike("istft", Zxx)
|
||||
if Zxx.ndim < 2:
|
||||
raise ValueError('Input stft must be at least 2d!')
|
||||
freq_axis = canonicalize_axis(freq_axis, Zxx.ndim)
|
||||
time_axis = canonicalize_axis(time_axis, Zxx.ndim)
|
||||
if freq_axis == time_axis:
|
||||
raise ValueError('Must specify differing time and frequency axes!')
|
||||
|
||||
Zxx = jnp.asarray(Zxx, dtype=jax.dtypes.canonicalize_dtype(
|
||||
np.result_type(Zxx, np.complex64)))
|
||||
|
||||
n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided
|
||||
else Zxx.shape[freq_axis])
|
||||
|
||||
nperseg = jax.core.concrete_or_error(int, nperseg or n_default,
|
||||
"nperseg: segment length of STFT")
|
||||
if nperseg < 1:
|
||||
raise ValueError('nperseg must be a positive integer')
|
||||
|
||||
if nfft is None:
|
||||
nfft = n_default
|
||||
if input_onesided and nperseg == n_default + 1:
|
||||
nfft += 1 # Odd nperseg, no FFT padding
|
||||
else:
|
||||
nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT")
|
||||
if nfft < nperseg:
|
||||
raise ValueError(
|
||||
f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).')
|
||||
|
||||
noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2,
|
||||
"noverlap of STFT")
|
||||
if noverlap >= nperseg:
|
||||
raise ValueError('noverlap must be less than nperseg.')
|
||||
nstep = nperseg - noverlap
|
||||
|
||||
# Rearrange axes if necessary
|
||||
if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2:
|
||||
outer_idxs = tuple(
|
||||
idx for idx in range(Zxx.ndim) if idx not in {time_axis, freq_axis})
|
||||
Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis))
|
||||
|
||||
# Perform IFFT
|
||||
ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft
|
||||
# xsubs: [..., T, N], N is the number of frames, T is the frame length.
|
||||
xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :]
|
||||
|
||||
# Get window as array
|
||||
if isinstance(window, (str, tuple)):
|
||||
win = osp_signal.get_window(window, nperseg)
|
||||
win = jnp.asarray(win)
|
||||
else:
|
||||
win = jnp.asarray(window)
|
||||
if len(win.shape) != 1:
|
||||
raise ValueError('window must be 1-D')
|
||||
if win.shape[0] != nperseg:
|
||||
raise ValueError('window must have length of {0}'.format(nperseg))
|
||||
win = win.astype(xsubs.dtype)
|
||||
|
||||
xsubs *= win.sum() # This takes care of the 'spectrum' scaling
|
||||
|
||||
# make win broadcastable over xsubs
|
||||
win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1,))
|
||||
x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep)
|
||||
win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1)
|
||||
norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep)
|
||||
|
||||
# Remove extension points
|
||||
if boundary:
|
||||
x = x[..., nperseg//2:-(nperseg//2)]
|
||||
norm = norm[..., nperseg//2:-(nperseg//2)]
|
||||
x /= jnp.where(norm > 1e-10, norm, 1.0)
|
||||
|
||||
# Put axes back
|
||||
if x.ndim > 1:
|
||||
if time_axis != Zxx.ndim - 1:
|
||||
if freq_axis < time_axis:
|
||||
time_axis -= 1
|
||||
x = jnp.moveaxis(x, -1, time_axis)
|
||||
|
||||
time = jnp.arange(x.shape[0]) / fs
|
||||
return time, x
|
||||
|
@ -21,6 +21,7 @@ from jax._src.scipy.signal import (
|
||||
correlate2d as correlate2d,
|
||||
detrend as detrend,
|
||||
csd as csd,
|
||||
istft as istft,
|
||||
stft as stft,
|
||||
welch as welch,
|
||||
)
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
@ -47,6 +48,12 @@ csd_test_shapes = [
|
||||
((3, 17, 2), (3, 12, 2), 9, 3, 1),
|
||||
]
|
||||
welch_test_shapes = stft_test_shapes
|
||||
istft_test_shapes = [
|
||||
# (input_shape, nperseg, noverlap, timeaxis, freqaxis)
|
||||
((3, 2, 64, 31), 100, 75, -1, -2),
|
||||
((17, 8, 5), 13, 7, 0, 1),
|
||||
((65, 24), 24, 7, -2, -1),
|
||||
]
|
||||
|
||||
|
||||
default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
||||
@ -376,6 +383,63 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
||||
f"_fs={fs}_window={window}_boundary={boundary}"
|
||||
f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}"
|
||||
f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}",
|
||||
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
||||
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
||||
"onesided": onesided, "boundary": boundary,
|
||||
"timeaxis": timeaxis, "freqaxis": freqaxis}
|
||||
for shape, nperseg, noverlap, timeaxis, freqaxis in istft_test_shapes
|
||||
for dtype in default_dtypes
|
||||
for fs in [1.0, 16000.0]
|
||||
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
||||
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
||||
for onesided in [False, True]
|
||||
for boundary in [False, True]))
|
||||
@jtu.skip_on_devices("rocm") # will be fixed in ROCm 5.1
|
||||
def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
||||
noverlap, nfft, onesided, boundary,
|
||||
timeaxis, freqaxis):
|
||||
if not onesided:
|
||||
new_freq_len = (shape[freqaxis] - 1) * 2
|
||||
shape = shape[:freqaxis] + (new_freq_len ,) + shape[freqaxis + 1:]
|
||||
|
||||
def osp_fun(x, fs):
|
||||
# Ignore UserWarning in osp so we can also test over ill-posed cases.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
result = osp_signal.istft(
|
||||
x,
|
||||
fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
|
||||
nfft=nfft, input_onesided=onesided, boundary=boundary,
|
||||
time_axis=timeaxis, freq_axis=freqaxis)
|
||||
return result
|
||||
|
||||
jsp_fun = partial(jsp_signal.istft,
|
||||
window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft,
|
||||
input_onesided=onesided, boundary=boundary,
|
||||
time_axis=timeaxis, freq_axis=freqaxis)
|
||||
|
||||
tol = {
|
||||
np.float32: 1e-4, np.float64: 1e-6,
|
||||
np.complex64: 1e-4, np.complex128: 1e-6
|
||||
}
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
tol = _TPU_FFT_TOL
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_fs = jtu.rand_uniform(self.rng(), 1.0, 16000.0)
|
||||
args_maker = lambda: [rng(shape, dtype), rng_fs((), np.float)]
|
||||
|
||||
# Here, dtype of output signal is different depending on osp versions,
|
||||
# and so depending on the test environment. Thus, dtype check is disabled.
|
||||
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user