mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use onp instead of np in ode_test (#3288)
* Use onp instead of np in ode_test * other ode_test.py fixes Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
dd81a8dded
commit
1eb7f1b13d
@ -42,7 +42,7 @@ class ODETest(jtu.JaxTestCase):
|
||||
scipy_result = jnp.asarray(osp_integrate.odeint(np_fun, y0, tspace, args))
|
||||
|
||||
y0, tspace = jnp.array(y0), jnp.array(tspace)
|
||||
jax_fun = partial(fun, np)
|
||||
jax_fun = partial(fun, jnp)
|
||||
jax_result = odeint(jax_fun, y0, tspace, *args)
|
||||
|
||||
self.assertAllClose(jax_result, scipy_result, check_dtypes=False, atol=tol, rtol=tol)
|
||||
@ -51,18 +51,16 @@ class ODETest(jtu.JaxTestCase):
|
||||
def test_pend_grads(self):
|
||||
def pend(_np, y, _, m, g):
|
||||
theta, omega = y
|
||||
return [omega, -m * omega - g * jnp.sin(theta)]
|
||||
return [omega, -m * omega - g * _np.sin(theta)]
|
||||
|
||||
integrate = partial(odeint, partial(pend, np))
|
||||
|
||||
y0 = [jnp.pi - 0.1, 0.0]
|
||||
ts = jnp.linspace(0., 1., 11)
|
||||
y0 = [np.pi - 0.1, 0.0]
|
||||
ts = np.linspace(0., 1., 11)
|
||||
args = (0.25, 9.8)
|
||||
|
||||
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
|
||||
|
||||
self.check_against_scipy(pend, y0, ts, *args, tol=tol)
|
||||
|
||||
integrate = partial(odeint, partial(pend, jnp))
|
||||
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@ -72,10 +70,11 @@ class ODETest(jtu.JaxTestCase):
|
||||
def dynamics(y, _t):
|
||||
return tree_map(jnp.negative, y)
|
||||
|
||||
y0 = (jnp.array(-0.1), jnp.array([[[0.1]]]))
|
||||
integrate = partial(odeint, dynamics)
|
||||
ts = jnp.linspace(0., 1., 11)
|
||||
y0 = (np.array(-0.1), np.array([[[0.1]]]))
|
||||
ts = np.linspace(0., 1., 11)
|
||||
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
|
||||
|
||||
integrate = partial(odeint, dynamics)
|
||||
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@ -83,59 +82,65 @@ class ODETest(jtu.JaxTestCase):
|
||||
def test_weird_time_pendulum_grads(self):
|
||||
"""Test that gradients are correct when the dynamics depend on t."""
|
||||
def dynamics(_np, y, t):
|
||||
return jnp.array([y[1] * -t, -1 * y[1] - 9.8 * jnp.sin(y[0])])
|
||||
|
||||
integrate = partial(odeint, partial(dynamics, np))
|
||||
|
||||
y0 = [jnp.pi - 0.1, 0.0]
|
||||
ts = jnp.linspace(0., 1., 11)
|
||||
return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])])
|
||||
|
||||
y0 = [np.pi - 0.1, 0.0]
|
||||
ts = np.linspace(0., 1., 11)
|
||||
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
|
||||
|
||||
self.check_against_scipy(dynamics, y0, ts, tol=tol)
|
||||
|
||||
integrate = partial(odeint, partial(dynamics, jnp))
|
||||
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_decay(self):
|
||||
def decay(_np, y, t, arg1, arg2):
|
||||
return -jnp.sqrt(t) - y + arg1 - jnp.mean((y + arg2)**2)
|
||||
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
|
||||
|
||||
integrate = partial(odeint, partial(decay, np))
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
args = (rng.randn(3), rng.randn(3))
|
||||
|
||||
y0 = rng.randn(3)
|
||||
ts = jnp.linspace(0.1, 0.2, 4)
|
||||
|
||||
ts = np.linspace(0.1, 0.2, 4)
|
||||
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
|
||||
|
||||
self.check_against_scipy(decay, y0, ts, *args, tol=tol)
|
||||
|
||||
integrate = partial(odeint, partial(decay, jnp))
|
||||
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_swoop(self):
|
||||
def swoop(_np, y, t, arg1, arg2):
|
||||
return jnp.array(y - jnp.sin(t) - jnp.cos(t) * arg1 + arg2)
|
||||
return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)
|
||||
|
||||
integrate = partial(odeint, partial(swoop, np))
|
||||
|
||||
ts = jnp.array([0.1, 0.2])
|
||||
ts = np.array([0.1, 0.2])
|
||||
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
|
||||
|
||||
y0 = jnp.linspace(0.1, 0.9, 10)
|
||||
y0 = np.linspace(0.1, 0.9, 10)
|
||||
args = (0.1, 0.2)
|
||||
|
||||
self.check_against_scipy(swoop, y0, ts, *args, tol=tol)
|
||||
|
||||
integrate = partial(odeint, partial(swoop, jnp))
|
||||
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
big_y0 = jnp.linspace(1.1, 10.9, 10)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_swoop_bigger(self):
|
||||
def swoop(_np, y, t, arg1, arg2):
|
||||
return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)
|
||||
|
||||
ts = np.array([0.1, 0.2])
|
||||
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
|
||||
big_y0 = np.linspace(1.1, 10.9, 10)
|
||||
args = (0.1, 0.3)
|
||||
self.check_against_scipy(swoop, y0, ts, *args, tol=tol)
|
||||
|
||||
self.check_against_scipy(swoop, big_y0, ts, *args, tol=tol)
|
||||
|
||||
integrate = partial(odeint, partial(swoop, jnp))
|
||||
jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@ -164,15 +169,15 @@ class ODETest(jtu.JaxTestCase):
|
||||
ans = jax.grad(g)(2.) # don't crash
|
||||
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
|
||||
|
||||
atol = {np.float64: 5e-15}
|
||||
rtol = {np.float64: 2e-15}
|
||||
atol = {jnp.float64: 5e-15}
|
||||
rtol = {jnp.float64: 2e-15}
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=atol, rtol=rtol)
|
||||
|
||||
def test_disable_jit_odeint_with_vmap(self):
|
||||
# https://github.com/google/jax/issues/2598
|
||||
with jax.disable_jit():
|
||||
t = jax.numpy.array([0.0, 1.0])
|
||||
x0_eval = jax.numpy.zeros((5, 2))
|
||||
t = jnp.array([0.0, 1.0])
|
||||
x0_eval = jnp.zeros((5, 2))
|
||||
f = lambda x0: odeint(lambda x, _t: x, x0, t)
|
||||
jax.vmap(f)(x0_eval) # doesn't crash
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user