jax.scipy.fft: manually document functions to avoid scipy import

This commit is contained in:
Jake VanderPlas 2024-04-17 10:06:25 -07:00
parent ba540ca735
commit 75921162ab

View File

@ -18,11 +18,10 @@ from collections.abc import Sequence
from functools import partial
import math
import scipy.fft as osp_fft
from jax import lax
import jax.numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import implements, promote_dtypes_complex
from jax._src.numpy.util import promote_dtypes_complex
from jax._src.typing import Array
def _W4(N: int, k: Array) -> Array:
@ -42,9 +41,30 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array:
# Implementation based on
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)
@implements(osp_fft.dct)
def dct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.dct`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
Returns:
array containing the discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
@ -81,11 +101,31 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
return out
@implements(osp_fft.dctn)
def dctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.dctn`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If not
specified, it will default to the shape of ``x`` along the specified ``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
Returns:
array containing the discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
@ -109,9 +149,29 @@ def dctn(x: Array, type: int = 2,
return x
@implements(osp_fft.idct)
def idct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the inverse discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.idct`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
Returns:
array containing the inverse discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
@ -126,7 +186,6 @@ def idct(x: Array, type: int = 2, n: int | None = None,
x = _dct_ortho_norm(x, axis)
x = _dct_ortho_norm(x, axis)
k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis])
# everything is complex from here...
w4 = _W4(N,k)
@ -139,11 +198,32 @@ def idct(x: Array, type: int = 2, n: int | None = None,
out = _dct_deinterleave(x.real, axis)
return out
@implements(osp_fft.idctn)
def idctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional inverse discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.idctn`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If not
specified, it will default to the shape of ``x`` along the specified ``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
Returns:
array containing the inverse discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')