Merge pull request #10506 from mattjj:remove-partial-eval-to-jaxpr-dynamic

PiperOrigin-RevId: 445532927
This commit is contained in:
jax authors 2022-04-29 16:38:17 -07:00
commit b90df4bf4d

View File

@ -1856,14 +1856,6 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun,
del fun, main
return jaxpr, out_avals, consts
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
# This function provides a partial evaluation behavior used by Flax. We can't
# use trace_to_jaxpr directly because of an interaction with the curent
# custom_derivatives.py, which we work around by adding the EvalTrace.
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
return trace_to_jaxpr(fun, in_pvals)
AbstractedAxisName = Hashable
AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], Tuple[AbstractedAxisName, ...]]