# Copyright 2021 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Sequence from functools import partial import math from jax import lax import jax.numpy as jnp from jax._src.util import canonicalize_axis from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact from jax._src.typing import Array def _W4(N: int, k: Array) -> Array: N_arr, k = promote_dtypes_complex(N, k) return jnp.exp(-.5j * jnp.pi * k / N_arr) def _dct_interleave(x: Array, axis: int) -> Array: v0 = lax.slice_in_dim(x, None, None, 2, axis) v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis,)) return lax.concatenate([v0, v1], axis) def _dct_ortho_norm(out: Array, axis: int) -> Array: factor = lax.concatenate([lax.full((1,), 4, out.dtype), lax.full((out.shape[axis] - 1,), 2, out.dtype)], 0) factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis]) return out / lax.sqrt(factor * out.shape[axis]) # Implementation based on # John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980) def dct(x: Array, type: int = 2, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: """Computes the discrete cosine transform of the input JAX implementation of :func:`scipy.fft.dct`. Args: x: array type: integer, default = 2. Currently only type 2 is supported. n: integer, default = x.shape[axis]. The length of the transform. If larger than ``x.shape[axis]``, the input will be zero-padded, if smaller, the input will be truncated. axis: integer, default=-1. The axis along which the dct will be performed. norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``. The default is ``None``, which is equivalent to ``"backward"``. Returns: array containing the discrete cosine transform of x See Also: - :func:`jax.scipy.fft.dctn`: multidimensional DCT - :func:`jax.scipy.fft.idct`: inverse DCT - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT Examples: >>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x)) [[ 6.43 3.56 -2.86] [-1.75 1.55 -1.4 ] [ 1.33 -2.01 -0.82]] When ``n`` smaller than ``x.shape[axis]`` >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=2)) [[ 7.3 -0.57] [ 0.19 -0.36] [-0. -1.4 ]] When ``n`` smaller than ``x.shape[axis]`` and ``axis=0`` >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=2, axis=0)) [[ 3.09 4.4 -2.81] [ 2.41 2.62 0.76]] When ``n`` larger than ``x.shape[axis]`` and ``axis=1`` >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dct(x, n=4, axis=1)) [[ 6.43 4.88 0.04 -3.3 ] [-1.75 0.73 1.01 -2.18] [ 1.33 -1.05 -2.34 -0.07]] """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') if norm is not None and norm not in ['backward', 'ortho']: raise ValueError(f"jax.scipy.fft.dct: {norm=!r} is not implemented") axis = canonicalize_axis(axis, x.ndim) if n is not None: x = lax.pad(x, jnp.array(0, x.dtype), [(0, n - x.shape[axis] if a == axis else 0, 0) for a in range(x.ndim)]) N = x.shape[axis] v = _dct_interleave(x, axis) V = jnp.fft.fft(v, axis=axis) k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis]) out = V * _W4(N, k) out = 2 * out.real if norm == 'ortho': out = _dct_ortho_norm(out, axis) return out def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array: axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes) N1, N2 = x.shape[axis1], x.shape[axis2] v = _dct_interleave(_dct_interleave(x, axis1), axis2) V = jnp.fft.fftn(v, axes=axes) k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype), [a for a in range(x.ndim) if a != axis1]) k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype), [a for a in range(x.ndim) if a != axis2]) out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2)) out = 2 * out.real if norm == 'ortho': return _dct_ortho_norm(_dct_ortho_norm(out, axis1), axis2) return out def dctn(x: Array, type: int = 2, s: Sequence[int] | None=None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: """Computes the multidimensional discrete cosine transform of the input JAX implementation of :func:`scipy.fft.dctn`. Args: x: array type: integer, default = 2. Currently only type 2 is supported. s: integer or sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of ``x`` along the specified ``axes``. axes: integer or sequence of integers. Specifies the axes along which the transform will be computed. norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``. The default is ``None``, which is equivalent to ``"backward"``. Returns: array containing the discrete cosine transform of x See Also: - :func:`jax.scipy.fft.dct`: one-dimensional DCT - :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT Examples: ``jax.scipy.fft.dctn`` computes the transform along both the axes by default when ``axes`` argument is ``None``. >>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x)) [[ 12.01 6.2 -10.17] [ 8.84 9.65 -3.54] [ 11.25 -1.54 -0.88]] When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2`` and dimension along ``axis 1`` will be same as that of input. >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2])) [[ 9.36 10.22 -8.53] [11.57 2.85 -2.06]] When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will be ``2`` and dimension along ``axis 0`` will be same as that of input. Also when ``axes=[1]``, transform will be computed only along ``axis 1``. >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2], axes=[1])) [[ 7.3 -0.57] [ 0.19 -0.36] [-0. -1.4 ]] When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``. >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2, 4])) [[ 9.36 11.23 2.12 -10.97] [ 11.57 5.86 -1.37 -1.58]] """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') if norm is not None and norm not in ['backward', 'ortho']: raise ValueError(f"jax.scipy.fft.dctn: {norm=!r} is not implemented") if axes is None: axes = range(x.ndim) if len(axes) == 1: return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm) if s is not None: ns = dict(zip(axes, s)) pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)] x = lax.pad(x, jnp.array(0, x.dtype), pads) if len(axes) == 2: return _dct2(x, axes=axes, norm=norm) # compose high-D DCTs from 2D and 1D DCTs: for axes_block in [axes[i:i+2] for i in range(0, len(axes), 2)]: x = dctn(x, axes=axes_block, norm=norm) return x def idct(x: Array, type: int = 2, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: """Computes the inverse discrete cosine transform of the input JAX implementation of :func:`scipy.fft.idct`. Args: x: array type: integer, default = 2. Currently only type 2 is supported. n: integer, default = x.shape[axis]. The length of the transform. If larger than ``x.shape[axis]``, the input will be zero-padded, if smaller, the input will be truncated. axis: integer, default=-1. The axis along which the dct will be performed. norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``. The default is ``None``, which is equivalent to ``"backward"``. Returns: array containing the inverse discrete cosine transform of x See Also: - :func:`jax.scipy.fft.dct`: DCT - :func:`jax.scipy.fft.dctn`: multidimensional DCT - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT Examples: >>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idct(x)) [[ 0.78 0.41 -0.39] [-0.12 0.31 -0.23] [ 0.17 -0.3 -0.11]] When ``n`` smaller than ``x.shape[axis]`` >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idct(x, n=2)) [[ 1.12 -0.31] [ 0.04 -0.08] [ 0.05 -0.3 ]] When ``n`` smaller than ``x.shape[axis]`` and ``axis=0`` >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idct(x, n=2, axis=0)) [[ 0.38 0.57 -0.45] [ 0.43 0.44 0.24]] When ``n`` larger than ``x.shape[axis]`` and ``axis=0`` >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idct(x, n=4, axis=0)) [[ 0.1 0.38 -0.16] [ 0.28 0.18 -0.26] [ 0.3 0.15 -0.08] [ 0.13 0.3 0.29]] ``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result of ``jax.scipy.fft.dct`` >>> x_dct = jax.scipy.fft.dct(x) >>> jnp.allclose(x, jax.scipy.fft.idct(x_dct)) Array(True, dtype=bool) """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') if norm is not None and norm not in ['backward', 'ortho']: raise ValueError(f"jax.scipy.fft.idct: {norm=!r} is not implemented") axis = canonicalize_axis(axis, x.ndim) if n is not None: x = lax.pad(x, jnp.array(0, x.dtype), [(0, n - x.shape[axis] if a == axis else 0, 0) for a in range(x.ndim)]) N = x.shape[axis] x, = promote_dtypes_inexact(x) if norm is None or norm == 'backward': x = _dct_ortho_norm(x, axis) x = _dct_ortho_norm(x, axis) k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis]) # everything is complex from here... w4 = _W4(N,k) x = x.astype(w4.dtype) x = x / (_W4(N, k)) x = x * 2 * N x = jnp.fft.ifft(x, axis=axis) # convert back to reals.. out = _dct_deinterleave(x.real, axis) return out def idctn(x: Array, type: int = 2, s: Sequence[int] | None=None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: """Computes the multidimensional inverse discrete cosine transform of the input JAX implementation of :func:`scipy.fft.idctn`. Args: x: array type: integer, default = 2. Currently only type 2 is supported. s: integer or sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of ``x`` along the specified ``axes``. axes: integer or sequence of integers. Specifies the axes along which the transform will be computed. norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``. The default is ``None``, which is equivalent to ``"backward"``. Returns: array containing the inverse discrete cosine transform of x See Also: - :func:`jax.scipy.fft.dct`: one-dimensional DCT - :func:`jax.scipy.fft.dctn`: multidimensional DCT - :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT Examples: ``jax.scipy.fft.idctn`` computes the transform along both the axes by default when ``axes`` argument is ``None``. >>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idctn(x)) [[ 0.12 0.11 -0.15] [ 0.07 0.17 -0.03] [ 0.19 -0.07 -0.02]] When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2`` and dimension along ``axis 1`` will be the same as that of input. >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idctn(x, s=[2])) [[ 0.15 0.21 -0.18] [ 0.24 -0.01 -0.02]] When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will be ``2`` and dimension along ``axis 0`` will be same as that of input. Also when ``axes=[1]``, transform will be computed only along ``axis 1``. >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idctn(x, s=[2], axes=[1])) [[ 1.12 -0.31] [ 0.04 -0.08] [ 0.05 -0.3 ]] When ``s=[2, 4]``, shape of the transform will be ``(2, 4)`` >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.idctn(x, s=[2, 4])) [[ 0.1 0.18 0.07 -0.16] [ 0.2 0.06 -0.03 -0.01]] ``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result of ``jax.scipy.fft.dctn`` >>> x_dctn = jax.scipy.fft.dctn(x) >>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn)) Array(True, dtype=bool) """ if type != 2: raise NotImplementedError('Only DCT type 2 is implemented.') if norm is not None and norm not in ['backward', 'ortho']: raise ValueError(f"jax.scipy.fft.idctn: {norm=!r} is not implemented") if axes is None: axes = range(x.ndim) if len(axes) == 1: return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm) if s is not None: ns = dict(zip(axes, s)) pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)] x = lax.pad(x, jnp.array(0, x.dtype), pads) # compose high-D DCTs from 1D DCTs: for axis in axes: x = idct(x, axis=axis, norm=norm) return x def _dct_deinterleave(x: Array, axis: int) -> Array: empty_slice = slice(None, None, None) ix0 = tuple( slice(None, math.ceil(x.shape[axis]/2), 1) if i == axis else empty_slice for i in range(len(x.shape))) ix1 = tuple( slice(math.ceil(x.shape[axis]/2), None, 1) if i == axis else empty_slice for i in range(len(x.shape))) v0 = x[ix0] v1 = lax.rev(x[ix1], (axis,)) out = jnp.zeros(x.shape, dtype=x.dtype) evens = tuple( slice(None, None, 2) if i == axis else empty_slice for i in range(len(x.shape))) odds = tuple( slice(1, None, 2) if i == axis else empty_slice for i in range(len(x.shape))) out = out.at[evens].set(v0) out = out.at[odds].set(v1) return out