Fix dtype bug in jax.scipy.fft.idct

This commit is contained in:
Dan Foreman-Mackey 2024-09-25 12:55:43 -04:00
parent b49d8b2615
commit 96268dcae6
2 changed files with 16 additions and 3 deletions

View File

@ -21,7 +21,7 @@ import math
from jax import lax
import jax.numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import promote_dtypes_complex
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact
from jax._src.typing import Array
def _W4(N: int, k: Array) -> Array:
@ -298,12 +298,12 @@ def idct(x: Array, type: int = 2, n: int | None = None,
[(0, n - x.shape[axis] if a == axis else 0, 0)
for a in range(x.ndim)])
N = x.shape[axis]
x = x.astype(jnp.float32)
x, = promote_dtypes_inexact(x)
if norm is None or norm == 'backward':
x = _dct_ortho_norm(x, axis)
x = _dct_ortho_norm(x, axis)
k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis])
k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis])
# everything is complex from here...
w4 = _W4(N,k)
x = x.astype(w4.dtype)

View File

@ -13,9 +13,12 @@
# limitations under the License.
import itertools
import numpy as np
from absl.testing import absltest
import jax
from jax._src import config
from jax._src import test_util as jtu
import jax.scipy.fft as jsp_fft
import scipy.fft as osp_fft
@ -117,5 +120,15 @@ class LaxBackedScipyFftTests(jtu.JaxTestCase):
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
def testIdctNormalizationPrecision(self):
# reported in https://github.com/jax-ml/jax/issues/23895
if not config.enable_x64.value:
raise self.skipTest("requires jax_enable_x64=true")
x = np.ones(3, dtype="float64")
n = 10
expected = osp_fft.idct(x, n=n, type=2)
actual = jsp_fft.idct(x, n=n, type=2)
self.assertArraysAllClose(actual, expected, atol=1e-14)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())