mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix diffrax with new remat, do for cond what #11773 did for while_loop
This commit is contained in:
parent
88636d2b64
commit
666f4f838f
@ -438,9 +438,20 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
# TODO(mattjj): de-duplicate with _cond_partial_eval
|
||||
def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
index_uk, *ops_uk = unks_in
|
||||
assert not index_uk # only possible with old-style remat
|
||||
branches = eqn.params['branches']
|
||||
|
||||
# Instantiate all inputs (b/c jaxpr_staged will take all inputs).
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
del inst_in
|
||||
|
||||
# NOTE(mattjj): I think it should be impossible for the index to be unknown,
|
||||
# but asserting that caused a test failure in diffrax. So we handle it: if it
|
||||
# is unknown, stage out the whole cond.
|
||||
if index_uk:
|
||||
all_true = [True] * len(branches[0].out_avals)
|
||||
return None, eqn, all_true, all_true, new_inst
|
||||
|
||||
# First, compute output unknowns (unks_out), where an output of the cond is
|
||||
# unknown if it would be unknown on any of the branches.
|
||||
unks_out: List[bool] = [False] * len(eqn.outvars)
|
||||
@ -477,12 +488,6 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
||||
for j in branches_known[1:])
|
||||
|
||||
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
|
||||
# passing in_inst argument to partial_eval_jaxpr_custom above).
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
del inst_in
|
||||
|
||||
# Create residual variables.
|
||||
newvar = core.gensym()
|
||||
res_binders = map(newvar, all_res_avals)
|
||||
|
Loading…
x
Reference in New Issue
Block a user