mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Merge pull request #11077 from DanPuzzuoli:ode_dt_max
PiperOrigin-RevId: 484639938
This commit is contained in:
commit
9dac458e85
@ -146,7 +146,7 @@ def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
|
||||
jnp.maximum(mean_error_ratio**(-1.0 / order) * safety, dfactor))
|
||||
return jnp.where(mean_error_ratio == 0, last_step * ifactor, last_step * factor)
|
||||
|
||||
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
|
||||
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf, hmax=jnp.inf):
|
||||
"""Adaptive stepsize (Dormand-Prince) Runge-Kutta odeint implementation.
|
||||
|
||||
Args:
|
||||
@ -161,6 +161,7 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
|
||||
rtol: float, relative local error tolerance for solver (optional).
|
||||
atol: float, absolute local error tolerance for solver (optional).
|
||||
mxstep: int, maximum number of steps to take for each timepoint (optional).
|
||||
hmax: float, maximum step size allowed (optional).
|
||||
|
||||
Returns:
|
||||
Values of the solution `y` (i.e. integrated system values) at each time
|
||||
@ -175,17 +176,17 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
|
||||
raise TypeError(f"t must be an array of floats, but got {t}.")
|
||||
|
||||
converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
|
||||
return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
|
||||
return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
|
||||
def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):
|
||||
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4))
|
||||
def _odeint_wrapper(func, rtol, atol, mxstep, hmax, y0, ts, *args):
|
||||
y0, unravel = ravel_pytree(y0)
|
||||
func = ravel_first_arg(func, unravel)
|
||||
out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
|
||||
out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args)
|
||||
return jax.vmap(unravel)(out)
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3))
|
||||
def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3, 4))
|
||||
def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):
|
||||
func_ = lambda y, t: func(y, t, *args)
|
||||
|
||||
def scan_fun(carry, target_t):
|
||||
@ -200,7 +201,7 @@ def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
|
||||
next_t = t + dt
|
||||
error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y)
|
||||
new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
|
||||
dt = optimal_step_size(dt, error_ratio)
|
||||
dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax)
|
||||
|
||||
new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
|
||||
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
|
||||
@ -213,17 +214,17 @@ def _odeint(func, rtol, atol, mxstep, y0, ts, *args):
|
||||
return carry, y_target
|
||||
|
||||
f0 = func_(y0, ts[0])
|
||||
dt = initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0)
|
||||
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax)
|
||||
interp_coeff = jnp.array([y0] * 5)
|
||||
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
|
||||
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
|
||||
return jnp.concatenate((y0[None], ys))
|
||||
|
||||
def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
|
||||
ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
|
||||
def _odeint_fwd(func, rtol, atol, mxstep, hmax, y0, ts, *args):
|
||||
ys = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args)
|
||||
return ys, (ys, ts, args)
|
||||
|
||||
def _odeint_rev(func, rtol, atol, mxstep, res, g):
|
||||
def _odeint_rev(func, rtol, atol, mxstep, hmax, res, g):
|
||||
ys, ts, args = res
|
||||
|
||||
def aug_dynamics(augmented_state, t, *args):
|
||||
@ -248,7 +249,7 @@ def _odeint_rev(func, rtol, atol, mxstep, res, g):
|
||||
_, y_bar, t0_bar, args_bar = odeint(
|
||||
aug_dynamics, (ys[i], y_bar, t0_bar, args_bar),
|
||||
jnp.array([-ts[i], -ts[i - 1]]),
|
||||
*args, rtol=rtol, atol=atol, mxstep=mxstep)
|
||||
*args, rtol=rtol, atol=atol, mxstep=mxstep, hmax=hmax)
|
||||
y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
|
||||
# Add gradient from current output
|
||||
y_bar = y_bar + g[i - 1]
|
||||
|
@ -253,6 +253,21 @@ class ODETest(jtu.JaxTestCase):
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_hmax(self):
|
||||
"""Test max step size control."""
|
||||
|
||||
def rhs(y, t):
|
||||
return jnp.piecewise(
|
||||
t,
|
||||
[t <= 2., (t >= 5.) & (t <= 7.)],
|
||||
[lambda s: jnp.array(1.), lambda s: jnp.array(-1.), lambda s: jnp.array(0.)]
|
||||
)
|
||||
ys = odeint(func=rhs, y0=jnp.array(0.), t=jnp.array([0., 5., 10.]), hmax=1.)
|
||||
|
||||
self.assertTrue(jnp.abs(ys[1] - 2.) < 1e-4)
|
||||
self.assertTrue(jnp.abs(ys[2]) < 1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user