From c5d4aba2a939ae1a2171a33f2e723490d9197b71 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 9 Mar 2022 13:03:45 -0500 Subject: [PATCH] Fix fft dtype for norm='ortho' --- jax/_src/numpy/fft.py | 3 ++- tests/fft_test.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index d365506be..cbb50be44 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -86,7 +86,8 @@ def _fft_core(func_name, fft_type, a, s, axes, norm): s += [max(0, 2 * (a.shape[axes[-1]] - 1))] else: s = [a.shape[axis] for axis in axes] - transformed = lax.fft(a, fft_type, tuple(s)) * _fft_norm(jnp.array(s), func_name, norm) + transformed = lax.fft(a, fft_type, tuple(s)) + transformed *= _fft_norm(jnp.array(s, dtype=transformed.real.dtype), func_name, norm) if orig_axes is not None: transformed = jnp.moveaxis(transformed, axes, orig_axes) diff --git a/tests/fft_test.py b/tests/fft_test.py index 83e512946..8b15276f9 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -138,6 +138,11 @@ class FftTest(jtu.JaxTestCase): tol = 0.15 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) + # check dtypes + dtype = jnp_fn(rng(shape, dtype)).dtype + expected_dtype = jnp.promote_types(float if inverse and real else complex, dtype) + self.assertEqual(dtype, expected_dtype) + def testIrfftTranspose(self): # regression test for https://github.com/google/jax/issues/6223 def build_matrix(linear_func, size):