Merge pull request #9228 from mattjj:cache-partial-eval-jaxpr

PiperOrigin-RevId: 422836999
This commit is contained in:
jax authors 2022-01-19 09:46:28 -08:00
commit 98b5b3a2a9

View File

@ -723,6 +723,11 @@ def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
simply get passed as residual constants and returned immediately.
"""
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _partial_eval_jaxpr(jaxpr, tuple(unknowns), instantiate)
@cache()
def _partial_eval_jaxpr(jaxpr, unknowns, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
cell = []