mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
handle linear custom_jvp functions
This commit is contained in:
parent
60e7044a33
commit
36154fac34
@ -381,17 +381,25 @@ ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpo
|
||||
def custom_jvp_jaxpr_custom_partial_eval_rule(
|
||||
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
|
||||
eqn: core.JaxprEqn
|
||||
) -> Tuple[core.JaxprEqn, core.JaxprEqn, List[bool], List[bool], List[core.Var]]:
|
||||
) -> Tuple[Optional[core.JaxprEqn], core.JaxprEqn, List[bool], List[bool], List[core.Var]]:
|
||||
# It doesn't make sense to unzip (i.e. break up) a custom_jvp function into
|
||||
# constituent parts, so we always perform full remat. An alternative would be
|
||||
# to allow the policy function to decide whether the value of a
|
||||
# custom_jvp-decorated function's application should be saved or not.
|
||||
if any(unks_in): raise NotImplementedError # TODO(mattjj): linear fn
|
||||
unks_out = [False] * len(eqn.outvars)
|
||||
# TODO(mattjj,jekbradbury): the user writing the custom_jvp-decorated function
|
||||
# probably has a better idea for what to do under remat (e.g. if the function
|
||||
# contains dots or not), so we should allow for more expressive interaction
|
||||
# (e.g. allow the policy to depend on which custom_jvp-decorated function is
|
||||
# being applied, or annotating the behavior where custom_vjp is called.)
|
||||
inst_out = [True] * len(eqn.outvars)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
inst_out = [True] * len(eqn.outvars)
|
||||
return eqn, eqn, unks_out, inst_out, new_inst
|
||||
if any(unks_in):
|
||||
unks_out = [True] * len(eqn.outvars)
|
||||
return None, eqn, unks_out, inst_out, new_inst
|
||||
else:
|
||||
unks_out = [False] * len(eqn.outvars)
|
||||
return eqn, eqn, unks_out, inst_out, new_inst
|
||||
pe.partial_eval_jaxpr_custom_rules[custom_jvp_call_jaxpr_p] = \
|
||||
custom_jvp_jaxpr_custom_partial_eval_rule # type: ignore
|
||||
|
||||
|
@ -3660,6 +3660,24 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
api.grad(g)(3.)
|
||||
|
||||
def test_remat_custom_jvp_linear_policy(self):
|
||||
@api.custom_jvp
|
||||
def sum(x):
|
||||
return jnp.sum(x, axis=0)
|
||||
@sum.defjvp
|
||||
def sum_jvp(primals, tangents):
|
||||
(x,), (xdot,) = primals, tangents
|
||||
return sum(x), sum(xdot)
|
||||
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
return sum(x)
|
||||
jtu.check_grads(f, (jnp.ones(3),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
def g(x):
|
||||
return lax.scan(lambda _, x: (None, f(x)), None, x)[1]
|
||||
jtu.check_grads(g, (jnp.ones((2, 3)),), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user