Merge pull request #9815 from eelregit:fft_norm_dtype_patch

PiperOrigin-RevId: 434772719
This commit is contained in:
jax authors 2022-03-15 09:29:03 -07:00
commit 98ad016794
2 changed files with 7 additions and 1 deletions

View File

@ -86,7 +86,8 @@ def _fft_core(func_name, fft_type, a, s, axes, norm):
s += [max(0, 2 * (a.shape[axes[-1]] - 1))]
else:
s = [a.shape[axis] for axis in axes]
transformed = lax.fft(a, fft_type, tuple(s)) * _fft_norm(jnp.array(s), func_name, norm)
transformed = lax.fft(a, fft_type, tuple(s))
transformed *= _fft_norm(jnp.array(s, dtype=transformed.real.dtype), func_name, norm)
if orig_axes is not None:
transformed = jnp.moveaxis(transformed, axes, orig_axes)

View File

@ -138,6 +138,11 @@ class FftTest(jtu.JaxTestCase):
tol = 0.15
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
# check dtypes
dtype = jnp_fn(rng(shape, dtype)).dtype
expected_dtype = jnp.promote_types(float if inverse and real else complex, dtype)
self.assertEqual(dtype, expected_dtype)
def testIrfftTranspose(self):
# regression test for https://github.com/google/jax/issues/6223
def build_matrix(linear_func, size):