[x64] more type safety in jax.scipy.signal

This commit is contained in:
Jake VanderPlas 2022-12-01 13:40:38 -08:00
parent b3b7eb68f1
commit d25a96caea
2 changed files with 18 additions and 10 deletions

View File

@ -229,9 +229,13 @@ def irfft2(a: ArrayLike, s: Optional[Shape] = None, axes: Sequence[int] = (-2,-1
norm=norm)
@_wraps(np.fft.fftfreq)
def fftfreq(n: int, d: ArrayLike = 1.0) -> Array:
dtype = dtypes.canonicalize_dtype(jnp.float_)
@_wraps(np.fft.fftfreq, extra_params="""
dtype : Optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
""")
def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
@ -257,12 +261,16 @@ def fftfreq(n: int, d: ArrayLike = 1.0) -> Array:
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype))
return k / (d * n)
return k / jnp.array(d * n, dtype=dtype)
@_wraps(np.fft.rfftfreq)
def rfftfreq(n: int, d: ArrayLike = 1.0) -> Array:
dtype = dtypes.canonicalize_dtype(jnp.float_)
@_wraps(np.fft.rfftfreq, extra_params="""
dtype : Optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
""")
def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
@ -279,7 +287,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0) -> Array:
else:
k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)
return k / (d * n)
return k / jnp.array(d * n, dtype=dtype)
@_wraps(np.fft.fftshift)

View File

@ -397,9 +397,9 @@ def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0,
sides = 'twosided'
if sides == 'twosided':
freqs = jax.numpy.fft.fftfreq(nfft_int, 1/fs).astype(freq_dtype)
freqs = jax.numpy.fft.fftfreq(nfft_int, 1/fs, dtype=freq_dtype)
elif sides == 'onesided':
freqs = jax.numpy.fft.rfftfreq(nfft_int, 1/fs).astype(freq_dtype)
freqs = jax.numpy.fft.rfftfreq(nfft_int, 1/fs, dtype=freq_dtype)
# Perform the windowed FFTs
result = _fft_helper(x, win, detrend_func,