Relax tolerances slightly for MKL.

Fix https://github.com/google/jax/issues/9705.
This commit is contained in:
Samuel Ainsworth 2022-02-25 22:02:55 +00:00
parent 7ec04c8311
commit bf59b7d872

View File

@ -224,7 +224,7 @@ class CustomRootTest(jtu.JaxTestCase):
fwd_val, fwd_aux = fwd(a, b)
expected_fwd_val = expected_fwd(a, b)
self.assertAllClose(fwd_val, expected_fwd_val, rtol={np.float32: 1E-6, np.float64: 1E-12})
self.assertAllClose(fwd_val, expected_fwd_val, rtol={np.float32: 5E-6, np.float64: 5E-12})
jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))