diff --git a/tests/lax_test.py b/tests/lax_test.py index 513a46bd3..141f5e226 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -108,7 +108,7 @@ class LaxTest(jtu.JaxTestCase): if dtype in (np.float32, np.complex64) and op_name in ( "acosh", "asinh", "cos", "cosh", "digamma", "exp", "exp2", "igamma", "igammac", "log", "log1p", "logistic", "pow", "sin", "sinh", "tan"): - tol = jtu.join_tolerance(tol, 1e-4) + tol = jtu.join_tolerance(tol, 2e-4) if op_name == "asinh" and dtype == np.float16: tol = jtu.join_tolerance(tol, 1e-3) self._CheckAgainstNumpy(numpy_op, op, args_maker, tol=tol)