mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 13:46:08 +00:00
630 lines
20 KiB
Python
630 lines
20 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""JAX-based Dormand-Prince ODE integration with adaptive stepsize.
|
|
|
|
Integrate systems of ordinary differential equations (ODEs) using the JAX
|
|
autograd/diff library and the Dormand-Prince method for adaptive integration
|
|
stepsize calculation. Provides improved integration accuracy over fixed
|
|
stepsize integration methods.
|
|
|
|
Adjoint algorithm based on Appendix C of https://arxiv.org/pdf/1806.07366.pdf
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
import time
|
|
|
|
import jax
|
|
from jax.flatten_util import ravel_pytree
|
|
import jax.lax
|
|
import jax.numpy as np
|
|
import jax.ops
|
|
from jax.test_util import check_vjp
|
|
import matplotlib.pyplot as plt
|
|
import numpy as onp
|
|
import scipy.integrate as osp_integrate
|
|
|
|
|
|
# Dopri5 Butcher tableaux
|
|
alpha = np.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
|
|
beta = np.array(
|
|
[[1 / 5, 0, 0, 0, 0, 0, 0],
|
|
[3 / 40, 9 / 40, 0, 0, 0, 0, 0],
|
|
[44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
|
|
[19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729, 0, 0, 0],
|
|
[9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
|
|
[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]])
|
|
c_sol = np.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84,
|
|
0])
|
|
c_error = np.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
|
|
125 / 192 - 451 / 720, -2187 / 6784 - -12231 / 42400,
|
|
11 / 84 - 649 / 6300, -1. / 60.])
|
|
dps_c_mid = np.array([
|
|
6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2,
|
|
-2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2,
|
|
-1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2])
|
|
|
|
|
|
@jax.jit
|
|
def interp_fit_dopri(y0, y1, k, dt):
|
|
# Fit a polynomial to the results of a Runge-Kutta step.
|
|
y_mid = y0 + dt * np.dot(dps_c_mid, k)
|
|
return np.array(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))
|
|
|
|
|
|
@jax.jit
|
|
def fit_4th_order_polynomial(y0, y1, y_mid, dy0, dy1, dt):
|
|
"""Fit fourth order polynomial over function interval.
|
|
|
|
Args:
|
|
y0: function value at the start of the interval.
|
|
y1: function value at the end of the interval.
|
|
y_mid: function value at the mid-point of the interval.
|
|
dy0: derivative value at the start of the interval.
|
|
dy1: derivative value at the end of the interval.
|
|
dt: width of the interval.
|
|
Returns:
|
|
Coefficients `[a, b, c, d, e]` for the polynomial
|
|
p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
|
|
"""
|
|
v = np.stack([dy0, dy1, y0, y1, y_mid])
|
|
a = np.dot(np.hstack([-2. * dt, 2. * dt, np.array([-8., -8., 16.])]), v)
|
|
b = np.dot(np.hstack([5. * dt, -3. * dt, np.array([18., 14., -32.])]), v)
|
|
c = np.dot(np.hstack([-4. * dt, dt, np.array([-11., -5., 16.])]), v)
|
|
d = dt * dy0
|
|
e = y0
|
|
return a, b, c, d, e
|
|
|
|
|
|
@functools.partial(jax.jit, static_argnums=(0,))
|
|
def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
|
|
"""Empirically choose initial step size.
|
|
|
|
Args:
|
|
fun: Function to evaluate like `func(y, t)` to compute the time
|
|
derivative of `y`.
|
|
t0: initial time.
|
|
y0: initial value for the state.
|
|
order: order of interpolation
|
|
rtol: relative local error tolerance for solver.
|
|
atol: absolute local error tolerance for solver.
|
|
f0: initial value for the derivative, computed from `func(t0, y0)`.
|
|
Returns:
|
|
Initial step size for odeint algorithm.
|
|
|
|
Algorithm from:
|
|
E. Hairer, S. P. Norsett G. Wanner,
|
|
Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
|
|
"""
|
|
scale = atol + np.abs(y0) * rtol
|
|
d0 = np.linalg.norm(y0 / scale)
|
|
d1 = np.linalg.norm(f0 / scale)
|
|
order_pow = (1. / (order + 1.))
|
|
|
|
h0 = np.where(np.any(np.asarray([d0 < 1e-5, d1 < 1e-5])),
|
|
1e-6,
|
|
0.01 * d0 / d1)
|
|
|
|
y1 = y0 + h0 * f0
|
|
f1 = fun(y1, t0 + h0)
|
|
d2 = np.linalg.norm((f1 - f0) / scale) / h0
|
|
|
|
h1 = np.where(np.all(np.asarray([d1 <= 1e-15, d2 <= 1e-15])),
|
|
np.maximum(1e-6, h0 * 1e-3),
|
|
(0.01 / np.max(d1 + d2))**order_pow)
|
|
|
|
return np.minimum(100. * h0, h1)
|
|
|
|
|
|
@functools.partial(jax.jit, static_argnums=(0,))
|
|
def runge_kutta_step(func, y0, f0, t0, dt):
|
|
"""Take an arbitrary Runge-Kutta step and estimate error.
|
|
|
|
Args:
|
|
func: Function to evaluate like `func(y, t)` to compute the time
|
|
derivative of `y`.
|
|
y0: initial value for the state.
|
|
f0: initial value for the derivative, computed from `func(t0, y0)`.
|
|
t0: initial time.
|
|
dt: time step.
|
|
alpha, beta, c: Butcher tableau describing how to take the Runge-Kutta
|
|
step.
|
|
|
|
Returns:
|
|
y1: estimated function at t1 = t0 + dt
|
|
f1: derivative of the state at t1
|
|
y1_error: estimated error at t1
|
|
k: list of Runge-Kutta coefficients `k` used for calculating these terms.
|
|
"""
|
|
def _fori_body_fun(i, val):
|
|
ti = t0 + dt * alpha[i-1]
|
|
yi = y0 + dt * np.dot(beta[i-1, :], val)
|
|
ft = func(yi, ti)
|
|
return jax.ops.index_update(val, jax.ops.index[i, :], ft)
|
|
|
|
k = jax.lax.fori_loop(
|
|
1,
|
|
7,
|
|
_fori_body_fun,
|
|
jax.ops.index_update(np.zeros((7, f0.shape[0])), jax.ops.index[0, :], f0))
|
|
|
|
y1 = dt * np.dot(c_sol, k) + y0
|
|
y1_error = dt * np.dot(c_error, k)
|
|
f1 = k[-1]
|
|
return y1, f1, y1_error, k
|
|
|
|
|
|
@jax.jit
|
|
def error_ratio(error_estimate, rtol, atol, y0, y1):
|
|
err_tol = atol + rtol * np.maximum(np.abs(y0), np.abs(y1))
|
|
err_ratio = error_estimate / err_tol
|
|
return np.mean(err_ratio**2)
|
|
|
|
|
|
@jax.jit
|
|
def optimal_step_size(last_step,
|
|
mean_error_ratio,
|
|
safety=0.9,
|
|
ifactor=10.0,
|
|
dfactor=0.2,
|
|
order=5.0):
|
|
"""Compute optimal Runge-Kutta stepsize."""
|
|
mean_error_ratio = np.max(mean_error_ratio)
|
|
dfactor = np.where(mean_error_ratio < 1,
|
|
1.0,
|
|
dfactor)
|
|
|
|
err_ratio = np.sqrt(mean_error_ratio)
|
|
factor = np.maximum(1.0 / ifactor,
|
|
np.minimum(err_ratio**(1.0 / order) / safety,
|
|
1.0 / dfactor))
|
|
return np.where(mean_error_ratio == 0,
|
|
last_step * ifactor,
|
|
last_step / factor,)
|
|
|
|
|
|
@functools.partial(jax.jit, static_argnums=(0,))
|
|
def odeint(ofunc, y0, t, *args, **kwargs):
|
|
"""Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.
|
|
|
|
Args:
|
|
ofunc: Function to evaluate `yt = ofunc(y, t, *args)` that
|
|
returns the time derivative of `y`.
|
|
y0: initial value for the state.
|
|
t: Timespan for `ofunc` evaluation like `np.linspace(0., 10., 101)`.
|
|
*args: Additional arguments to `ofunc` beyond y0 and t.
|
|
**kwargs: Two relevant keyword arguments:
|
|
'rtol': Relative local error tolerance for solver.
|
|
'atol': Absolute local error tolerance for solver.
|
|
|
|
Returns:
|
|
Integrated system values at each timepoint.
|
|
"""
|
|
rtol = kwargs.get('rtol', 1.4e-8)
|
|
atol = kwargs.get('atol', 1.4e-8)
|
|
|
|
@functools.partial(jax.jit, static_argnums=(0,))
|
|
def _fori_body_fun(func, i, val):
|
|
"""Internal fori_loop body to interpolate an integral at each timestep."""
|
|
t, cur_y, cur_f, cur_t, dt, last_t, interp_coeff, solution = val
|
|
cur_y, cur_f, cur_t, dt, last_t, interp_coeff = jax.lax.while_loop(
|
|
lambda x: x[2] < t[i],
|
|
functools.partial(_while_body_fun, func),
|
|
(cur_y, cur_f, cur_t, dt, last_t, interp_coeff))
|
|
|
|
relative_output_time = (t[i] - last_t) / (cur_t - last_t)
|
|
out_x = np.polyval(interp_coeff, relative_output_time)
|
|
|
|
return (t, cur_y, cur_f, cur_t, dt, last_t, interp_coeff,
|
|
jax.ops.index_update(solution,
|
|
jax.ops.index[i, :],
|
|
out_x))
|
|
|
|
@functools.partial(jax.jit, static_argnums=(0,))
|
|
def _while_body_fun(func, x):
|
|
"""Internal while_loop body to determine interpolation coefficients."""
|
|
cur_y, cur_f, cur_t, dt, last_t, interp_coeff = x
|
|
next_t = cur_t + dt
|
|
next_y, next_f, next_y_error, k = runge_kutta_step(
|
|
func, cur_y, cur_f, cur_t, dt)
|
|
error_ratios = error_ratio(next_y_error, rtol, atol, cur_y, next_y)
|
|
new_interp_coeff = interp_fit_dopri(cur_y, next_y, k, dt)
|
|
dt = optimal_step_size(dt, error_ratios)
|
|
|
|
new_rav, unravel = ravel_pytree(
|
|
(next_y, next_f, next_t, dt, cur_t, new_interp_coeff))
|
|
old_rav, _ = ravel_pytree(
|
|
(cur_y, cur_f, cur_t, dt, last_t, interp_coeff))
|
|
|
|
return unravel(np.where(np.all(error_ratios <= 1.),
|
|
new_rav,
|
|
old_rav))
|
|
|
|
func = lambda y, t: ofunc(y, t, *args)
|
|
f0 = func(y0, t[0])
|
|
dt = initial_step_size(func, t[0], y0, 4, rtol, atol, f0)
|
|
interp_coeff = np.array([y0] * 5)
|
|
|
|
return jax.lax.fori_loop(1,
|
|
t.shape[0],
|
|
functools.partial(_fori_body_fun, func),
|
|
(t, y0, f0, t[0], dt, t[0], interp_coeff,
|
|
jax.ops.index_update(
|
|
np.zeros((t.shape[0], y0.shape[0])),
|
|
jax.ops.index[0, :],
|
|
y0)))[-1]
|
|
|
|
|
|
def vjp_odeint(ofunc, y0, t, *args, **kwargs):
|
|
"""Return a function that calculates `vjp(odeint(func(y, t, *args))`.
|
|
|
|
Args:
|
|
ofunc: Function `ydot = ofunc(y, t, *args)` to compute the time
|
|
derivative of `y`.
|
|
y0: initial value for the state.
|
|
t: Timespan for `ofunc` evaluation like `np.linspace(0., 10., 101)`.
|
|
*args: Additional arguments to `ofunc` beyond y0 and t.
|
|
**kwargs: Two relevant keyword arguments:
|
|
'rtol': Relative local error tolerance for solver.
|
|
'atol': Absolute local error tolerance for solver.
|
|
|
|
Returns:
|
|
VJP function `vjp = vjp_all(g)` where `yt = ofunc(y, t, *args)`
|
|
and g is used for VJP calculation. To evaluate the gradient w/ the VJP,
|
|
supply `g = np.ones_like(yt)`. To evaluate the reverse Jacobian do a vmap
|
|
over the standard basis of yt.
|
|
"""
|
|
rtol = kwargs.get('rtol', 1.4e-8)
|
|
atol = kwargs.get('atol', 1.4e-8)
|
|
|
|
flat_args, unravel_args = ravel_pytree(args)
|
|
flat_func = lambda y, t, flat_args: ofunc(y, t, *unravel_args(flat_args))
|
|
|
|
@jax.jit
|
|
def aug_dynamics(augmented_state, t, flat_args):
|
|
"""Original system augmented with vjp_y, vjp_t and vjp_args."""
|
|
state_len = int(np.floor_divide(
|
|
augmented_state.shape[0] - flat_args.shape[0] - 1, 2))
|
|
y = augmented_state[:state_len]
|
|
adjoint = augmented_state[state_len:2*state_len]
|
|
dy_dt, vjpfun = jax.vjp(flat_func, y, t, flat_args)
|
|
return np.hstack([np.ravel(dy_dt), np.hstack(vjpfun(-adjoint))])
|
|
|
|
rev_aug_dynamics = lambda y, t, flat_args: -aug_dynamics(y, -t, flat_args)
|
|
|
|
@jax.jit
|
|
def _fori_body_fun(i, val):
|
|
"""fori_loop function for VJP calculation."""
|
|
rev_yt, rev_t, rev_tarray, rev_gi, vjp_y, vjp_t0, vjp_args, time_vjp_list = val
|
|
this_yt = rev_yt[i, :]
|
|
this_t = rev_t[i]
|
|
this_tarray = rev_tarray[i, :]
|
|
this_gi = rev_gi[i, :]
|
|
# this is g[i-1, :] when g has been reversed
|
|
this_gim1 = rev_gi[i+1, :]
|
|
state_len = this_yt.shape[0]
|
|
vjp_cur_t = np.dot(flat_func(this_yt, this_t, flat_args), this_gi)
|
|
vjp_t0 = vjp_t0 - vjp_cur_t
|
|
# Run augmented system backwards to the previous observation.
|
|
aug_y0 = np.hstack((this_yt, vjp_y, vjp_t0, vjp_args))
|
|
aug_ans = odeint(rev_aug_dynamics,
|
|
aug_y0,
|
|
this_tarray,
|
|
flat_args,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
vjp_y = aug_ans[1][state_len:2*state_len] + this_gim1
|
|
vjp_t0 = aug_ans[1][2*state_len]
|
|
vjp_args = aug_ans[1][2*state_len+1:]
|
|
time_vjp_list = jax.ops.index_update(time_vjp_list, i, vjp_cur_t)
|
|
return rev_yt, rev_t, rev_tarray, rev_gi, vjp_y, vjp_t0, vjp_args, time_vjp_list
|
|
|
|
@jax.jit
|
|
def vjp_all(g, yt, t):
|
|
"""Calculate the VJP g * Jac(odeint(ofunc, y0, t, *args))."""
|
|
rev_yt = yt[-1::-1, :]
|
|
rev_t = t[-1::-1]
|
|
rev_tarray = -np.array([t[-1:0:-1], t[-2::-1]]).T
|
|
rev_gi = g[-1::-1, :]
|
|
|
|
vjp_y = g[-1, :]
|
|
vjp_t0 = 0.
|
|
vjp_args = np.zeros_like(flat_args)
|
|
time_vjp_list = np.zeros_like(t)
|
|
|
|
result = jax.lax.fori_loop(0,
|
|
rev_t.shape[0]-1,
|
|
_fori_body_fun,
|
|
(rev_yt,
|
|
rev_t,
|
|
rev_tarray,
|
|
rev_gi,
|
|
vjp_y,
|
|
vjp_t0,
|
|
vjp_args,
|
|
time_vjp_list))
|
|
|
|
time_vjp_list = jax.ops.index_update(result[-1], -1, result[-3])
|
|
vjp_times = np.hstack(time_vjp_list)[::-1]
|
|
|
|
return tuple([result[-4], vjp_times] + list(result[-2]))
|
|
|
|
primals_out = odeint(flat_func, y0, t, flat_args)
|
|
vjp_fun = lambda g: vjp_all(g, primals_out, t)
|
|
|
|
return primals_out, vjp_fun
|
|
|
|
|
|
def build_odeint(ofunc, rtol=1.4e-8, atol=1.4e-8):
|
|
"""Return `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)`.
|
|
|
|
Given the function ofunc(y, t, *args), return the jitted function
|
|
`f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)` with
|
|
the VJP of `f` defined using `vjp_odeint`, where:
|
|
|
|
`y0` is the initial condition of the ODE integration,
|
|
`t` is the time course of the integration, and
|
|
`*args` are all other arguments to `ofunc`.
|
|
|
|
Args:
|
|
ofunc: The function to be wrapped into an ODE integration.
|
|
rtol: relative local error tolerance for solver.
|
|
atol: absolute local error tolerance for solver.
|
|
|
|
Returns:
|
|
`f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)`
|
|
"""
|
|
ct_odeint = jax.custom_transforms(
|
|
lambda y0, t, *args: odeint(ofunc, y0, t, *args, rtol=rtol, atol=atol))
|
|
|
|
v = lambda y0, t, *args: vjp_odeint(ofunc, y0, t, *args, rtol=rtol, atol=atol)
|
|
jax.defvjp_all(ct_odeint, v)
|
|
|
|
return jax.jit(ct_odeint)
|
|
|
|
|
|
def my_odeint_grad(fun):
|
|
"""Calculate the Jacobian of an odeint."""
|
|
@jax.jit
|
|
def _gradfun(*args, **kwargs):
|
|
ys, pullback = vjp_odeint(fun, *args, **kwargs)
|
|
my_grad = pullback(np.ones_like(ys))
|
|
return my_grad
|
|
return _gradfun
|
|
|
|
|
|
def my_odeint_jacrev(fun):
|
|
"""Calculate the Jacobian of an odeint."""
|
|
@jax.jit
|
|
def _jacfun(*args, **kwargs):
|
|
ys, pullback = vjp_odeint(fun, *args, **kwargs)
|
|
my_jac = jax.vmap(pullback)(jax.api._std_basis(ys))
|
|
my_jac = jax.api.tree_map(
|
|
functools.partial(jax.api._unravel_array_into_pytree, ys, 0), my_jac)
|
|
my_jac = jax.api.tree_transpose(
|
|
jax.api.tree_structure(args), jax.api.tree_structure(ys), my_jac)
|
|
return my_jac
|
|
return _jacfun
|
|
|
|
|
|
def nd(f, x, eps=0.0001):
|
|
flat_x, unravel = ravel_pytree(x)
|
|
dim = len(flat_x)
|
|
g = onp.zeros_like(flat_x)
|
|
for i in range(dim):
|
|
d = onp.zeros_like(flat_x)
|
|
d[i] = eps
|
|
g[i] = (f(unravel(flat_x + d)) - f(unravel(flat_x - d))) / (2.0 * eps)
|
|
return g
|
|
|
|
|
|
def test_grad_vjp_odeint():
|
|
"""Compare numerical and exact differentiation of a simple odeint."""
|
|
|
|
def f(y, t, arg1, arg2):
|
|
return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)
|
|
|
|
def onearg_odeint(args):
|
|
return np.sum(
|
|
odeint(f, *args, atol=1e-8, rtol=1e-8))
|
|
|
|
dim = 10
|
|
t0 = 0.1
|
|
t1 = 0.2
|
|
y0 = np.linspace(0.1, 0.9, dim)
|
|
arg1 = 0.1
|
|
arg2 = 0.2
|
|
wrap_args = (y0, np.array([t0, t1]), arg1, arg2)
|
|
|
|
numerical_grad = nd(onearg_odeint, wrap_args)
|
|
exact_grad, _ = ravel_pytree(my_odeint_grad(f)(*wrap_args))
|
|
|
|
assert np.allclose(numerical_grad, exact_grad)
|
|
|
|
|
|
def plot_gradient_field(ax, func, xlimits, ylimits, numticks=30):
|
|
"""Plot the gradient field of `func` on `ax`."""
|
|
x = np.linspace(*xlimits, num=numticks)
|
|
y = np.linspace(*ylimits, num=numticks)
|
|
x_mesh, y_mesh = np.meshgrid(x, y)
|
|
zs = jax.vmap(func)(y_mesh.ravel(), x_mesh.ravel())
|
|
z_mesh = zs.reshape(x_mesh.shape)
|
|
ax.quiver(x_mesh, y_mesh, np.ones(z_mesh.shape), z_mesh)
|
|
ax.set_xlim(xlimits)
|
|
ax.set_ylim(ylimits)
|
|
|
|
|
|
def plot_demo():
|
|
"""Demo plot of simple ode integration and gradient field."""
|
|
def f(y, t, arg1, arg2):
|
|
return y - np.sin(t) - np.cos(t) * arg1 + arg2
|
|
|
|
t0 = 0.
|
|
t1 = 5.0
|
|
ts = np.linspace(t0, t1, 100)
|
|
y0 = np.array([1.])
|
|
fargs = (1.0, 0.0)
|
|
|
|
ys = odeint(f, y0, ts, *fargs, atol=0.001, rtol=0.001)
|
|
|
|
# Set up figure.
|
|
fig = plt.figure(figsize=(8, 6), facecolor='white')
|
|
ax = fig.add_subplot(111, frameon=False)
|
|
f_no_args = lambda y, t: f(y, t, *fargs)
|
|
plot_gradient_field(ax, f_no_args, xlimits=[t0, t1], ylimits=[-1.1, 1.1])
|
|
ax.plot(ts, ys, 'g-')
|
|
ax.set_xlabel('t')
|
|
ax.set_ylabel('y')
|
|
plt.show()
|
|
|
|
|
|
@jax.jit
|
|
def pend(y, t, arg1, arg2):
|
|
"""Simple pendulum system for odeint testing."""
|
|
del t
|
|
theta, omega = y
|
|
dydt = np.array([omega, -arg1*omega - arg2*np.sin(theta)])
|
|
return dydt
|
|
|
|
|
|
@jax.jit
|
|
def swoop(y, t, arg1, arg2):
|
|
return np.array(y - np.sin(t) - np.cos(t) * arg1 + arg2)
|
|
|
|
|
|
@jax.jit
|
|
def decay(y, t, arg1, arg2):
|
|
return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)
|
|
|
|
|
|
def benchmark_odeint(fun, y0, tspace, *args):
|
|
"""Time performance of JAX odeint method against scipy.integrate.odeint."""
|
|
n_trials = 5
|
|
for k in range(n_trials):
|
|
start = time.time()
|
|
scipy_result = osp_integrate.odeint(fun, y0, tspace, args)
|
|
end = time.time()
|
|
print('scipy odeint elapsed time ({} of {}): {}'.format(
|
|
k+1, n_trials, end-start))
|
|
for k in range(n_trials):
|
|
start = time.time()
|
|
jax_result = odeint(fun, np.array(y0), np.array(tspace), *args)
|
|
jax_result.block_until_ready()
|
|
end = time.time()
|
|
print('JAX odeint elapsed time ({} of {}): {}'.format(
|
|
k+1, n_trials, end-start))
|
|
print('norm(scipy result-jax result): {}'.format(
|
|
np.linalg.norm(np.asarray(scipy_result) - jax_result)))
|
|
|
|
return scipy_result, jax_result
|
|
|
|
|
|
def pend_benchmark_odeint():
|
|
_, _ = benchmark_odeint(pend,
|
|
(onp.pi - 0.1, 0.0),
|
|
onp.linspace(0., 10., 101),
|
|
0.25,
|
|
9.8)
|
|
|
|
|
|
def test_odeint_grad():
|
|
"""Test the gradient behavior of various ODE integrations."""
|
|
def _test_odeint_grad(func, *args):
|
|
def onearg_odeint(fargs):
|
|
return np.sum(odeint(func, *fargs))
|
|
|
|
numerical_grad = nd(onearg_odeint, args)
|
|
exact_grad, _ = ravel_pytree(my_odeint_grad(func)(*args))
|
|
assert np.allclose(numerical_grad, exact_grad)
|
|
|
|
ts = np.array((0.1, 0.2))
|
|
y0 = np.linspace(0.1, 0.9, 10)
|
|
big_y0 = np.linspace(1.1, 10.9, 10)
|
|
|
|
# check pend()
|
|
for cond in (
|
|
(np.array((onp.pi - 0.1, 0.0)), ts, 0.25, 0.98),
|
|
(np.array((onp.pi * 0.1, 0.0)), ts, 0.1, 0.4),
|
|
):
|
|
_test_odeint_grad(pend, *cond)
|
|
|
|
# check swoop
|
|
for cond in (
|
|
(y0, ts, 0.1, 0.2),
|
|
(big_y0, ts, 0.1, 0.3),
|
|
):
|
|
_test_odeint_grad(swoop, *cond)
|
|
|
|
# check decay
|
|
for cond in (
|
|
(y0, ts, 0.1, 0.2),
|
|
(big_y0, ts, 0.1, 0.3),
|
|
):
|
|
_test_odeint_grad(decay, *cond)
|
|
|
|
|
|
def test_odeint_vjp():
|
|
"""Use check_vjp to check odeint VJP calculations."""
|
|
|
|
# check pend()
|
|
y = np.array([np.pi - 0.1, 0.0])
|
|
t = np.linspace(0., 10., 11)
|
|
b = 0.25
|
|
c = 9.8
|
|
wrap_args = (y, t, b, c)
|
|
pend_odeint_wrap = lambda y, t, *args: odeint(pend, y, t, *args)
|
|
pend_vjp_wrap = lambda y, t, *args: vjp_odeint(pend, y, t, *args)
|
|
check_vjp(pend_odeint_wrap, pend_vjp_wrap, wrap_args)
|
|
|
|
# check swoop()
|
|
y = np.array([0.1])
|
|
t = np.linspace(0., 10., 11)
|
|
arg1 = 0.1
|
|
arg2 = 0.2
|
|
wrap_args = (y, t, arg1, arg2)
|
|
swoop_odeint_wrap = lambda y, t, *args: odeint(swoop, y, t, *args)
|
|
swoop_vjp_wrap = lambda y, t, *args: vjp_odeint(swoop, y, t, *args)
|
|
check_vjp(swoop_odeint_wrap, swoop_vjp_wrap, wrap_args)
|
|
|
|
# decay() check_vjp hangs!
|
|
|
|
|
|
def test_defvjp_all():
|
|
"""Use build_odeint to check odeint VJP calculations."""
|
|
n_trials = 5
|
|
swoop_build = build_odeint(swoop)
|
|
jacswoop = jax.jit(jax.jacrev(swoop_build))
|
|
y = np.array([0.1])
|
|
t = np.linspace(0., 2., 11)
|
|
arg1 = 0.1
|
|
arg2 = 0.2
|
|
wrap_args = (y, t, arg1, arg2)
|
|
for k in range(n_trials):
|
|
start = time.time()
|
|
rslt = jacswoop(*wrap_args)
|
|
rslt.block_until_ready()
|
|
end = time.time()
|
|
print('JAX jacrev elapsed time ({} of {}): {}'.format(
|
|
k+1, n_trials, end-start))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_odeint_grad()
|
|
test_odeint_vjp()
|