mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remat: fix regression of broke calling convention
In #7631 we made `_partial_eval_jaxpr_custom` follow a convention: drop
unit outputs from the known jaxprs it returned:
3377abaef4/jax/interpreters/partial_eval.py (L937)
The caller needed to compensate for that, e.g. by dropping the
corresponding binders in an outer jaxpr eqn representing an application
of that inner jaxpr. But that logic was written in terms of checking for
dropped outputs in the outer jaxpr (since units are typically not
consumed downstream):
https://github.com/google/jax/pull/8227/files#diff-440d9df723b313bb263bc7704103cad1dcc886ff6553aa78c30188b0b323b686L981
That worked (or at least we never noticed a failure, though now it seems
sketchy...) with the 'classic' `jax.checkpoint` / `jax.remat`
implementation, before #8191, because of how that implementation relied
on tracing-based partial evaluation, which would detect and mark dropped
outputs in the outer jaxpr as part of jaxpr formation.
But then in #8191 we no longer marked dropvars in the same way. That led
to assertion failures, and #8227 attempted to fix those. That fix made
sense with the new remat implementation, but not the old one! (In the
intervening period I forgot about this unit-dropping convention...)
The fix here is not to rely on dropvars but to more directly encode the
convention that _partial_eval_jaxpr_custom drops unit outputs in the
known jaxpr it produces.
This commit is contained in:
parent
3377abaef4
commit
89606c2c35
@ -18,6 +18,7 @@ import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import operator as op
|
||||
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
|
||||
List, Union, cast)
|
||||
from weakref import ref
|
||||
@ -974,10 +975,14 @@ def call_partial_eval_custom_rule(
|
||||
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
|
||||
eqn: JaxprEqn
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
|
||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = _partial_eval_jaxpr_custom(
|
||||
eqn.params[jaxpr_param_name], unks_in, saveable)
|
||||
jaxpr = eqn.params[jaxpr_param_name]
|
||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
||||
_partial_eval_jaxpr_custom(jaxpr, unks_in, saveable)
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
# by convention, _partial_eval_jaxpr_custom drops units on known outputs
|
||||
known_units_out = [v.aval is core.abstract_unit for v in jaxpr.outvars]
|
||||
dropped_outs_known = map(op.or_, unks_out, known_units_out)
|
||||
out_binders_known, _ = partition_list(dropped_outs_known, eqn.outvars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
||||
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
|
||||
|
@ -3989,8 +3989,15 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertLen(res, 1)
|
||||
self.assertEqual(res[0][0].shape, ())
|
||||
|
||||
def test_checkpoint_dropvars(self):
|
||||
@new_checkpoint
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_checkpoint_dropvars(self, remat):
|
||||
@remat
|
||||
def f(x):
|
||||
_, x = api.jit(lambda: (x, x))()
|
||||
return x
|
||||
@ -4005,6 +4012,21 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
_ = jax.grad(f)(3.) # doesn't crash
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
('', api.remat),
|
||||
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
def test_unit_dropvar_consistency_regression(self, remat):
|
||||
@partial(remat, policy=lambda *_, **__: False)
|
||||
def f(u, x):
|
||||
x, _ = jax.jit(lambda x: (x, u))(x)
|
||||
return x
|
||||
|
||||
_ = api.linearize(partial(f, core.unit), 3.)
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
def test_scalar_literals(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user