mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Relax GMRES test tolerances
This commit is contained in:
parent
7e62270e5a
commit
fca5666382
@ -313,9 +313,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
tol = shape[0] * jnp.finfo(dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M,
|
||||
qr_mode=qr_mode)
|
||||
self.assertAllClose(x, solution, atol=300*jnp.finfo(x.dtype).eps)
|
||||
M=M, qr_mode=qr_mode)
|
||||
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
||||
solution_tol = 1e-8 if using_x64 else 1e-4
|
||||
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -345,10 +346,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
restart = shape[-1]
|
||||
tol = shape[0] * jnp.finfo(A.dtype).eps
|
||||
x, info = jax.scipy.sparse.linalg._gmres(A_mv, b, tol=tol, atol=tol,
|
||||
restart=restart,
|
||||
M=M,
|
||||
qr_mode=qr_mode)
|
||||
solution_tol = 300*jnp.finfo(x.dtype).eps
|
||||
restart=restart,
|
||||
M=M, qr_mode=qr_mode)
|
||||
using_x64 = solution.dtype.kind in {np.float64, np.complex128}
|
||||
solution_tol = 1e-8 if using_x64 else 1e-4
|
||||
self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
|
||||
|
||||
def test_gmres_pytree(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user