Merge pull request #22567 from jakevdp:fft-norm-validation

PiperOrigin-RevId: 654828825
This commit is contained in:
jax authors 2024-07-22 11:15:43 -07:00
commit 0d7531b4f1

View File

@ -98,6 +98,8 @@ def dct(x: Array, type: int = 2, n: int | None = None,
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['ortho']:
raise ValueError(f"jax.scipy.fft.dct: {norm=!r} is not implemented")
axis = canonicalize_axis(axis, x.ndim)
if n is not None:
@ -196,6 +198,8 @@ def dctn(x: Array, type: int = 2,
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['ortho']:
raise ValueError(f"jax.scipy.fft.dctn: {norm=!r} is not implemented")
if axes is None:
axes = range(x.ndim)
@ -282,6 +286,8 @@ def idct(x: Array, type: int = 2, n: int | None = None,
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['ortho']:
raise ValueError(f"jax.scipy.fft.idct: {norm=!r} is not implemented")
axis = canonicalize_axis(axis, x.ndim)
if n is not None:
@ -378,6 +384,8 @@ def idctn(x: Array, type: int = 2,
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['ortho']:
raise ValueError(f"jax.scipy.fft.idctn: {norm=!r} is not implemented")
if axes is None:
axes = range(x.ndim)