mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #10506 from mattjj:remove-partial-eval-to-jaxpr-dynamic
PiperOrigin-RevId: 445532927
This commit is contained in:
commit
b90df4bf4d
@ -1856,14 +1856,6 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
||||
del fun, main
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
|
||||
# This function provides a partial evaluation behavior used by Flax. We can't
|
||||
# use trace_to_jaxpr directly because of an interaction with the curent
|
||||
# custom_derivatives.py, which we work around by adding the EvalTrace.
|
||||
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
|
||||
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
||||
return trace_to_jaxpr(fun, in_pvals)
|
||||
|
||||
|
||||
AbstractedAxisName = Hashable
|
||||
AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], Tuple[AbstractedAxisName, ...]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user