refactor ode tests, add scipy benchmark (#2824)

* refactor ode tests, add scipy benchmark

remove double import

rename to scipy merge vmap test properly

* clean up more global trace state after errors

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Jacob Kelly 2020-04-28 00:53:38 -04:00 committed by GitHub
parent e6df98de55
commit cc0e9a3189
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 81 additions and 130 deletions

View File

@ -623,8 +623,10 @@ def find_top_trace(xs):
@contextmanager
def initial_style_staging():
prev, trace_state.initial_style = trace_state.initial_style, True
yield
trace_state.initial_style = prev
try:
yield
finally:
trace_state.initial_style = prev
# -------------------- abstract values --------------------

View File

@ -25,7 +25,6 @@ Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf
from functools import partial
import operator as op
import time
import jax
import jax.numpy as np
@ -33,11 +32,8 @@ from jax import lax
from jax import ops
from jax.util import safe_map, safe_zip
from jax.flatten_util import ravel_pytree
from jax.test_util import check_grads
from jax.tree_util import tree_map
from jax import linear_util as lu
import numpy as onp
import scipy.integrate as osp_integrate
map = safe_map
zip = safe_zip
@ -240,69 +236,3 @@ def _odeint_rev(func, rtol, atol, mxstep, res, g):
return (y_bar, ts_bar, *args_bar)
_odeint.defvjp(_odeint_fwd, _odeint_rev)
def pend(np, y, _, m, g):
theta, omega = y
return [omega, -m * omega - g * np.sin(theta)]
def benchmark_odeint(fun, y0, tspace, *args):
"""Time performance of JAX odeint method against scipy.integrate.odeint."""
n_trials = 10
n_repeat = 100
y0, tspace = onp.array(y0), onp.array(tspace)
onp_fun = partial(fun, onp)
scipy_times = []
for k in range(n_trials):
start = time.time()
for _ in range(n_repeat):
scipy_result = osp_integrate.odeint(onp_fun, y0, tspace, args)
end = time.time()
print('scipy odeint elapsed time ({} of {}): {}'.format(k+1, n_trials, end-start))
scipy_times.append(end - start)
y0, tspace = np.array(y0), np.array(tspace)
jax_fun = partial(fun, np)
jax_times = []
for k in range(n_trials):
start = time.time()
for _ in range(n_repeat):
jax_result = odeint(jax_fun, y0, tspace, *args)
jax_result.block_until_ready()
end = time.time()
print('JAX odeint elapsed time ({} of {}): {}'.format(k+1, n_trials, end-start))
jax_times.append(end - start)
print('(avg scipy time) / (avg jax time) = {}'.format(
onp.mean(scipy_times[1:]) / onp.mean(jax_times[1:])))
print('norm(scipy result-jax result): {}'.format(
np.linalg.norm(np.asarray(scipy_result) - jax_result)))
return scipy_result, jax_result
def pend_benchmark_odeint():
_, _ = benchmark_odeint(pend, [np.pi - 0.1, 0.0], np.linspace(0., 10., 101),
0.25, 9.8)
def pend_check_grads():
def f(y0, ts, *args):
return odeint(partial(pend, np), y0, ts, *args)
y0 = [np.pi - 0.1, 0.0]
ts = np.linspace(0., 1., 11)
args = (0.25, 9.8)
check_grads(f, (y0, ts, *args), modes=["rev"], order=2,
atol=1e-1, rtol=1e-1)
def weird_time_pendulum_check_grads():
"""Test that gradients are correct when the dynamics depend on t."""
def f(y0, ts):
return odeint(lambda y, t: np.array([y[1] * -t, -1 * y[1] - 9.8 * np.sin(y[0])]), y0, ts)
y0 = [np.pi - 0.1, 0.0]
ts = np.linspace(0., 1., 11)
check_grads(f, (y0, ts), modes=["rev"], order=2)
if __name__ == '__main__':
pend_benchmark_odeint()
pend_check_grads()
weird_time_pendulum_check_grads()

View File

@ -2834,40 +2834,6 @@ class CustomVJPTest(jtu.JaxTestCase):
foo = lambda x: api.vmap(np.linalg.det, (0,))(x)
api.jit(foo)(arr) # doesn't crash
def test_odeint_vmap_grad(self):
# https://github.com/google/jax/issues/2531
# TODO(mattjj): factor out an ode tests file
try:
from jax.experimental.ode import odeint
except ImportError:
raise unittest.SkipTest("missing jax.experimental") from None
def dx_dt(x, *args):
return 0.1 * x
def f(x, y):
y0 = np.array([x, y])
t = np.array([0., 5.])
y = odeint(dx_dt, y0, t)
return y[-1].sum()
def g(x):
# Two initial values for the ODE
y0_arr = np.array([[x, 0.1],
[x, 0.2]])
# Run ODE twice
t = np.array([0., 5.])
y = jax.vmap(lambda y0: odeint(dx_dt, y0, t))(y0_arr)
return y[:,-1].sum()
ans = jax.grad(g)(2.) # don't crash
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
atol = {onp.float64: 5e-15}
rtol = {onp.float64: 2e-15}
self.assertAllClose(ans, expected, check_dtypes=False, atol=atol, rtol=rtol)
def test_lowering_out_of_traces(self):
# https://github.com/google/jax/issues/2578

View File

@ -13,16 +13,19 @@
# limitations under the License.
from functools import partial
import unittest
from absl.testing import absltest
import numpy as onp
from jax import jit
import jax
from jax import dtypes
from jax import test_util as jtu
import jax.numpy as np
from jax.experimental.ode import odeint
import scipy.integrate as osp_integrate
from jax.config import config
config.parse_flags_with_absl()
@ -31,77 +34,127 @@ def num_float_bits(dtype):
return dtypes.finfo(dtypes.canonicalize_dtype(dtype)).bits
class JetTest(jtu.JaxTestCase):
class ODETest(jtu.JaxTestCase):
def check_against_scipy(self, fun, y0, tspace, *args, tol=1e-1):
y0, tspace = onp.array(y0), onp.array(tspace)
onp_fun = partial(fun, onp)
scipy_result = np.asarray(osp_integrate.odeint(onp_fun, y0, tspace, args))
y0, tspace = np.array(y0), np.array(tspace)
jax_fun = partial(fun, np)
jax_result = odeint(jax_fun, y0, tspace, *args)
self.assertAllClose(jax_result, scipy_result, check_dtypes=False, atol=tol, rtol=tol)
@jtu.skip_on_devices("tpu")
def test_pend_grads(self):
def pend(y, _, m, g):
def pend(_np, y, _, m, g):
theta, omega = y
return [omega, -m * omega - g * np.sin(theta)]
return [omega, -m * omega - g * _np.sin(theta)]
def f(y0, ts, *args):
return odeint(pend, y0, ts, *args)
integrate = partial(odeint, partial(pend, np))
y0 = [np.pi - 0.1, 0.0]
ts = np.linspace(0., 1., 11)
args = (0.25, 9.8)
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
jtu.check_grads(f, (y0, ts, *args), modes=["rev"], order=2,
self.check_against_scipy(pend, y0, ts, *args, tol=tol)
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
atol=tol, rtol=tol)
@jtu.skip_on_devices("tpu")
def test_weird_time_pendulum_grads(self):
"""Test that gradients are correct when the dynamics depend on t."""
def dynamics(y, t):
return np.array([y[1] * -t, -1 * y[1] - 9.8 * np.sin(y[0])])
def dynamics(_np, y, t):
return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])])
integrate = partial(odeint, dynamics)
integrate = partial(odeint, partial(dynamics, np))
y0 = [np.pi - 0.1, 0.0]
ts = np.linspace(0., 1., 11)
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
self.check_against_scipy(dynamics, y0, ts, tol=tol)
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
rtol=tol, atol=tol)
@jtu.skip_on_devices("tpu")
def test_decay(self):
def decay(y, t, arg1, arg2):
return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)
def decay(_np, y, t, arg1, arg2):
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
integrate = partial(odeint, partial(decay, np))
rng = onp.random.RandomState(0)
arg1 = rng.randn(3)
arg2 = rng.randn(3)
def integrate(y0, ts):
return odeint(decay, y0, ts, arg1, arg2)
args = (rng.randn(3), rng.randn(3))
y0 = rng.randn(3)
ts = np.linspace(0.1, 0.2, 4)
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
self.check_against_scipy(decay, y0, ts, *args, tol=tol)
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
rtol=tol, atol=tol)
@jtu.skip_on_devices("tpu")
def test_swoop(self):
def swoop(y, t, arg1, arg2):
return np.array(y - np.sin(t) - np.cos(t) * arg1 + arg2)
def swoop(_np, y, t, arg1, arg2):
return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)
integrate = partial(odeint, partial(swoop, np))
ts = np.array([0.1, 0.2])
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
y0 = np.linspace(0.1, 0.9, 10)
integrate = lambda y0, ts: odeint(swoop, y0, ts, 0.1, 0.2)
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
args = (0.1, 0.2)
self.check_against_scipy(swoop, y0, ts, *args, tol=tol)
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
rtol=tol, atol=tol)
big_y0 = np.linspace(1.1, 10.9, 10)
integrate = lambda y0, ts: odeint(swoop, big_y0, ts, 0.1, 0.3)
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
args = (0.1, 0.3)
self.check_against_scipy(swoop, y0, ts, *args, tol=tol)
jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2,
rtol=tol, atol=tol)
def test_odeint_vmap_grad(self):
# https://github.com/google/jax/issues/2531
def dx_dt(x, *args):
return 0.1 * x
def f(x, y):
y0 = np.array([x, y])
t = np.array([0., 5.])
y = odeint(dx_dt, y0, t)
return y[-1].sum()
def g(x):
# Two initial values for the ODE
y0_arr = np.array([[x, 0.1],
[x, 0.2]])
# Run ODE twice
t = np.array([0., 5.])
y = jax.vmap(lambda y0: odeint(dx_dt, y0, t))(y0_arr)
return y[:,-1].sum()
ans = jax.grad(g)(2.) # don't crash
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
atol = {onp.float64: 5e-15}
rtol = {onp.float64: 2e-15}
self.assertAllClose(ans, expected, check_dtypes=False, atol=atol, rtol=rtol)
if __name__ == '__main__':
absltest.main()