mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[XLA:CPU] Use central difference to calculate numerical gradient
PiperOrigin-RevId: 718383754
This commit is contained in:
parent
e304e9ea16
commit
dc16721b52
@ -90,8 +90,7 @@ class ODETest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_decay(self):
|
||||
def decay(_np, y, t, arg1, arg2):
|
||||
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
|
||||
|
||||
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
|
||||
|
||||
rng = self.rng()
|
||||
args = (rng.randn(3), rng.randn(3))
|
||||
@ -202,7 +201,7 @@ class ODETest(jtu.JaxTestCase):
|
||||
ans = h[11], g[11]
|
||||
|
||||
expected_h = experiment(t[11])
|
||||
expected_g = (experiment(t[11] + 1e-5) - expected_h) / 1e-5
|
||||
expected_g = (experiment(t[11] + 5e-6) - experiment(t[11] - 5e-6)) / 1e-5
|
||||
expected = expected_h, expected_g
|
||||
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user