diff --git a/tests/ode_test.py b/tests/ode_test.py index acdfa1fc6..4e9224d33 100644 --- a/tests/ode_test.py +++ b/tests/ode_test.py @@ -90,8 +90,7 @@ class ODETest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu", "gpu") def test_decay(self): def decay(_np, y, t, arg1, arg2): - return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2) - + return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2) rng = self.rng() args = (rng.randn(3), rng.randn(3)) @@ -202,7 +201,7 @@ class ODETest(jtu.JaxTestCase): ans = h[11], g[11] expected_h = experiment(t[11]) - expected_g = (experiment(t[11] + 1e-5) - expected_h) / 1e-5 + expected_g = (experiment(t[11] + 5e-6) - experiment(t[11] - 5e-6)) / 1e-5 expected = expected_h, expected_g self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)