Relax some test tolerances in N-D FFT tests.

PiperOrigin-RevId: 708259645
This commit is contained in:
Dan Foreman-Mackey 2024-12-20 03:21:11 -08:00 committed by jax authors
parent 5031b6f599
commit 0b190bb665

View File

@ -158,12 +158,13 @@ class FftTest(jtu.JaxTestCase):
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6})
self._CompileAndCheck(jnp_fn, args_maker, atol={np.complex64: 2e-6},
rtol={np.float32: 2e-6})
# Test gradient for differentiable types.
if (config.enable_x64.value and
dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
# TODO(skye): can we be more precise?
tol = 0.15
tol = 0.16
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
# check dtypes