mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax.scipy.fft: manually document functions to avoid scipy import
This commit is contained in:
parent
ba540ca735
commit
75921162ab
@ -18,11 +18,10 @@ from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import scipy.fft as osp_fft
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.numpy.util import implements, promote_dtypes_complex
|
||||
from jax._src.numpy.util import promote_dtypes_complex
|
||||
from jax._src.typing import Array
|
||||
|
||||
def _W4(N: int, k: Array) -> Array:
|
||||
@ -42,9 +41,30 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array:
|
||||
# Implementation based on
|
||||
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)
|
||||
|
||||
@implements(osp_fft.dct)
|
||||
|
||||
def dct(x: Array, type: int = 2, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
"""Computes the discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.dct`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
n: integer, default = x.shape[axis]. The length of the transform.
|
||||
If larger than ``x.shape[axis]``, the input will be zero-padded, if
|
||||
smaller, the input will be truncated.
|
||||
axis: integer, default=-1. The axis along which the dct will be performed.
|
||||
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
|
||||
|
||||
Returns:
|
||||
array containing the discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: inverse DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
"""
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
|
||||
@ -81,11 +101,31 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
|
||||
return out
|
||||
|
||||
|
||||
@implements(osp_fft.dctn)
|
||||
def dctn(x: Array, type: int = 2,
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
"""Computes the multidimensional discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.dctn`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
s: integer or sequence of integers. Specifies the shape of the result. If not
|
||||
specified, it will default to the shape of ``x`` along the specified ``axes``.
|
||||
axes: integer or sequence of integers. Specifies the axes along which the
|
||||
transform will be computed.
|
||||
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
|
||||
|
||||
Returns:
|
||||
array containing the discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
"""
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
|
||||
@ -109,9 +149,29 @@ def dctn(x: Array, type: int = 2,
|
||||
return x
|
||||
|
||||
|
||||
@implements(osp_fft.idct)
|
||||
def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
"""Computes the inverse discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.idct`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
n: integer, default = x.shape[axis]. The length of the transform.
|
||||
If larger than ``x.shape[axis]``, the input will be zero-padded, if
|
||||
smaller, the input will be truncated.
|
||||
axis: integer, default=-1. The axis along which the dct will be performed.
|
||||
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
|
||||
|
||||
Returns:
|
||||
array containing the inverse discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: DCT
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
"""
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
|
||||
@ -126,7 +186,6 @@ def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
x = _dct_ortho_norm(x, axis)
|
||||
x = _dct_ortho_norm(x, axis)
|
||||
|
||||
|
||||
k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis])
|
||||
# everything is complex from here...
|
||||
w4 = _W4(N,k)
|
||||
@ -139,11 +198,32 @@ def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
out = _dct_deinterleave(x.real, axis)
|
||||
return out
|
||||
|
||||
@implements(osp_fft.idctn)
|
||||
|
||||
def idctn(x: Array, type: int = 2,
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
"""Computes the multidimensional inverse discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.idctn`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
s: integer or sequence of integers. Specifies the shape of the result. If not
|
||||
specified, it will default to the shape of ``x`` along the specified ``axes``.
|
||||
axes: integer or sequence of integers. Specifies the axes along which the
|
||||
transform will be computed.
|
||||
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
|
||||
|
||||
Returns:
|
||||
array containing the inverse discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
|
||||
"""
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user