Merge pull request #21657 from rajasekharporeddy:test_branch7

PiperOrigin-RevId: 640652642
This commit is contained in:
jax authors 2024-06-05 14:35:35 -07:00
commit bf59a67bf0

View File

@ -159,7 +159,7 @@ def dctn(x: Array, type: int = 2,
Example:
``jax.scipy.fft.dctn`` computes the transform along both the axes by dafault
``jax.scipy.fft.dctn`` computes the transform along both the axes by default
when ``axes`` argument is ``None``.
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
@ -239,6 +239,46 @@ def idct(x: Array, type: int = 2, n: int | None = None,
- :func:`jax.scipy.fft.dct`: DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
Example:
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x))
[[-0.02 -0. -0.17]
[-0.02 -0.07 -0.28]
[-0.16 -0.36 0.18]]
When ``n`` smaller than ``x.shape[axis]``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=2))
[[ 0. -0.19]
[-0.03 -0.34]
[-0.38 0.04]]
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=2, axis=0))
[[-0.35 0.23 -0.1 ]
[ 0.17 -0.09 0.01]]
When ``n`` larger than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=4, axis=0))
[[-0.34 0.03 0.07]
[ 0. 0.18 -0.17]
[ 0.14 0.09 -0.14]
[ 0. -0.18 0.14]]
``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result
of ``jax.scipy.fft.dct``
>>> x_dct = jax.scipy.fft.dct(x)
>>> jnp.allclose(x, jax.scipy.fft.idct(x_dct))
Array(True, dtype=bool)
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
@ -291,6 +331,50 @@ def idctn(x: Array, type: int = 2,
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
Example:
``jax.scipy.fft.idctn`` computes the transform along both the axes by default
when ``axes`` argument is ``None``.
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x))
[[-0.03 -0.08 -0.08]
[ 0.05 0.12 -0.09]
[-0.02 -0.04 0.08]]
When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2``
and dimension along ``axis 1`` will be the same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2]))
[[-0.01 -0.03 -0.14]
[ 0. 0.03 0.06]]
When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will
be ``2`` and dimension along ``axis 0`` will be same as that of input.
Also when ``axes=[1]``, transform will be computed only along ``axis 1``.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2], axes=[1]))
[[ 0. -0.19]
[-0.03 -0.34]
[-0.38 0.04]]
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2, 4]))
[[-0.01 -0.01 -0.05 -0.11]
[ 0. 0.01 0.03 0.04]]
``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result
of ``jax.scipy.fft.dctn``
>>> x_dctn = jax.scipy.fft.dctn(x)
>>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
Array(True, dtype=bool)
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')