mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
refactor ode tests, add scipy benchmark (#2824)
* refactor ode tests, add scipy benchmark remove double import rename to scipy merge vmap test properly * clean up more global trace state after errors Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
e6df98de55
commit
cc0e9a3189
@ -623,8 +623,10 @@ def find_top_trace(xs):
|
||||
@contextmanager
|
||||
def initial_style_staging():
|
||||
prev, trace_state.initial_style = trace_state.initial_style, True
|
||||
yield
|
||||
trace_state.initial_style = prev
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
trace_state.initial_style = prev
|
||||
|
||||
|
||||
# -------------------- abstract values --------------------
|
||||
|
@ -25,7 +25,6 @@ Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf
|
||||
|
||||
from functools import partial
|
||||
import operator as op
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as np
|
||||
@ -33,11 +32,8 @@ from jax import lax
|
||||
from jax import ops
|
||||
from jax.util import safe_map, safe_zip
|
||||
from jax.flatten_util import ravel_pytree
|
||||
from jax.test_util import check_grads
|
||||
from jax.tree_util import tree_map
|
||||
from jax import linear_util as lu
|
||||
import numpy as onp
|
||||
import scipy.integrate as osp_integrate
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
@ -240,69 +236,3 @@ def _odeint_rev(func, rtol, atol, mxstep, res, g):
|
||||
return (y_bar, ts_bar, *args_bar)
|
||||
|
||||
_odeint.defvjp(_odeint_fwd, _odeint_rev)
|
||||
|
||||
|
||||
def pend(np, y, _, m, g):
|
||||
theta, omega = y
|
||||
return [omega, -m * omega - g * np.sin(theta)]
|
||||
|
||||
def benchmark_odeint(fun, y0, tspace, *args):
|
||||
"""Time performance of JAX odeint method against scipy.integrate.odeint."""
|
||||
n_trials = 10
|
||||
n_repeat = 100
|
||||
y0, tspace = onp.array(y0), onp.array(tspace)
|
||||
onp_fun = partial(fun, onp)
|
||||
scipy_times = []
|
||||
for k in range(n_trials):
|
||||
start = time.time()
|
||||
for _ in range(n_repeat):
|
||||
scipy_result = osp_integrate.odeint(onp_fun, y0, tspace, args)
|
||||
end = time.time()
|
||||
print('scipy odeint elapsed time ({} of {}): {}'.format(k+1, n_trials, end-start))
|
||||
scipy_times.append(end - start)
|
||||
y0, tspace = np.array(y0), np.array(tspace)
|
||||
jax_fun = partial(fun, np)
|
||||
jax_times = []
|
||||
for k in range(n_trials):
|
||||
start = time.time()
|
||||
for _ in range(n_repeat):
|
||||
jax_result = odeint(jax_fun, y0, tspace, *args)
|
||||
jax_result.block_until_ready()
|
||||
end = time.time()
|
||||
print('JAX odeint elapsed time ({} of {}): {}'.format(k+1, n_trials, end-start))
|
||||
jax_times.append(end - start)
|
||||
print('(avg scipy time) / (avg jax time) = {}'.format(
|
||||
onp.mean(scipy_times[1:]) / onp.mean(jax_times[1:])))
|
||||
print('norm(scipy result-jax result): {}'.format(
|
||||
np.linalg.norm(np.asarray(scipy_result) - jax_result)))
|
||||
return scipy_result, jax_result
|
||||
|
||||
def pend_benchmark_odeint():
|
||||
_, _ = benchmark_odeint(pend, [np.pi - 0.1, 0.0], np.linspace(0., 10., 101),
|
||||
0.25, 9.8)
|
||||
|
||||
def pend_check_grads():
|
||||
def f(y0, ts, *args):
|
||||
return odeint(partial(pend, np), y0, ts, *args)
|
||||
|
||||
y0 = [np.pi - 0.1, 0.0]
|
||||
ts = np.linspace(0., 1., 11)
|
||||
args = (0.25, 9.8)
|
||||
|
||||
check_grads(f, (y0, ts, *args), modes=["rev"], order=2,
|
||||
atol=1e-1, rtol=1e-1)
|
||||
|
||||
def weird_time_pendulum_check_grads():
|
||||
"""Test that gradients are correct when the dynamics depend on t."""
|
||||
def f(y0, ts):
|
||||
return odeint(lambda y, t: np.array([y[1] * -t, -1 * y[1] - 9.8 * np.sin(y[0])]), y0, ts)
|
||||
|
||||
y0 = [np.pi - 0.1, 0.0]
|
||||
ts = np.linspace(0., 1., 11)
|
||||
|
||||
check_grads(f, (y0, ts), modes=["rev"], order=2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
pend_benchmark_odeint()
|
||||
pend_check_grads()
|
||||
weird_time_pendulum_check_grads()
|
||||
|
@ -2834,40 +2834,6 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
foo = lambda x: api.vmap(np.linalg.det, (0,))(x)
|
||||
api.jit(foo)(arr) # doesn't crash
|
||||
|
||||
def test_odeint_vmap_grad(self):
|
||||
# https://github.com/google/jax/issues/2531
|
||||
# TODO(mattjj): factor out an ode tests file
|
||||
try:
|
||||
from jax.experimental.ode import odeint
|
||||
except ImportError:
|
||||
raise unittest.SkipTest("missing jax.experimental") from None
|
||||
|
||||
def dx_dt(x, *args):
|
||||
return 0.1 * x
|
||||
|
||||
def f(x, y):
|
||||
y0 = np.array([x, y])
|
||||
t = np.array([0., 5.])
|
||||
y = odeint(dx_dt, y0, t)
|
||||
return y[-1].sum()
|
||||
|
||||
def g(x):
|
||||
# Two initial values for the ODE
|
||||
y0_arr = np.array([[x, 0.1],
|
||||
[x, 0.2]])
|
||||
|
||||
# Run ODE twice
|
||||
t = np.array([0., 5.])
|
||||
y = jax.vmap(lambda y0: odeint(dx_dt, y0, t))(y0_arr)
|
||||
return y[:,-1].sum()
|
||||
|
||||
ans = jax.grad(g)(2.) # don't crash
|
||||
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
|
||||
|
||||
atol = {onp.float64: 5e-15}
|
||||
rtol = {onp.float64: 2e-15}
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=atol, rtol=rtol)
|
||||
|
||||
def test_lowering_out_of_traces(self):
|
||||
# https://github.com/google/jax/issues/2578
|
||||
|
||||
|
@ -13,16 +13,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
|
||||
from jax import jit
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import test_util as jtu
|
||||
import jax.numpy as np
|
||||
from jax.experimental.ode import odeint
|
||||
|
||||
import scipy.integrate as osp_integrate
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -31,77 +34,127 @@ def num_float_bits(dtype):
|
||||
return dtypes.finfo(dtypes.canonicalize_dtype(dtype)).bits
|
||||
|
||||
|
||||
class JetTest(jtu.JaxTestCase):
|
||||
class ODETest(jtu.JaxTestCase):
|
||||
|
||||
def check_against_scipy(self, fun, y0, tspace, *args, tol=1e-1):
|
||||
y0, tspace = onp.array(y0), onp.array(tspace)
|
||||
onp_fun = partial(fun, onp)
|
||||
scipy_result = np.asarray(osp_integrate.odeint(onp_fun, y0, tspace, args))
|
||||
|
||||
y0, tspace = np.array(y0), np.array(tspace)
|
||||
jax_fun = partial(fun, np)
|
||||
jax_result = odeint(jax_fun, y0, tspace, *args)
|
||||
|
||||
self.assertAllClose(jax_result, scipy_result, check_dtypes=False, atol=tol, rtol=tol)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_pend_grads(self):
|
||||
def pend(y, _, m, g):
|
||||
def pend(_np, y, _, m, g):
|
||||
theta, omega = y
|
||||
return [omega, -m * omega - g * np.sin(theta)]
|
||||
return [omega, -m * omega - g * _np.sin(theta)]
|
||||
|
||||
def f(y0, ts, *args):
|
||||
return odeint(pend, y0, ts, *args)
|
||||
integrate = partial(odeint, partial(pend, np))
|
||||
|
||||
y0 = [np.pi - 0.1, 0.0]
|
||||
ts = np.linspace(0., 1., 11)
|
||||
args = (0.25, 9.8)
|
||||
|
||||
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
|
||||
jtu.check_grads(f, (y0, ts, *args), modes=["rev"], order=2,
|
||||
|
||||
self.check_against_scipy(pend, y0, ts, *args, tol=tol)
|
||||
|
||||
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_weird_time_pendulum_grads(self):
|
||||
"""Test that gradients are correct when the dynamics depend on t."""
|
||||
def dynamics(y, t):
|
||||
return np.array([y[1] * -t, -1 * y[1] - 9.8 * np.sin(y[0])])
|
||||
def dynamics(_np, y, t):
|
||||
return _np.array([y[1] * -t, -1 * y[1] - 9.8 * _np.sin(y[0])])
|
||||
|
||||
integrate = partial(odeint, dynamics)
|
||||
integrate = partial(odeint, partial(dynamics, np))
|
||||
|
||||
y0 = [np.pi - 0.1, 0.0]
|
||||
ts = np.linspace(0., 1., 11)
|
||||
|
||||
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
|
||||
|
||||
self.check_against_scipy(dynamics, y0, ts, tol=tol)
|
||||
|
||||
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_decay(self):
|
||||
def decay(y, t, arg1, arg2):
|
||||
return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)
|
||||
def decay(_np, y, t, arg1, arg2):
|
||||
return -_np.sqrt(t) - y + arg1 - _np.mean((y + arg2)**2)
|
||||
|
||||
integrate = partial(odeint, partial(decay, np))
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
arg1 = rng.randn(3)
|
||||
arg2 = rng.randn(3)
|
||||
|
||||
def integrate(y0, ts):
|
||||
return odeint(decay, y0, ts, arg1, arg2)
|
||||
args = (rng.randn(3), rng.randn(3))
|
||||
|
||||
y0 = rng.randn(3)
|
||||
ts = np.linspace(0.1, 0.2, 4)
|
||||
|
||||
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
|
||||
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
|
||||
|
||||
self.check_against_scipy(decay, y0, ts, *args, tol=tol)
|
||||
|
||||
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_swoop(self):
|
||||
def swoop(y, t, arg1, arg2):
|
||||
return np.array(y - np.sin(t) - np.cos(t) * arg1 + arg2)
|
||||
def swoop(_np, y, t, arg1, arg2):
|
||||
return _np.array(y - _np.sin(t) - _np.cos(t) * arg1 + arg2)
|
||||
|
||||
integrate = partial(odeint, partial(swoop, np))
|
||||
|
||||
ts = np.array([0.1, 0.2])
|
||||
tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
|
||||
|
||||
y0 = np.linspace(0.1, 0.9, 10)
|
||||
integrate = lambda y0, ts: odeint(swoop, y0, ts, 0.1, 0.2)
|
||||
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
|
||||
args = (0.1, 0.2)
|
||||
self.check_against_scipy(swoop, y0, ts, *args, tol=tol)
|
||||
jtu.check_grads(integrate, (y0, ts, *args), modes=["rev"], order=2,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
big_y0 = np.linspace(1.1, 10.9, 10)
|
||||
integrate = lambda y0, ts: odeint(swoop, big_y0, ts, 0.1, 0.3)
|
||||
jtu.check_grads(integrate, (y0, ts), modes=["rev"], order=2,
|
||||
args = (0.1, 0.3)
|
||||
self.check_against_scipy(swoop, y0, ts, *args, tol=tol)
|
||||
jtu.check_grads(integrate, (big_y0, ts, *args), modes=["rev"], order=2,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
def test_odeint_vmap_grad(self):
|
||||
# https://github.com/google/jax/issues/2531
|
||||
|
||||
def dx_dt(x, *args):
|
||||
return 0.1 * x
|
||||
|
||||
def f(x, y):
|
||||
y0 = np.array([x, y])
|
||||
t = np.array([0., 5.])
|
||||
y = odeint(dx_dt, y0, t)
|
||||
return y[-1].sum()
|
||||
|
||||
def g(x):
|
||||
# Two initial values for the ODE
|
||||
y0_arr = np.array([[x, 0.1],
|
||||
[x, 0.2]])
|
||||
|
||||
# Run ODE twice
|
||||
t = np.array([0., 5.])
|
||||
y = jax.vmap(lambda y0: odeint(dx_dt, y0, t))(y0_arr)
|
||||
return y[:,-1].sum()
|
||||
|
||||
ans = jax.grad(g)(2.) # don't crash
|
||||
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
|
||||
|
||||
atol = {onp.float64: 5e-15}
|
||||
rtol = {onp.float64: 2e-15}
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user