mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21657 from rajasekharporeddy:test_branch7
PiperOrigin-RevId: 640652642
This commit is contained in:
commit
bf59a67bf0
@ -159,7 +159,7 @@ def dctn(x: Array, type: int = 2,
|
||||
|
||||
Example:
|
||||
|
||||
``jax.scipy.fft.dctn`` computes the transform along both the axes by dafault
|
||||
``jax.scipy.fft.dctn`` computes the transform along both the axes by default
|
||||
when ``axes`` argument is ``None``.
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
@ -239,6 +239,46 @@ def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
- :func:`jax.scipy.fft.dct`: DCT
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x))
|
||||
[[-0.02 -0. -0.17]
|
||||
[-0.02 -0.07 -0.28]
|
||||
[-0.16 -0.36 0.18]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=2))
|
||||
[[ 0. -0.19]
|
||||
[-0.03 -0.34]
|
||||
[-0.38 0.04]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=2, axis=0))
|
||||
[[-0.35 0.23 -0.1 ]
|
||||
[ 0.17 -0.09 0.01]]
|
||||
|
||||
When ``n`` larger than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=4, axis=0))
|
||||
[[-0.34 0.03 0.07]
|
||||
[ 0. 0.18 -0.17]
|
||||
[ 0.14 0.09 -0.14]
|
||||
[ 0. -0.18 0.14]]
|
||||
|
||||
``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result
|
||||
of ``jax.scipy.fft.dct``
|
||||
|
||||
>>> x_dct = jax.scipy.fft.dct(x)
|
||||
>>> jnp.allclose(x, jax.scipy.fft.idct(x_dct))
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
@ -291,6 +331,50 @@ def idctn(x: Array, type: int = 2,
|
||||
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
|
||||
|
||||
Example:
|
||||
|
||||
``jax.scipy.fft.idctn`` computes the transform along both the axes by default
|
||||
when ``axes`` argument is ``None``.
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x))
|
||||
[[-0.03 -0.08 -0.08]
|
||||
[ 0.05 0.12 -0.09]
|
||||
[-0.02 -0.04 0.08]]
|
||||
|
||||
When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2``
|
||||
and dimension along ``axis 1`` will be the same as that of input.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2]))
|
||||
[[-0.01 -0.03 -0.14]
|
||||
[ 0. 0.03 0.06]]
|
||||
|
||||
When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will
|
||||
be ``2`` and dimension along ``axis 0`` will be same as that of input.
|
||||
Also when ``axes=[1]``, transform will be computed only along ``axis 1``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2], axes=[1]))
|
||||
[[ 0. -0.19]
|
||||
[-0.03 -0.34]
|
||||
[-0.38 0.04]]
|
||||
|
||||
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2, 4]))
|
||||
[[-0.01 -0.01 -0.05 -0.11]
|
||||
[ 0. 0.01 0.03 0.04]]
|
||||
|
||||
``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result
|
||||
of ``jax.scipy.fft.dctn``
|
||||
|
||||
>>> x_dctn = jax.scipy.fft.dctn(x)
|
||||
>>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user