diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index a826d4746..a0050cc81 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -21,7 +21,7 @@ import math from jax import lax import jax.numpy as jnp from jax._src.util import canonicalize_axis -from jax._src.numpy.util import promote_dtypes_complex +from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact from jax._src.typing import Array def _W4(N: int, k: Array) -> Array: @@ -298,12 +298,12 @@ def idct(x: Array, type: int = 2, n: int | None = None, [(0, n - x.shape[axis] if a == axis else 0, 0) for a in range(x.ndim)]) N = x.shape[axis] - x = x.astype(jnp.float32) + x, = promote_dtypes_inexact(x) if norm is None or norm == 'backward': x = _dct_ortho_norm(x, axis) x = _dct_ortho_norm(x, axis) - k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis]) + k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis]) # everything is complex from here... w4 = _W4(N,k) x = x.astype(w4.dtype) diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index 6c549f5ed..a6fdd1b79 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -13,9 +13,12 @@ # limitations under the License. import itertools +import numpy as np + from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu import jax.scipy.fft as jsp_fft import scipy.fft as osp_fft @@ -117,5 +120,15 @@ class LaxBackedScipyFftTests(jtu.JaxTestCase): tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) + def testIdctNormalizationPrecision(self): + # reported in https://github.com/jax-ml/jax/issues/23895 + if not config.enable_x64.value: + raise self.skipTest("requires jax_enable_x64=true") + x = np.ones(3, dtype="float64") + n = 10 + expected = osp_fft.idct(x, n=n, type=2) + actual = jsp_fft.idct(x, n=n, type=2) + self.assertArraysAllClose(actual, expected, atol=1e-14) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())