mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[x64] more type safety in jax.scipy.signal
This commit is contained in:
parent
b3b7eb68f1
commit
d25a96caea
@ -229,9 +229,13 @@ def irfft2(a: ArrayLike, s: Optional[Shape] = None, axes: Sequence[int] = (-2,-1
|
||||
norm=norm)
|
||||
|
||||
|
||||
@_wraps(np.fft.fftfreq)
|
||||
def fftfreq(n: int, d: ArrayLike = 1.0) -> Array:
|
||||
dtype = dtypes.canonicalize_dtype(jnp.float_)
|
||||
@_wraps(np.fft.fftfreq, extra_params="""
|
||||
dtype : Optional
|
||||
The dtype of the returned frequencies. If not specified, JAX's default
|
||||
floating point dtype will be used.
|
||||
""")
|
||||
def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
|
||||
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
|
||||
if isinstance(n, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
|
||||
@ -257,12 +261,16 @@ def fftfreq(n: int, d: ArrayLike = 1.0) -> Array:
|
||||
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
|
||||
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype))
|
||||
|
||||
return k / (d * n)
|
||||
return k / jnp.array(d * n, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(np.fft.rfftfreq)
|
||||
def rfftfreq(n: int, d: ArrayLike = 1.0) -> Array:
|
||||
dtype = dtypes.canonicalize_dtype(jnp.float_)
|
||||
@_wraps(np.fft.rfftfreq, extra_params="""
|
||||
dtype : Optional
|
||||
The dtype of the returned frequencies. If not specified, JAX's default
|
||||
floating point dtype will be used.
|
||||
""")
|
||||
def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
|
||||
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
|
||||
if isinstance(n, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
|
||||
@ -279,7 +287,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0) -> Array:
|
||||
else:
|
||||
k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)
|
||||
|
||||
return k / (d * n)
|
||||
return k / jnp.array(d * n, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(np.fft.fftshift)
|
||||
|
@ -397,9 +397,9 @@ def _spectral_helper(x: Array, y: Optional[ArrayLike], fs: ArrayLike = 1.0,
|
||||
sides = 'twosided'
|
||||
|
||||
if sides == 'twosided':
|
||||
freqs = jax.numpy.fft.fftfreq(nfft_int, 1/fs).astype(freq_dtype)
|
||||
freqs = jax.numpy.fft.fftfreq(nfft_int, 1/fs, dtype=freq_dtype)
|
||||
elif sides == 'onesided':
|
||||
freqs = jax.numpy.fft.rfftfreq(nfft_int, 1/fs).astype(freq_dtype)
|
||||
freqs = jax.numpy.fft.rfftfreq(nfft_int, 1/fs, dtype=freq_dtype)
|
||||
|
||||
# Perform the windowed FFTs
|
||||
result = _fft_helper(x, win, detrend_func,
|
||||
|
Loading…
x
Reference in New Issue
Block a user