mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21418 from rajasekharporeddy:test_branch6
PiperOrigin-RevId: 640617948
This commit is contained in:
commit
913ff50000
@ -64,6 +64,37 @@ def dct(x: Array, type: int = 2, n: int | None = None,
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: inverse DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
|
||||
Example:
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x))
|
||||
[[-0.58 -0.33 -1.08]
|
||||
[-0.88 -1.01 -1.79]
|
||||
[-1.06 -2.43 1.24]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=2))
|
||||
[[-0.22 -0.9 ]
|
||||
[-0.57 -1.68]
|
||||
[-2.52 -0.11]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=2, axis=0))
|
||||
[[-2.22 1.43 -0.67]
|
||||
[ 0.52 -0.26 -0.04]]
|
||||
|
||||
When ``n`` larger than ``x.shape[axis]`` and ``axis=1``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=4, axis=1))
|
||||
[[-0.58 -0.35 -0.64 -1.11]
|
||||
[-0.88 -0.9 -1.46 -1.68]
|
||||
[-1.06 -2.25 -1.15 1.93]]
|
||||
"""
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user