diff --git a/tests/fft_test.py b/tests/fft_test.py index 3668f534d..26b69f5be 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -158,12 +158,13 @@ class FftTest(jtu.JaxTestCase): # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6}) + self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6}, + rtol={np.float32: 2e-6}) # Test gradient for differentiable types. if (config.enable_x64.value and dtype in (float_dtypes if real and not inverse else inexact_dtypes)): # TODO(skye): can we be more precise? - tol = 0.15 + tol = 0.16 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) # check dtypes