diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index ab55eb5a7..565639f91 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -313,9 +313,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase): tol = shape[0] * jnp.finfo(dtype).eps x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol, restart=restart, - M=M, - qr_mode=qr_mode) - self.assertAllClose(x, solution, atol=300*jnp.finfo(x.dtype).eps) + M=M, qr_mode=qr_mode) + using_x64 = solution.dtype.kind in {np.float64, np.complex128} + solution_tol = 1e-8 if using_x64 else 1e-4 + self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -345,10 +346,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase): restart = shape[-1] tol = shape[0] * jnp.finfo(A.dtype).eps x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol, - restart=restart, - M=M, - qr_mode=qr_mode) - solution_tol = 300*jnp.finfo(x.dtype).eps + restart=restart, + M=M, qr_mode=qr_mode) + using_x64 = solution.dtype.kind in {np.float64, np.complex128} + solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) def test_gmres_pytree(self):