keep dropvar binders in call_partial_eval_custom_rule

The dropvars indicate that these binders/outputs aren't used in the
outer jaxpr and so they could be dropped, but to drop the binders would
require also editing the called jaxpr to be consistent. For completeness
that editing could involve DCE, which in turn can affect the jaxpr's
inputs.

Instead of doing that bookkeeping, we can just keep the dropvars.
There's a DCE pass to follow in the remat primitive's partial eval rule
which will clean these up.

(This commit also contains unrelated tweaks to comments and strings.)
This commit is contained in:
Matthew Johnson 2021-10-14 18:49:56 -07:00
parent 1bafdb6d7e
commit d1f0c60b7b
3 changed files with 12 additions and 5 deletions

View File

@ -254,7 +254,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
for v in eqn.outvars:
if v in res_vars:
if eqn.primitive is name_p:
results.append((v.aval, f'named {eqn.params["name"]} from {src}'))
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
else:
results.append((v.aval, f'from {src}'))
@ -316,7 +316,7 @@ ad.primitive_jvps[remat_p] = remat_jvp
def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars
policy = params['policy'] or (lambda *_, **__: False)
# unzip into known and jaxpr_unknown
# unzip into jaxpr_known and jaxpr_unknown
in_unknowns = [not t.is_known() for t in tracers]
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)

View File

@ -788,7 +788,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
if params['policy']:
# unzip into known and jaxpr_unknown
# unzip into jaxpr_known and jaxpr_unknown
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = _partial_eval_jaxpr_custom(
jaxpr, in_unknowns, params['policy'])
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
@ -978,7 +978,6 @@ def call_partial_eval_custom_rule(
eqn.params[jaxpr_param_name], unks_in, saveable)
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
out_binders_known = [v for v in out_binders_known if v is not dropvar]
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym([jaxpr_known, jaxpr_staged])
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]

View File

@ -3981,7 +3981,7 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(res[3][0].shape, ())
self.assertEqual(res[3][1], "from the argument 'y'")
self.assertEqual(res[4][0].shape, ())
self.assertStartsWith(res[4][1], "named z")
self.assertStartsWith(res[4][1], "named 'z'")
self.assertEqual(res[5][0].shape, ())
def test_saved_residuals_utility_literals(self):
@ -3989,6 +3989,14 @@ class RematTest(jtu.JaxTestCase):
self.assertLen(res, 1)
self.assertEqual(res[0][0].shape, ())
def test_checkpoint_dropvars(self):
@new_checkpoint
def f(x):
_, x = api.jit(lambda: (x, x))()
return x
_ = api.grad(f)(3.) # doesn't crash
class JaxprTest(jtu.JaxTestCase):