mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
allow closures for odeint dynamics functions (#3562)
* allow closures for odeint dynamics functions fixes #2718, #3557 * add tests for odeint dynamics closing over tracers
This commit is contained in:
parent
2a6fc316c3
commit
db80ca5dd8
@ -31,15 +31,39 @@ import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax.util import safe_map, safe_zip
|
||||
from jax.util import safe_map, safe_zip, cache, split_list
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.flatten_util import ravel_pytree
|
||||
from jax.tree_util import tree_map
|
||||
from jax.tree_util import tree_map, tree_flatten, tree_unflatten
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax import linear_util as lu
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
@cache()
|
||||
def closure_convert(fun, in_tree, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
with core.initial_style_staging():
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
|
||||
out_tree = out_tree()
|
||||
num_consts = len(consts)
|
||||
|
||||
def converted_fun(y, t, *consts_args):
|
||||
consts, args = split_list(consts_args, [num_consts])
|
||||
all_args, in_tree2 = tree_flatten((y, t, *args))
|
||||
assert in_tree == in_tree2
|
||||
out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
return converted_fun, consts
|
||||
|
||||
def abstractify(x):
|
||||
return core.raise_to_shaped(core.get_aval(x))
|
||||
|
||||
def ravel_first_arg(f, unravel):
|
||||
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
|
||||
|
||||
@ -159,8 +183,12 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf):
|
||||
msg = ("The contents of odeint *args must be arrays or scalars, but got "
|
||||
"\n{}.")
|
||||
raise TypeError(msg.format(arg))
|
||||
tree_map(_check_arg, args)
|
||||
return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
|
||||
|
||||
flat_args, in_tree = tree_flatten((y0, t[0], *args))
|
||||
in_avals = tuple(map(abstractify, flat_args))
|
||||
converted, consts = closure_convert(func, in_tree, in_avals)
|
||||
|
||||
return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *consts, *args)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
|
||||
def _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args):
|
||||
|
@ -181,6 +181,37 @@ class ODETest(jtu.JaxTestCase):
|
||||
f = lambda x0: odeint(lambda x, _t: x, x0, t)
|
||||
jax.vmap(f)(x0_eval) # doesn't crash
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_grad_closure(self):
|
||||
# simplification of https://github.com/google/jax/issues/2718
|
||||
def experiment(x):
|
||||
def model(y, t):
|
||||
return -x * y
|
||||
history = odeint(model, 1., np.arange(0, 10, 0.1))
|
||||
return history[-1]
|
||||
jtu.check_grads(experiment, (0.01,), modes=["rev"], order=1)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_grad_closure_with_vmap(self):
|
||||
# https://github.com/google/jax/issues/2718
|
||||
@jax.jit
|
||||
def experiment(x):
|
||||
def model(y, t):
|
||||
return -x * y
|
||||
history = odeint(model, 1., np.arange(0, 10, 0.1))
|
||||
return history[-1]
|
||||
|
||||
gradfun = jax.value_and_grad(experiment)
|
||||
t = np.arange(0., 1., 0.01)
|
||||
h, g = jax.vmap(gradfun)(t) # doesn't crash
|
||||
ans = h[11], g[11]
|
||||
|
||||
expected_h = experiment(t[11])
|
||||
expected_g = (experiment(t[11] + 1e-5) - expected_h) / 1e-5
|
||||
expected = expected_h, expected_g
|
||||
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user