fix diffrax with new remat, do for cond what #11773 did for while_loop

This commit is contained in:
Matthew Johnson 2022-08-09 12:22:10 -07:00
parent 88636d2b64
commit 666f4f838f

View File

@ -438,9 +438,20 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
# TODO(mattjj): de-duplicate with _cond_partial_eval
def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
index_uk, *ops_uk = unks_in
assert not index_uk # only possible with old-style remat
branches = eqn.params['branches']
# Instantiate all inputs (b/c jaxpr_staged will take all inputs).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
del inst_in
# NOTE(mattjj): I think it should be impossible for the index to be unknown,
# but asserting that caused a test failure in diffrax. So we handle it: if it
# is unknown, stage out the whole cond.
if index_uk:
all_true = [True] * len(branches[0].out_avals)
return None, eqn, all_true, all_true, new_inst
# First, compute output unknowns (unks_out), where an output of the cond is
# unknown if it would be unknown on any of the branches.
unks_out: List[bool] = [False] * len(eqn.outvars)
@ -477,12 +488,6 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
for j in branches_known[1:])
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
# passing in_inst argument to partial_eval_jaxpr_custom above).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
del inst_in
# Create residual variables.
newvar = core.gensym()
res_binders = map(newvar, all_res_avals)