handle linear custom_jvp functions

This commit is contained in:
Matthew Johnson 2021-09-03 16:43:57 -07:00
parent 60e7044a33
commit 36154fac34
2 changed files with 31 additions and 5 deletions

View File

@ -381,17 +381,25 @@ ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpo
def custom_jvp_jaxpr_custom_partial_eval_rule(
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
eqn: core.JaxprEqn
) -> Tuple[core.JaxprEqn, core.JaxprEqn, List[bool], List[bool], List[core.Var]]:
) -> Tuple[Optional[core.JaxprEqn], core.JaxprEqn, List[bool], List[bool], List[core.Var]]:
# It doesn't make sense to unzip (i.e. break up) a custom_jvp function into
# constituent parts, so we always perform full remat. An alternative would be
# to allow the policy function to decide whether the value of a
# custom_jvp-decorated function's application should be saved or not.
if any(unks_in): raise NotImplementedError # TODO(mattjj): linear fn
unks_out = [False] * len(eqn.outvars)
# TODO(mattjj,jekbradbury): the user writing the custom_jvp-decorated function
# probably has a better idea for what to do under remat (e.g. if the function
# contains dots or not), so we should allow for more expressive interaction
# (e.g. allow the policy to depend on which custom_jvp-decorated function is
# being applied, or annotating the behavior where custom_vjp is called.)
inst_out = [True] * len(eqn.outvars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_out = [True] * len(eqn.outvars)
return eqn, eqn, unks_out, inst_out, new_inst
if any(unks_in):
unks_out = [True] * len(eqn.outvars)
return None, eqn, unks_out, inst_out, new_inst
else:
unks_out = [False] * len(eqn.outvars)
return eqn, eqn, unks_out, inst_out, new_inst
pe.partial_eval_jaxpr_custom_rules[custom_jvp_call_jaxpr_p] = \
custom_jvp_jaxpr_custom_partial_eval_rule # type: ignore

View File

@ -3660,6 +3660,24 @@ class RematTest(jtu.JaxTestCase):
api.grad(g)(3.)
def test_remat_custom_jvp_linear_policy(self):
@api.custom_jvp
def sum(x):
return jnp.sum(x, axis=0)
@sum.defjvp
def sum_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return sum(x), sum(xdot)
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
return sum(x)
jtu.check_grads(f, (jnp.ones(3),), order=2, modes=['fwd', 'rev'])
def g(x):
return lax.scan(lambda _, x: (None, f(x)), None, x)[1]
jtu.check_grads(g, (jnp.ones((2, 3)),), order=2, modes=['fwd', 'rev'])
class JaxprTest(jtu.JaxTestCase):