2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2020 The JAX Authors.
|
2020-04-10 11:54:10 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial
|
2022-02-03 14:18:44 +09:00
|
|
|
import unittest
|
2020-04-10 11:54:10 -07:00
|
|
|
|
|
|
|
from absl.testing import absltest, parameterized
|
|
|
|
|
2020-07-14 13:03:24 -07:00
|
|
|
import numpy as np
|
2022-06-17 12:26:15 -07:00
|
|
|
import scipy.signal as osp_signal
|
2020-04-10 11:54:10 -07:00
|
|
|
|
2020-04-14 18:23:19 -04:00
|
|
|
from jax import lax
|
2022-06-17 12:26:15 -07:00
|
|
|
import jax.numpy as jnp
|
|
|
|
from jax._src import dtypes
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2020-04-10 11:54:10 -07:00
|
|
|
import jax.scipy.signal as jsp_signal
|
|
|
|
|
|
|
|
from jax.config import config
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
|
|
onedim_shapes = [(1,), (2,), (5,), (10,)]
|
|
|
|
twodim_shapes = [(1, 1), (2, 2), (2, 3), (3, 4), (4, 4)]
|
2020-11-07 18:44:28 +05:30
|
|
|
threedim_shapes = [(2, 2, 2), (3, 3, 2), (4, 4, 2), (5, 5, 2)]
|
2022-02-03 14:18:44 +09:00
|
|
|
stft_test_shapes = [
|
|
|
|
# (input_shape, nperseg, noverlap, axis)
|
|
|
|
((50,), 17, 5, -1),
|
|
|
|
((2, 13), 7, 0, -1),
|
|
|
|
((3, 17, 2), 9, 3, 1),
|
|
|
|
((2, 3, 389, 5), 17, 13, 2),
|
|
|
|
((2, 1, 133, 3), 17, 13, -2),
|
2022-03-28 14:51:54 +09:00
|
|
|
((3, 7), 1, 0, 1),
|
2022-02-03 14:18:44 +09:00
|
|
|
]
|
|
|
|
csd_test_shapes = [
|
|
|
|
# (x_input_shape, y_input_shape, nperseg, noverlap, axis)
|
|
|
|
((50,), (13,), 17, 5, -1),
|
|
|
|
((2, 13), (2, 13), 7, 0, -1),
|
|
|
|
((3, 17, 2), (3, 12, 2), 9, 3, 1),
|
|
|
|
]
|
|
|
|
welch_test_shapes = stft_test_shapes
|
2022-03-28 21:33:40 +09:00
|
|
|
istft_test_shapes = [
|
|
|
|
# (input_shape, nperseg, noverlap, timeaxis, freqaxis)
|
|
|
|
((3, 2, 64, 31), 100, 75, -1, -2),
|
|
|
|
((17, 8, 5), 13, 7, 0, 1),
|
|
|
|
((65, 24), 24, 7, -2, -1),
|
|
|
|
]
|
2020-04-10 11:54:10 -07:00
|
|
|
|
|
|
|
|
2021-06-11 16:40:39 -06:00
|
|
|
default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex
|
2022-02-03 14:18:44 +09:00
|
|
|
_TPU_FFT_TOL = 0.15
|
2020-04-10 11:54:10 -07:00
|
|
|
|
2022-06-17 12:26:15 -07:00
|
|
|
def _real_dtype(dtype):
|
2022-08-12 12:51:09 +00:00
|
|
|
return jnp.finfo(dtypes.to_inexact_dtype(dtype)).dtype
|
2022-06-17 12:26:15 -07:00
|
|
|
|
|
|
|
def _complex_dtype(dtype):
|
2022-08-12 12:51:09 +00:00
|
|
|
return dtypes.to_complex_dtype(dtype)
|
2022-06-17 12:26:15 -07:00
|
|
|
|
2020-04-10 11:54:10 -07:00
|
|
|
|
|
|
|
class LaxBackedScipySignalTests(jtu.JaxTestCase):
|
|
|
|
"""Tests for LAX-backed scipy.stats implementations"""
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2020-06-22 19:49:00 -07:00
|
|
|
{"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format(
|
2020-04-10 11:54:10 -07:00
|
|
|
op,
|
|
|
|
jtu.format_shape_dtype_string(xshape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(yshape, dtype),
|
2020-11-10 09:16:44 -08:00
|
|
|
mode),
|
2020-04-10 11:54:10 -07:00
|
|
|
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
|
|
|
|
"jsp_op": getattr(jsp_signal, op),
|
|
|
|
"osp_op": getattr(osp_signal, op)}
|
|
|
|
for mode in ['full', 'same', 'valid']
|
|
|
|
for op in ['convolve', 'correlate']
|
|
|
|
for dtype in default_dtypes
|
2020-11-07 18:44:28 +05:30
|
|
|
for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
|
|
|
|
for xshape in shapeset
|
|
|
|
for yshape in shapeset))
|
2020-11-10 09:16:44 -08:00
|
|
|
def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
|
2020-05-04 23:00:20 -04:00
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-04-10 11:54:10 -07:00
|
|
|
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
|
|
|
osp_fun = partial(osp_op, mode=mode)
|
2020-04-14 18:23:19 -04:00
|
|
|
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
|
2021-06-11 16:40:39 -06:00
|
|
|
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12}
|
2020-04-10 11:54:10 -07:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
|
2021-06-11 16:40:39 -06:00
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
2020-04-10 11:54:10 -07:00
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2020-06-22 19:49:00 -07:00
|
|
|
{"testcase_name": "op={}_xshape={}_yshape={}_mode={}".format(
|
2020-04-10 11:54:10 -07:00
|
|
|
op,
|
|
|
|
jtu.format_shape_dtype_string(xshape, dtype),
|
|
|
|
jtu.format_shape_dtype_string(yshape, dtype),
|
|
|
|
mode),
|
|
|
|
"xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode,
|
|
|
|
"jsp_op": getattr(jsp_signal, op),
|
|
|
|
"osp_op": getattr(osp_signal, op)}
|
2020-04-10 14:11:16 -07:00
|
|
|
for mode in ['full', 'same', 'valid']
|
2020-04-10 11:54:10 -07:00
|
|
|
for op in ['convolve2d', 'correlate2d']
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for xshape in twodim_shapes
|
|
|
|
for yshape in twodim_shapes))
|
2020-05-04 23:00:20 -04:00
|
|
|
def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2020-04-10 11:54:10 -07:00
|
|
|
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
|
|
|
osp_fun = partial(osp_op, mode=mode)
|
2020-04-14 18:23:19 -04:00
|
|
|
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
|
2021-06-11 16:40:39 -06:00
|
|
|
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12}
|
2021-02-16 22:13:53 -05:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False,
|
|
|
|
tol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
2020-04-10 11:54:10 -07:00
|
|
|
|
2020-06-22 19:49:00 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_shape={}_axis={}_type={}_bp={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
|
|
|
|
"shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp}
|
|
|
|
for shape in [(5,), (4, 5), (3, 4, 5)]
|
2021-06-11 16:40:39 -06:00
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
|
2020-06-22 19:49:00 -07:00
|
|
|
for axis in [0, -1]
|
|
|
|
for type in ['constant', 'linear']
|
|
|
|
for bp in [0, [0, 2]]))
|
|
|
|
def testDetrend(self, shape, dtype, axis, type, bp):
|
2022-06-17 12:26:15 -07:00
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
kwds = dict(axis=axis, type=type, bp=bp)
|
2022-04-28 14:46:08 -07:00
|
|
|
|
2022-06-17 12:26:15 -07:00
|
|
|
def osp_fun(x):
|
2022-08-12 12:51:09 +00:00
|
|
|
return osp_signal.detrend(x, **kwds).astype(dtypes.to_inexact_dtype(x.dtype))
|
2022-06-17 12:26:15 -07:00
|
|
|
jsp_fun = partial(jsp_signal.detrend, **kwds)
|
2022-04-28 14:46:08 -07:00
|
|
|
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
2022-05-02 10:03:34 -07:00
|
|
|
tol = {np.float32: 3e-2, np.float64: 1e-12}
|
2022-04-28 14:46:08 -07:00
|
|
|
else:
|
|
|
|
tol = {np.float32: 1e-5, np.float64: 1e-12}
|
|
|
|
|
2020-06-22 19:49:00 -07:00
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
|
2022-02-03 14:18:44 +09:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
|
|
|
f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}"
|
|
|
|
f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}"
|
|
|
|
f"_axis={timeaxis}_nfft={nfft}",
|
|
|
|
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
|
|
|
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
|
|
|
"detrend": detrend, "boundary": boundary, "padded": padded,
|
|
|
|
"timeaxis": timeaxis}
|
|
|
|
for shape, nperseg, noverlap, timeaxis in stft_test_shapes
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for fs in [1.0, 16000.0]
|
|
|
|
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
|
|
|
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
|
|
|
for detrend in ['constant', 'linear', False]
|
|
|
|
for boundary in [None, 'even', 'odd', 'zeros']
|
|
|
|
for padded in [True, False]))
|
|
|
|
def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
|
|
|
noverlap, nfft, detrend, boundary, padded,
|
|
|
|
timeaxis):
|
2022-06-17 12:26:15 -07:00
|
|
|
is_complex = dtypes.issubdtype(dtype, np.complexfloating)
|
2022-02-03 14:18:44 +09:00
|
|
|
if is_complex and detrend is not None:
|
2022-06-17 12:26:15 -07:00
|
|
|
self.skipTest("Complex signal is not supported in lax-backed `signal.detrend`.")
|
|
|
|
|
|
|
|
kwds = dict(fs=fs, window=window, nfft=nfft, boundary=boundary, padded=padded,
|
|
|
|
detrend=detrend, nperseg=nperseg, noverlap=noverlap, axis=timeaxis,
|
|
|
|
return_onesided=not is_complex)
|
|
|
|
|
|
|
|
def osp_fun(x):
|
|
|
|
freqs, time, Pxx = osp_signal.stft(x, **kwds)
|
|
|
|
return freqs.astype(_real_dtype(dtype)), time.astype(_real_dtype(dtype)), Pxx.astype(_complex_dtype(dtype))
|
|
|
|
jsp_fun = partial(jsp_signal.stft, **kwds)
|
2022-02-03 14:18:44 +09:00
|
|
|
tol = {
|
|
|
|
np.float32: 1e-5, np.float64: 1e-12,
|
|
|
|
np.complex64: 1e-5, np.complex128: 1e-12
|
|
|
|
}
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
tol = _TPU_FFT_TOL
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
|
|
|
|
# Tests with `average == 'median'`` is excluded from `testCsd*`
|
|
|
|
# due to the issue:
|
|
|
|
# https://github.com/scipy/scipy/issues/15601
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}"
|
|
|
|
f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}"
|
|
|
|
f"_average={average}_scaling={scaling}_nfft={nfft}"
|
|
|
|
f"_fs={fs}_window={window}_detrend={detrend}"
|
|
|
|
f"_nperseg={nperseg}_noverlap={noverlap}"
|
|
|
|
f"_axis={timeaxis}",
|
|
|
|
"xshape": xshape, "yshape": yshape, "dtype": dtype, "fs": fs,
|
|
|
|
"window": window, "nperseg": nperseg, "noverlap": noverlap,
|
|
|
|
"nfft": nfft, "detrend": detrend, "scaling": scaling,
|
|
|
|
"timeaxis": timeaxis, "average": average}
|
|
|
|
for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for fs in [1.0, 16000.0]
|
|
|
|
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
|
|
|
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
|
|
|
for detrend in ['constant', 'linear', False]
|
|
|
|
for scaling in ['density', 'spectrum']
|
|
|
|
for average in ['mean']))
|
|
|
|
def testCsdAgainstNumpy(
|
|
|
|
self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft,
|
|
|
|
detrend, scaling, timeaxis, average):
|
2022-06-17 12:26:15 -07:00
|
|
|
is_complex = dtypes.issubdtype(dtype, np.complexfloating)
|
2022-02-03 14:18:44 +09:00
|
|
|
if is_complex and detrend is not None:
|
2022-06-17 12:26:15 -07:00
|
|
|
self.skipTest("Complex signal is not supported in lax-backed `signal.detrend`.")
|
|
|
|
|
|
|
|
kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
|
|
|
|
nfft=nfft, detrend=detrend, return_onesided=not is_complex,
|
|
|
|
scaling=scaling, axis=timeaxis, average=average)
|
|
|
|
|
|
|
|
def osp_fun(x, y):
|
|
|
|
freqs, Pxy = osp_signal.csd(x, y, **kwds)
|
|
|
|
# Make type-casting the same as JAX.
|
|
|
|
return freqs.astype(_real_dtype(dtype)), Pxy.astype(_complex_dtype(dtype))
|
|
|
|
jsp_fun = partial(jsp_signal.csd, **kwds)
|
2022-02-03 14:18:44 +09:00
|
|
|
tol = {
|
|
|
|
np.float32: 1e-5, np.float64: 1e-12,
|
|
|
|
np.complex64: 1e-5, np.complex128: 1e-12
|
|
|
|
}
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
tol = _TPU_FFT_TOL
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
|
|
|
f"_average={average}_scaling={scaling}_nfft={nfft}"
|
|
|
|
f"_fs={fs}_window={window}_detrend={detrend}"
|
|
|
|
f"_nperseg={nperseg}_noverlap={noverlap}"
|
|
|
|
f"_axis={timeaxis}",
|
|
|
|
"shape": shape, "dtype": dtype, "fs": fs,
|
|
|
|
"window": window, "nperseg": nperseg, "noverlap": noverlap,
|
|
|
|
"nfft": nfft, "detrend": detrend, "scaling": scaling,
|
|
|
|
"timeaxis": timeaxis, "average": average}
|
|
|
|
for shape, unused_yshape, nperseg, noverlap, timeaxis in csd_test_shapes
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for fs in [1.0, 16000.0]
|
|
|
|
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
|
|
|
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
|
|
|
for detrend in ['constant', 'linear', False]
|
|
|
|
for scaling in ['density', 'spectrum']
|
|
|
|
for average in ['mean']))
|
|
|
|
def testCsdWithSameParamAgainstNumpy(
|
|
|
|
self, *, shape, dtype, fs, window, nperseg, noverlap, nfft,
|
|
|
|
detrend, scaling, timeaxis, average):
|
2022-06-17 12:26:15 -07:00
|
|
|
is_complex = dtypes.issubdtype(dtype, np.complexfloating)
|
2022-02-03 14:18:44 +09:00
|
|
|
if is_complex and detrend is not None:
|
2022-06-17 12:26:15 -07:00
|
|
|
self.skipTest("Complex signal is not supported in lax-backed `signal.detrend`.")
|
|
|
|
|
|
|
|
kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
|
|
|
|
nfft=nfft, detrend=detrend, return_onesided=not is_complex,
|
|
|
|
scaling=scaling, axis=timeaxis, average=average)
|
2022-02-03 14:18:44 +09:00
|
|
|
|
|
|
|
def osp_fun(x, y):
|
|
|
|
# When the identical parameters are given, jsp-version follows
|
|
|
|
# the behavior with copied parameters.
|
2022-06-17 12:26:15 -07:00
|
|
|
freqs, Pxy = osp_signal.csd(x, y.copy(), **kwds)
|
|
|
|
# Make type-casting the same as JAX.
|
|
|
|
return freqs.astype(_real_dtype(dtype)), Pxy.astype(_complex_dtype(dtype))
|
|
|
|
jsp_fun = partial(jsp_signal.csd, **kwds)
|
2022-02-03 14:18:44 +09:00
|
|
|
tol = {
|
|
|
|
np.float32: 1e-5, np.float64: 1e-12,
|
|
|
|
np.complex64: 1e-5, np.complex128: 1e-12
|
|
|
|
}
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
tol = _TPU_FFT_TOL
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)] * 2
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
|
|
|
f"_fs={fs}_window={window}"
|
|
|
|
f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}"
|
|
|
|
f"_detrend={detrend}_return_onesided={return_onesided}"
|
|
|
|
f"_scaling={scaling}_axis={timeaxis}_average={average}",
|
|
|
|
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
|
|
|
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
|
|
|
"detrend": detrend, "return_onesided": return_onesided,
|
|
|
|
"scaling": scaling, "timeaxis": timeaxis, "average": average}
|
|
|
|
for shape, nperseg, noverlap, timeaxis in welch_test_shapes
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for fs in [1.0, 16000.0]
|
|
|
|
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
|
|
|
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
|
|
|
for detrend in ['constant', 'linear', False]
|
|
|
|
for return_onesided in [True, False]
|
|
|
|
for scaling in ['density', 'spectrum']
|
|
|
|
for average in ['mean', 'median']))
|
|
|
|
def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
|
|
|
noverlap, nfft, detrend, return_onesided,
|
|
|
|
scaling, timeaxis, average):
|
|
|
|
if np.dtype(dtype).kind == 'c':
|
|
|
|
return_onesided = False
|
|
|
|
if detrend is not None:
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
"Complex signal is not supported in lax-backed `signal.detrend`.")
|
|
|
|
|
2022-06-17 12:26:15 -07:00
|
|
|
kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft,
|
|
|
|
detrend=detrend, return_onesided=return_onesided, scaling=scaling,
|
|
|
|
axis=timeaxis, average=average)
|
|
|
|
|
|
|
|
def osp_fun(x):
|
|
|
|
freqs, Pxx = osp_signal.welch(x, **kwds)
|
|
|
|
return freqs.astype(_real_dtype(dtype)), Pxx.astype(_real_dtype(dtype))
|
|
|
|
jsp_fun = partial(jsp_signal.welch, **kwds)
|
2022-02-03 14:18:44 +09:00
|
|
|
tol = {
|
|
|
|
np.float32: 1e-5, np.float64: 1e-12,
|
|
|
|
np.complex64: 1e-5, np.complex128: 1e-12
|
|
|
|
}
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
tol = _TPU_FFT_TOL
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
|
|
|
f"_nperseg={nperseg}_noverlap={noverlap}"
|
|
|
|
f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}"
|
|
|
|
f"_axis={timeaxis}",
|
|
|
|
"shape": shape, "dtype": dtype,
|
|
|
|
"nperseg": nperseg, "noverlap": noverlap,
|
|
|
|
"use_nperseg": use_nperseg, "use_noverlap": use_noverlap,
|
|
|
|
"timeaxis": timeaxis}
|
|
|
|
for shape, nperseg, noverlap, timeaxis in welch_test_shapes
|
|
|
|
for use_nperseg in [False, True]
|
|
|
|
for use_noverlap in [False, True]
|
|
|
|
for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
|
|
|
|
def testWelchWithDefaultStepArgsAgainstNumpy(
|
|
|
|
self, *, shape, dtype, nperseg, noverlap, use_nperseg, use_noverlap,
|
|
|
|
timeaxis):
|
2022-06-17 12:26:15 -07:00
|
|
|
kwargs = {'axis': timeaxis}
|
2022-02-03 14:18:44 +09:00
|
|
|
|
|
|
|
if use_nperseg:
|
|
|
|
kwargs['nperseg'] = nperseg
|
|
|
|
else:
|
2022-06-17 12:26:15 -07:00
|
|
|
kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg),
|
2022-08-12 12:51:09 +00:00
|
|
|
dtype=dtypes.to_complex_dtype(dtype))
|
2022-02-03 14:18:44 +09:00
|
|
|
if use_noverlap:
|
|
|
|
kwargs['noverlap'] = noverlap
|
|
|
|
|
2022-06-17 12:26:15 -07:00
|
|
|
def osp_fun(x):
|
|
|
|
freqs, Pxx = osp_signal.welch(x, **kwargs)
|
|
|
|
return freqs.astype(_real_dtype(dtype)), Pxx.astype(_real_dtype(dtype))
|
2022-02-03 14:18:44 +09:00
|
|
|
jsp_fun = partial(jsp_signal.welch, **kwargs)
|
|
|
|
tol = {
|
|
|
|
np.float32: 1e-5, np.float64: 1e-12,
|
|
|
|
np.complex64: 1e-5, np.complex128: 1e-12
|
|
|
|
}
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
tol = _TPU_FFT_TOL
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
|
|
|
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
|
|
|
|
2022-03-28 21:33:40 +09:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
|
|
|
|
f"_fs={fs}_window={window}_boundary={boundary}"
|
|
|
|
f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}"
|
|
|
|
f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}",
|
|
|
|
"shape": shape, "dtype": dtype, "fs": fs, "window": window,
|
|
|
|
"nperseg": nperseg, "noverlap": noverlap, "nfft": nfft,
|
|
|
|
"onesided": onesided, "boundary": boundary,
|
|
|
|
"timeaxis": timeaxis, "freqaxis": freqaxis}
|
|
|
|
for shape, nperseg, noverlap, timeaxis, freqaxis in istft_test_shapes
|
|
|
|
for dtype in default_dtypes
|
|
|
|
for fs in [1.0, 16000.0]
|
|
|
|
for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
|
|
|
|
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
|
|
|
|
for onesided in [False, True]
|
|
|
|
for boundary in [False, True]))
|
|
|
|
def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
|
|
|
|
noverlap, nfft, onesided, boundary,
|
|
|
|
timeaxis, freqaxis):
|
|
|
|
if not onesided:
|
|
|
|
new_freq_len = (shape[freqaxis] - 1) * 2
|
|
|
|
shape = shape[:freqaxis] + (new_freq_len ,) + shape[freqaxis + 1:]
|
|
|
|
|
2022-06-17 12:26:15 -07:00
|
|
|
kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap,
|
|
|
|
nfft=nfft, input_onesided=onesided, boundary=boundary,
|
|
|
|
time_axis=timeaxis, freq_axis=freqaxis)
|
|
|
|
|
|
|
|
osp_fun = partial(osp_signal.istft, **kwds)
|
|
|
|
osp_fun = jtu.ignore_warning(message="NOLA condition failed, STFT may not be invertible")(osp_fun)
|
|
|
|
jsp_fun = partial(jsp_signal.istft, **kwds)
|
2022-03-28 21:33:40 +09:00
|
|
|
|
|
|
|
tol = {
|
|
|
|
np.float32: 1e-4, np.float64: 1e-6,
|
|
|
|
np.complex64: 1e-4, np.complex128: 1e-6
|
|
|
|
}
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
|
|
tol = _TPU_FFT_TOL
|
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
2022-06-17 12:26:15 -07:00
|
|
|
args_maker = lambda: [rng(shape, dtype)]
|
2022-03-28 21:33:40 +09:00
|
|
|
|
|
|
|
# Here, dtype of output signal is different depending on osp versions,
|
|
|
|
# and so depending on the test environment. Thus, dtype check is disabled.
|
|
|
|
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, rtol=tol, atol=tol,
|
|
|
|
check_dtypes=False)
|
|
|
|
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
|
2020-04-10 11:54:10 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|