mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix dtype bug in jax.scipy.fft.idct
This commit is contained in:
parent
b49d8b2615
commit
96268dcae6
@ -21,7 +21,7 @@ import math
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.numpy.util import promote_dtypes_complex
|
||||
from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact
|
||||
from jax._src.typing import Array
|
||||
|
||||
def _W4(N: int, k: Array) -> Array:
|
||||
@ -298,12 +298,12 @@ def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
[(0, n - x.shape[axis] if a == axis else 0, 0)
|
||||
for a in range(x.ndim)])
|
||||
N = x.shape[axis]
|
||||
x = x.astype(jnp.float32)
|
||||
x, = promote_dtypes_inexact(x)
|
||||
if norm is None or norm == 'backward':
|
||||
x = _dct_ortho_norm(x, axis)
|
||||
x = _dct_ortho_norm(x, axis)
|
||||
|
||||
k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis])
|
||||
k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis])
|
||||
# everything is complex from here...
|
||||
w4 = _W4(N,k)
|
||||
x = x.astype(w4.dtype)
|
||||
|
@ -13,9 +13,12 @@
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
import jax.scipy.fft as jsp_fft
|
||||
import scipy.fft as osp_fft
|
||||
@ -117,5 +120,15 @@ class LaxBackedScipyFftTests(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
|
||||
|
||||
def testIdctNormalizationPrecision(self):
|
||||
# reported in https://github.com/jax-ml/jax/issues/23895
|
||||
if not config.enable_x64.value:
|
||||
raise self.skipTest("requires jax_enable_x64=true")
|
||||
x = np.ones(3, dtype="float64")
|
||||
n = 10
|
||||
expected = osp_fft.idct(x, n=n, type=2)
|
||||
actual = jsp_fft.idct(x, n=n, type=2)
|
||||
self.assertArraysAllClose(actual, expected, atol=1e-14)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user