1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 23:56:06 +00:00

[XLA:CPU] Use central difference to calculate numerical gradient

PiperOrigin-RevId: 718383754
This commit is contained in:
Will Froom 2025-01-22 07:48:58 -08:00 committed by jax authors
parent e304e9ea16
commit dc16721b52

@ -92,7 +92,6 @@ class ODETest(jtu.JaxTestCase):
def decay(_np, y, t, arg1, arg2):
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
rng = self.rng()
args = (rng.randn(3), rng.randn(3))
y0 = rng.randn(3)
@ -202,7 +201,7 @@ class ODETest(jtu.JaxTestCase):
ans = h[11], g[11]
expected_h = experiment(t[11])
expected_g = (experiment(t[11] + 1e-5) - expected_h) / 1e-5
expected_g = (experiment(t[11] + 5e-6) - experiment(t[11] - 5e-6)) / 1e-5
expected = expected_h, expected_g
self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)