fix remat with nontrivial env (#2136)

fixes #2030
This commit is contained in:
Matthew Johnson 2020-01-31 23:47:30 -08:00 committed by GitHub
parent efbdaf66bf
commit ae1d6b875f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 2 deletions

View File

@ -2027,7 +2027,8 @@ def checkpoint(fun, concrete=False):
def fun_remat(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out_flat = pe.remat_call(flat_fun, *args_flat, concrete=concrete)
out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__,
concrete=concrete)
return tree_unflatten(out_tree(), out_flat)
return fun_remat
remat = checkpoint

View File

@ -526,7 +526,7 @@ def _remat_partial_eval(trace, f, tracers, params):
# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
in_avals = ([raise_to_shaped(partial_val_aval(*t.pval)) for t in env]
+ [raise_to_shaped(pv) for pv in in_pvs])
out_avals = [raise_to_shaped(pv if pv is not None
else abstract_unit if var is unitvar

View File

@ -1600,6 +1600,46 @@ class APITest(jtu.JaxTestCase):
api.grad(func)(5.0) # doesn't crash
def test_remat_jit2(self):
@api.jit
def f(x):
y = 2 * x
@api.remat
def g():
return y
return g()
self.assertAllClose(f(3), 6, check_dtypes=False)
def test_remat_nontrivial_env(self):
# simplified from https://github.com/google/jax/issues/2030
@api.remat
def foo(state, dt=0.5, c=1):
u, u_t = state
u_tt = c**2 * u
u_t = u_t + u_tt * dt
return (u, u_t)
@partial(api.jit, static_argnums=(1,))
def _multi_step(state, count, dt, c):
f = lambda s, _: (foo(s, dt, c), _)
return lax.scan(f, state, None, count)
def multi_step(state, count, dt=1/np.sqrt(2), c=1):
return _multi_step(state, count, dt, c)
def loss(u0, target, steps, dt=1/np.sqrt(2), c=1):
init = (u0, np.zeros_like(u0))
(uf, _), _ = multi_step(init, steps, dt, c)
return ((uf - target) ** 2).mean()
target = np.zeros((128, 128))
u0 = np.ones_like(target)
loss(u0, target, 10) # doesn't crash
def test_trivial_computations(self):
x = np.array([1, 2, 3])
y = api.jit(lambda x: x)(x)