Merge pull request #11077 from DanPuzzuoli:ode_dt_max

PiperOrigin-RevId: 484639938
This commit is contained in:
jax authors 2022-10-28 16:05:25 -07:00
commit 9dac458e85
2 changed files with 29 additions and 13 deletions

View File

@ -146,7 +146,7 @@ def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
jnp.maximum(mean_error_ratio**(-1.0 / order) * safety, dfactor))
return jnp.where(mean_error_ratio == 0, last_step * ifactor, last_step * factor)
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf, hmax=jnp.inf):
"""Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.
Args:
@ -161,6 +161,7 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
rtol: float, relative local error tolerance for solver (optional).
atol: float, absolute local error tolerance for solver (optional).
mxstep: int, maximum number of steps to take for each timepoint (optional).
hmax: float, maximum step size allowed (optional).
Returns:
Values of the solution `y` (i.e. integrated system values) at each time
@ -175,17 +176,17 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
raise TypeError(f"t must be an array of floats, but got {t}.")
converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts)
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4))
def _odeint_wrapper(func, rtol, atol, mxstep, hmax, y0, ts, *args):
y0, unravel = ravel_pytree(y0)
func = ravel_first_arg(func, unravel)
out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args)
return jax.vmap(unravel)(out)
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3))
def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3, 4))
def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):
func_ = lambda y, t: func(y, t, *args)
def scan_fun(carry, target_t):
@ -200,7 +201,7 @@ def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
next_t = t + dt
error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y)
new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
dt = optimal_step_size(dt, error_ratio)
dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax)
new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
@ -213,17 +214,17 @@ def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
return carry, y_target
f0 = func_(y0, ts[0])
dt = initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0)
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax)
interp_coeff = jnp.array([y0] * 5)
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
return jnp.concatenate((y0[None], ys))
def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
def _odeint_fwd(func, rtol, atol, mxstep, hmax, y0, ts, *args):
ys = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args)
return ys, (ys, ts, args)
def _odeint_rev(func, rtol, atol, mxstep, res, g):
def _odeint_rev(func, rtol, atol, mxstep, hmax, res, g):
ys, ts, args = res
def aug_dynamics(augmented_state, t, *args):
@ -248,7 +249,7 @@ def _odeint_rev(func, rtol, atol, mxstep, res, g):
_, y_bar, t0_bar, args_bar = odeint(
aug_dynamics, (ys[i], y_bar, t0_bar, args_bar),
jnp.array([-ts[i], -ts[i - 1]]),
*args, rtol=rtol, atol=atol, mxstep=mxstep)
*args, rtol=rtol, atol=atol, mxstep=mxstep, hmax=hmax)
y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
# Add gradient from current output
y_bar = y_bar + g[i - 1]

View File

@ -253,6 +253,21 @@ class ODETest(jtu.JaxTestCase):
with jax.numpy_dtype_promotion('standard'):
jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)
@jtu.skip_on_devices("tpu", "gpu")
def test_hmax(self):
"""Test max step size control."""
def rhs(y, t):
return jnp.piecewise(
t,
[t <= 2., (t >= 5.) & (t <= 7.)],
[lambda s: jnp.array(1.), lambda s: jnp.array(-1.), lambda s: jnp.array(0.)]
)
ys = odeint(func=rhs, y0=jnp.array(0.), t=jnp.array([0., 5., 10.]), hmax=1.)
self.assertTrue(jnp.abs(ys[1] - 2.) < 1e-4)
self.assertTrue(jnp.abs(ys[2]) < 1e-4)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())