[XLA:CPU] Use central difference to calculate numerical gradient

PiperOrigin-RevId: 718383754
This commit is contained in:
Will Froom 2025-01-22 07:48:58 -08:00 committed by jax authors
parent e304e9ea16
commit dc16721b52

View File

@ -90,8 +90,7 @@ class ODETest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_decay(self):
def decay(_np, y, t, arg1, arg2):
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
rng = self.rng()
args = (rng.randn(3), rng.randn(3))
@ -202,7 +201,7 @@ class ODETest(jtu.JaxTestCase):
ans = h[11], g[11]
expected_h = experiment(t[11])
expected_g = (experiment(t[11] + 1e-5) - expected_h) / 1e-5
expected_g = (experiment(t[11] + 5e-6) - experiment(t[11] - 5e-6)) / 1e-5
expected = expected_h, expected_g
self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)