Relax GMRES test tolerances

This commit is contained in:
Stephan Hoyer 2020-11-12 10:50:47 -08:00
parent 7e62270e5a
commit fca5666382

View File

@ -313,9 +313,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
tol = shape[0] * jnp.finfo(dtype).eps
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M,
qr_mode=qr_mode)
self.assertAllClose(x, solution, atol=300*jnp.finfo(x.dtype).eps)
M=M, qr_mode=qr_mode)
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -345,10 +346,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
restart = shape[-1]
tol = shape[0] * jnp.finfo(A.dtype).eps
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
restart=restart,
M=M,
qr_mode=qr_mode)
solution_tol = 300*jnp.finfo(x.dtype).eps
restart=restart,
M=M, qr_mode=qr_mode)
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
solution_tol = 1e-8 if using_x64 else 1e-4
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
def test_gmres_pytree(self):