allow closures for odeint dynamics functions (#3562)

* allow closures for odeint dynamics functions

fixes #2718, #3557

* add tests for odeint dynamics closing over tracers
This commit is contained in:
Matthew Johnson 2020-06-25 17:36:17 -07:00 committed by GitHub
parent 2a6fc316c3
commit db80ca5dd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 4 deletions

View File

@ -31,15 +31,39 @@ import jax.numpy as jnp
from jax import core
from jax import lax
from jax import ops
from jax.util import safe_map, safe_zip
from jax.util import safe_map, safe_zip, cache, split_list
from jax.api_util import flatten_fun_nokwargs
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
from jax.tree_util import tree_map, tree_flatten, tree_unflatten
from jax.interpreters import partial_eval as pe
from jax import linear_util as lu
map = safe_map
zip = safe_zip
@cache()
def closure_convert(fun, in_tree, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
with core.initial_style_staging():
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
out_tree = out_tree()
num_consts = len(consts)
def converted_fun(y, t, *consts_args):
consts, args = split_list(consts_args, [num_consts])
all_args, in_tree2 = tree_flatten((y, t, *args))
assert in_tree == in_tree2
out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
return tree_unflatten(out_tree, out_flat)
return converted_fun, consts
def abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
def ravel_first_arg(f, unravel):
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
@ -159,8 +183,12 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
msg = ("The contents of odeint *args must be arrays or scalars, but got "
"\n{}.")
raise TypeError(msg.format(arg))
tree_map(_check_arg, args)
return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
flat_args, in_tree = tree_flatten((y0, t[0], *args))
in_avals = tuple(map(abstractify, flat_args))
converted, consts = closure_convert(func, in_tree, in_avals)
return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *consts, *args)
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):

View File

@ -181,6 +181,37 @@ class ODETest(jtu.JaxTestCase):
f = lambda x0: odeint(lambda x, _t: x, x0, t)
jax.vmap(f)(x0_eval) # doesn't crash
@jtu.skip_on_devices("tpu")
def test_grad_closure(self):
# simplification of https://github.com/google/jax/issues/2718
def experiment(x):
def model(y, t):
return -x * y
history = odeint(model, 1., np.arange(0, 10, 0.1))
return history[-1]
jtu.check_grads(experiment, (0.01,), modes=["rev"], order=1)
@jtu.skip_on_devices("tpu")
def test_grad_closure_with_vmap(self):
# https://github.com/google/jax/issues/2718
@jax.jit
def experiment(x):
def model(y, t):
return -x * y
history = odeint(model, 1., np.arange(0, 10, 0.1))
return history[-1]
gradfun = jax.value_and_grad(experiment)
t = np.arange(0., 1., 0.01)
h, g = jax.vmap(gradfun)(t) # doesn't crash
ans = h[11], g[11]
expected_h = experiment(t[11])
expected_g = (experiment(t[11] + 1e-5) - expected_h) / 1e-5
expected = expected_h, expected_g
self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)
if __name__ == '__main__':
absltest.main()