mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
keep dropvar binders in call_partial_eval_custom_rule
The dropvars indicate that these binders/outputs aren't used in the outer jaxpr and so they could be dropped, but to drop the binders would require also editing the called jaxpr to be consistent. For completeness that editing could involve DCE, which in turn can affect the jaxpr's inputs. Instead of doing that bookkeeping, we can just keep the dropvars. There's a DCE pass to follow in the remat primitive's partial eval rule which will clean these up. (This commit also contains unrelated tweaks to comments and strings.)
This commit is contained in:
parent
1bafdb6d7e
commit
d1f0c60b7b
@ -254,7 +254,7 @@ def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
||||
for v in eqn.outvars:
|
||||
if v in res_vars:
|
||||
if eqn.primitive is name_p:
|
||||
results.append((v.aval, f'named {eqn.params["name"]} from {src}'))
|
||||
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
|
||||
else:
|
||||
results.append((v.aval, f'from {src}'))
|
||||
|
||||
@ -316,7 +316,7 @@ ad.primitive_jvps[remat_p] = remat_jvp
|
||||
def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
assert not jaxpr.constvars
|
||||
policy = params['policy'] or (lambda *_, **__: False)
|
||||
# unzip into known and jaxpr_unknown
|
||||
# unzip into jaxpr_known and jaxpr_unknown
|
||||
in_unknowns = [not t.is_known() for t in tracers]
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
|
||||
pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
|
||||
|
@ -788,7 +788,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
in_unknowns = ([False] * len(consts) +
|
||||
[not t.is_known() for t in it.chain(env_tracers, tracers)])
|
||||
if params['policy']:
|
||||
# unzip into known and jaxpr_unknown
|
||||
# unzip into jaxpr_known and jaxpr_unknown
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = _partial_eval_jaxpr_custom(
|
||||
jaxpr, in_unknowns, params['policy'])
|
||||
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
|
||||
@ -978,7 +978,6 @@ def call_partial_eval_custom_rule(
|
||||
eqn.params[jaxpr_param_name], unks_in, saveable)
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
out_binders_known = [v for v in out_binders_known if v is not dropvar]
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
||||
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
|
||||
|
@ -3981,7 +3981,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(res[3][0].shape, ())
|
||||
self.assertEqual(res[3][1], "from the argument 'y'")
|
||||
self.assertEqual(res[4][0].shape, ())
|
||||
self.assertStartsWith(res[4][1], "named z")
|
||||
self.assertStartsWith(res[4][1], "named 'z'")
|
||||
self.assertEqual(res[5][0].shape, ())
|
||||
|
||||
def test_saved_residuals_utility_literals(self):
|
||||
@ -3989,6 +3989,14 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertLen(res, 1)
|
||||
self.assertEqual(res[0][0].shape, ())
|
||||
|
||||
def test_checkpoint_dropvars(self):
|
||||
@new_checkpoint
|
||||
def f(x):
|
||||
_, x = api.jit(lambda: (x, x))()
|
||||
return x
|
||||
|
||||
_ = api.grad(f)(3.) # doesn't crash
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user