From fca5666382ace4b621ce895422a7e0ff8ffe27af Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 12 Nov 2020 10:50:47 -0800 Subject: [PATCH] Relax GMRES test tolerances --- tests/lax_scipy_sparse_test.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index ab55eb5a7..565639f91 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -313,9 +313,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase): tol = shape[0] * jnp.finfo(dtype).eps x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol, restart=restart, - M=M, - qr_mode=qr_mode) - self.assertAllClose(x, solution, atol=300*jnp.finfo(x.dtype).eps) + M=M, qr_mode=qr_mode) + using_x64 = solution.dtype.kind in {np.float64, np.complex128} + solution_tol = 1e-8 if using_x64 else 1e-4 + self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -345,10 +346,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase): restart = shape[-1] tol = shape[0] * jnp.finfo(A.dtype).eps x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol, - restart=restart, - M=M, - qr_mode=qr_mode) - solution_tol = 300*jnp.finfo(x.dtype).eps + restart=restart, + M=M, qr_mode=qr_mode) + using_x64 = solution.dtype.kind in {np.float64, np.complex128} + solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) def test_gmres_pytree(self):