mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #9228 from mattjj:cache-partial-eval-jaxpr
PiperOrigin-RevId: 422836999
This commit is contained in:
commit
98b5b3a2a9
@ -723,6 +723,11 @@ def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
|
||||
to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
|
||||
simply get passed as residual constants and returned immediately.
|
||||
"""
|
||||
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
||||
return _partial_eval_jaxpr(jaxpr, tuple(unknowns), instantiate)
|
||||
|
||||
@cache()
|
||||
def _partial_eval_jaxpr(jaxpr, unknowns, instantiate):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
|
||||
cell = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user