lax_test: adjust TPU tolerance for igamma & friends

PiperOrigin-RevId: 564859109
This commit is contained in:
Jake VanderPlas 2023-09-12 15:59:00 -07:00 committed by jax authors
parent a26125c49e
commit 56791eb9ec

View File

@ -108,7 +108,7 @@ class LaxTest(jtu.JaxTestCase):
if dtype in (np.float32, np.complex64) and op_name in (
"acosh", "asinh", "cos", "cosh", "digamma", "exp", "exp2", "igamma",
"igammac", "log", "log1p", "logistic", "pow", "sin", "sinh", "tan"):
tol = jtu.join_tolerance(tol, 1e-4)
tol = jtu.join_tolerance(tol, 2e-4)
if op_name == "asinh" and dtype == np.float16:
tol = jtu.join_tolerance(tol, 1e-3)
self._CheckAgainstNumpy(numpy_op, op, args_maker, tol=tol)