diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index d365506be..cbb50be44 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -86,7 +86,8 @@ def _fft_core(func_name, fft_type, a, s, axes, norm): s += [max(0, 2 * (a.shape[axes[-1]] - 1))] else: s = [a.shape[axis] for axis in axes] - transformed = lax.fft(a, fft_type, tuple(s)) * _fft_norm(jnp.array(s), func_name, norm) + transformed = lax.fft(a, fft_type, tuple(s)) + transformed *= _fft_norm(jnp.array(s, dtype=transformed.real.dtype), func_name, norm) if orig_axes is not None: transformed = jnp.moveaxis(transformed, axes, orig_axes) diff --git a/tests/fft_test.py b/tests/fft_test.py index 83e512946..8b15276f9 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -138,6 +138,11 @@ class FftTest(jtu.JaxTestCase): tol = 0.15 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) + # check dtypes + dtype = jnp_fn(rng(shape, dtype)).dtype + expected_dtype = jnp.promote_types(float if inverse and real else complex, dtype) + self.assertEqual(dtype, expected_dtype) + def testIrfftTranspose(self): # regression test for https://github.com/google/jax/issues/6223 def build_matrix(linear_func, size):