mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[remove-units] remove now-dead flax helper function
This commit is contained in:
parent
d0799acd6c
commit
0bf3241e93
@ -1856,14 +1856,6 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
||||
del fun, main
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
|
||||
# This function provides a partial evaluation behavior used by Flax. We can't
|
||||
# use trace_to_jaxpr directly because of an interaction with the curent
|
||||
# custom_derivatives.py, which we work around by adding the EvalTrace.
|
||||
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
|
||||
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
||||
return trace_to_jaxpr(fun, in_pvals)
|
||||
|
||||
|
||||
AbstractedAxisName = Hashable
|
||||
AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], Tuple[AbstractedAxisName, ...]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user