Use onp instead of np in ode_test (#3288)

* Use onp instead of np in ode_test

* other ode_test.py fixes

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
samuela 2020-06-02 09:54:51 -07:00 committed by GitHub
parent dd81a8dded
commit 1eb7f1b13d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -42,7 +42,7 @@ class ODETest(jtu.JaxTestCase):
scipy_result = jnp.asarray(osp_integrate.odeint(np_fun, y0, tspace, args))
y0, tspace = jnp.array(y0), jnp.array(tspace)
jax_fun = partial(fun, np)
jax_fun = partial(fun, jnp)
jax_result = odeint(jax_fun, y0, tspace, *args)
self.assertAllClose(jax_result, scipy_result, check_dtypes=False, atol=tol, rtol=tol)
@ -51,18 +51,16 @@ class ODETest(jtu.JaxTestCase):
def test_pend_grads(self):
def pend(_np, y, _, m, g):
theta, omega = y
return [omega, -m * omega - g * jnp.sin(theta)]
return [omega, -m * omega - g * _np.sin(theta)]
integrate = partial(odeint, partial(pend, np))
y0 = [jnp.pi - 0.1, 0.0]
ts = jnp.linspace(0., 1., 11)
y0 = [np.pi - 0.1, 0.0]
ts = np.linspace(0., 1., 11)
args = (0.25, 9.8)
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
self.check_against_scipy(pend, y0, ts, *args, tol=tol)
integrate = partial(odeint, partial(pend, jnp))
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
atol=tol, rtol=tol)
@ -72,10 +70,11 @@ class ODETest(jtu.JaxTestCase):
def dynamics(y, _t):
return tree_map(jnp.negative, y)
y0 = (jnp.array(-0.1), jnp.array([[[0.1]]]))
integrate = partial(odeint, dynamics)
ts = jnp.linspace(0., 1., 11)
y0 = (np.array(-0.1), np.array([[[0.1]]]))
ts = np.linspace(0., 1., 11)
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
integrate = partial(odeint, dynamics)
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
atol=tol, rtol=tol)
@ -83,59 +82,65 @@ class ODETest(jtu.JaxTestCase):
def test_weird_time_pendulum_grads(self):
"""Test that gradients are correct when the dynamics depend on t."""
def dynamics(_np, y, t):
return jnp.array([y[1] * -t, -1 * y[1] - 9.8 * jnp.sin(y[0])])
integrate = partial(odeint, partial(dynamics, np))
y0 = [jnp.pi - 0.1, 0.0]
ts = jnp.linspace(0., 1., 11)
return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])])
y0 = [np.pi - 0.1, 0.0]
ts = np.linspace(0., 1., 11)
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
self.check_against_scipy(dynamics, y0, ts, tol=tol)
integrate = partial(odeint, partial(dynamics, jnp))
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
rtol=tol, atol=tol)
@jtu.skip_on_devices("tpu")
def test_decay(self):
def decay(_np, y, t, arg1, arg2):
return -jnp.sqrt(t) - y + arg1 - jnp.mean((y + arg2)**2)
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
integrate = partial(odeint, partial(decay, np))
rng = np.random.RandomState(0)
args = (rng.randn(3), rng.randn(3))
y0 = rng.randn(3)
ts = jnp.linspace(0.1, 0.2, 4)
ts = np.linspace(0.1, 0.2, 4)
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
self.check_against_scipy(decay, y0, ts, *args, tol=tol)
integrate = partial(odeint, partial(decay, jnp))
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
rtol=tol, atol=tol)
@jtu.skip_on_devices("tpu")
def test_swoop(self):
def swoop(_np, y, t, arg1, arg2):
return jnp.array(y - jnp.sin(t) - jnp.cos(t) * arg1 + arg2)
return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)
integrate = partial(odeint, partial(swoop, np))
ts = jnp.array([0.1, 0.2])
ts = np.array([0.1, 0.2])
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
y0 = jnp.linspace(0.1, 0.9, 10)
y0 = np.linspace(0.1, 0.9, 10)
args = (0.1, 0.2)
self.check_against_scipy(swoop, y0, ts, *args, tol=tol)
integrate = partial(odeint, partial(swoop, jnp))
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
rtol=tol, atol=tol)
big_y0 = jnp.linspace(1.1, 10.9, 10)
@jtu.skip_on_devices("tpu")
def test_swoop_bigger(self):
def swoop(_np, y, t, arg1, arg2):
return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)
ts = np.array([0.1, 0.2])
tol = 1e-1 if num_float_bits(np.float64) == 32 else 1e-3
big_y0 = np.linspace(1.1, 10.9, 10)
args = (0.1, 0.3)
self.check_against_scipy(swoop, y0, ts, *args, tol=tol)
self.check_against_scipy(swoop, big_y0, ts, *args, tol=tol)
integrate = partial(odeint, partial(swoop, jnp))
jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2,
rtol=tol, atol=tol)
@ -164,15 +169,15 @@ class ODETest(jtu.JaxTestCase):
ans = jax.grad(g)(2.) # don't crash
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
atol = {np.float64: 5e-15}
rtol = {np.float64: 2e-15}
atol = {jnp.float64: 5e-15}
rtol = {jnp.float64: 2e-15}
self.assertAllClose(ans, expected, check_dtypes=False, atol=atol, rtol=rtol)
def test_disable_jit_odeint_with_vmap(self):
# https://github.com/google/jax/issues/2598
with jax.disable_jit():
t = jax.numpy.array([0.0, 1.0])
x0_eval = jax.numpy.zeros((5, 2))
t = jnp.array([0.0, 1.0])
x0_eval = jnp.zeros((5, 2))
f = lambda x0: odeint(lambda x, _t: x, x0, t)
jax.vmap(f)(x0_eval) # doesn't crash