Merge pull request #21418 from rajasekharporeddy:test_branch6

PiperOrigin-RevId: 640617948
This commit is contained in:
jax authors 2024-06-05 12:53:17 -07:00
commit 913ff50000

View File

@ -64,6 +64,37 @@ def dct(x: Array, type: int = 2, n: int | None = None,
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
Example:
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x))
[[-0.58 -0.33 -1.08]
[-0.88 -1.01 -1.79]
[-1.06 -2.43 1.24]]
When ``n`` smaller than ``x.shape[axis]``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=2))
[[-0.22 -0.9 ]
[-0.57 -1.68]
[-2.52 -0.11]]
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=2, axis=0))
[[-2.22 1.43 -0.67]
[ 0.52 -0.26 -0.04]]
When ``n`` larger than ``x.shape[axis]`` and ``axis=1``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=4, axis=1))
[[-0.58 -0.35 -0.64 -1.11]
[-0.88 -0.9 -1.46 -1.68]
[-1.06 -2.25 -1.15 1.93]]
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')