mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #18010 from hawkinsp:macfailures
PiperOrigin-RevId: 571975100
This commit is contained in:
commit
1ec51f4cba
@ -1662,6 +1662,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
assert b.shape == ()
|
||||
return c, b
|
||||
|
||||
if scan is scan_with_new_checkpoint:
|
||||
rtol = {np.float32: 5e-5, np.float64: 1e-13}
|
||||
atol = 1e-5
|
||||
elif scan is scan_with_for:
|
||||
rtol = {np.float32: 2e-5, np.float64: 1e-13}
|
||||
atol = {np.float32: 6e-2, np.float64: 1e-13}
|
||||
else:
|
||||
rtol = {np.float32: 2e-5, np.float64: 1e-13}
|
||||
atol = {np.float32: 5e-5, np.float64: 1e-13}
|
||||
|
||||
if jit_f:
|
||||
f = jax.jit(f)
|
||||
if jit_scan:
|
||||
@ -1672,15 +1682,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
ans = jax.grad(lambda c, as_: list( scan(f, c, as_))[0].sum())(c, as_)
|
||||
expected = jax.grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
|
||||
if scan is scan_with_new_checkpoint:
|
||||
rtol = {np.float32: 5e-5, np.float64: 1e-13}
|
||||
atol = 1e-5
|
||||
elif scan is scan_with_for:
|
||||
rtol = {np.float32: 2e-5, np.float64: 1e-13}
|
||||
atol = {np.float32: 6e-2, np.float64: 1e-13}
|
||||
else:
|
||||
rtol = {np.float32: 2e-5, np.float64: 1e-13}
|
||||
atol = 1e-5
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol, atol=atol)
|
||||
|
||||
rtol = 5e-3 if scan is not scan_with_new_checkpoint2 else 5e-2
|
||||
|
Loading…
x
Reference in New Issue
Block a user