From 0b190bb665060ea241650e77a7c311fec5d853bb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 20 Dec 2024 03:21:11 -0800 Subject: [PATCH] Relax some test tolerances in N-D FFT tests. PiperOrigin-RevId: 708259645 --- tests/fft_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/fft_test.py b/tests/fft_test.py index 3668f534d..26b69f5be 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -158,12 +158,13 @@ class FftTest(jtu.JaxTestCase): # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6}) + self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6}, + rtol={np.float32: 2e-6}) # Test gradient for differentiable types. if (config.enable_x64.value and dtype in (float_dtypes if real and not inverse else inexact_dtypes)): # TODO(skye): can we be more precise? - tol = 0.15 + tol = 0.16 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) # check dtypes