remat: fix regression of broke calling convention

In #7631 we made `_partial_eval_jaxpr_custom` follow a convention: drop
unit outputs from the known jaxprs it returned:

3377abaef4/jax/interpreters/partial_eval.py (L937)

The caller needed to compensate for that, e.g. by dropping the
corresponding binders in an outer jaxpr eqn representing an application
of that inner jaxpr. But that logic was written in terms of checking for
dropped outputs in the outer jaxpr (since units are typically not
consumed downstream):

https://github.com/google/jax/pull/8227/files#diff-440d9df723b313bb263bc7704103cad1dcc886ff6553aa78c30188b0b323b686L981

That worked (or at least we never noticed a failure, though now it seems
sketchy...) with the 'classic' `jax.checkpoint` / `jax.remat`
implementation, before #8191, because of how that implementation relied
on tracing-based partial evaluation, which would detect and mark dropped
outputs in the outer jaxpr as part of jaxpr formation.

But then in #8191 we no longer marked dropvars in the same way. That led
to assertion failures, and #8227 attempted to fix those. That fix made
sense with the new remat implementation, but not the old one! (In the
intervening period I forgot about this unit-dropping convention...)

The fix here is not to rely on dropvars but to more directly encode the
convention that _partial_eval_jaxpr_custom drops unit outputs in the
known jaxpr it produces.
This commit is contained in:
Matthew Johnson 2021-10-15 16:51:37 -07:00
parent 3377abaef4
commit 89606c2c35
2 changed files with 32 additions and 5 deletions

View File

@ -18,6 +18,7 @@ import functools
from functools import partial
import inspect
import itertools as it
import operator as op
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
List, Union, cast)
from weakref import ref
@ -974,10 +975,14 @@ def call_partial_eval_custom_rule(
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
eqn: JaxprEqn
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = _partial_eval_jaxpr_custom(
eqn.params[jaxpr_param_name], unks_in, saveable)
jaxpr = eqn.params[jaxpr_param_name]
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
_partial_eval_jaxpr_custom(jaxpr, unks_in, saveable)
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
# by convention, _partial_eval_jaxpr_custom drops units on known outputs
known_units_out = [v.aval is core.abstract_unit for v in jaxpr.outvars]
dropped_outs_known = map(op.or_, unks_out, known_units_out)
out_binders_known, _ = partition_list(dropped_outs_known, eqn.outvars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym([jaxpr_known, jaxpr_staged])
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]

View File

@ -3989,8 +3989,15 @@ class RematTest(jtu.JaxTestCase):
self.assertLen(res, 1)
self.assertEqual(res[0][0].shape, ())
def test_checkpoint_dropvars(self):
@new_checkpoint
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_checkpoint_dropvars(self, remat):
@remat
def f(x):
_, x = api.jit(lambda: (x, x))()
return x
@ -4005,6 +4012,21 @@ class RematTest(jtu.JaxTestCase):
_ = jax.grad(f)(3.) # doesn't crash
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_unit_dropvar_consistency_regression(self, remat):
@partial(remat, policy=lambda *_, **__: False)
def f(u, x):
x, _ = jax.jit(lambda x: (x, u))(x)
return x
_ = api.linearize(partial(f, core.unit), 3.)
class JaxprTest(jtu.JaxTestCase):
def test_scalar_literals(self):