Merge pull request #18010 from hawkinsp:macfailures

PiperOrigin-RevId: 571975100
This commit is contained in:
jax authors 2023-10-09 10:26:55 -07:00
commit 1ec51f4cba

View File

@ -1662,6 +1662,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
assert b.shape == ()
return c, b
if scan is scan_with_new_checkpoint:
rtol = {np.float32: 5e-5, np.float64: 1e-13}
atol = 1e-5
elif scan is scan_with_for:
rtol = {np.float32: 2e-5, np.float64: 1e-13}
atol = {np.float32: 6e-2, np.float64: 1e-13}
else:
rtol = {np.float32: 2e-5, np.float64: 1e-13}
atol = {np.float32: 5e-5, np.float64: 1e-13}
if jit_f:
f = jax.jit(f)
if jit_scan:
@ -1672,15 +1682,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
ans = jax.grad(lambda c, as_: list( scan(f, c, as_))[0].sum())(c, as_)
expected = jax.grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
if scan is scan_with_new_checkpoint:
rtol = {np.float32: 5e-5, np.float64: 1e-13}
atol = 1e-5
elif scan is scan_with_for:
rtol = {np.float32: 2e-5, np.float64: 1e-13}
atol = {np.float32: 6e-2, np.float64: 1e-13}
else:
rtol = {np.float32: 2e-5, np.float64: 1e-13}
atol = 1e-5
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol, atol=atol)
rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2