From 0bf3241e93acf734afc0ba0875f83ba505984057 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 29 Apr 2022 15:50:25 -0700 Subject: [PATCH] [remove-units] remove now-dead flax helper function --- jax/interpreters/partial_eval.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index c0e4a2f49..30730378a 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1856,14 +1856,6 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun, del fun, main return jaxpr, out_avals, consts -def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]): - # This function provides a partial evaluation behavior used by Flax. We can't - # use trace_to_jaxpr directly because of an interaction with the curent - # custom_derivatives.py, which we work around by adding the EvalTrace. - # TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py - with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore - return trace_to_jaxpr(fun, in_pvals) - AbstractedAxisName = Hashable AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], Tuple[AbstractedAxisName, ...]]