mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
turn lifted_jvp into a PyTree
This commit is contained in:
parent
df3cc0d980
commit
3041c18250
@ -1847,11 +1847,12 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
|
||||
out_tree = out_tree()
|
||||
out_primal_py = tree_unflatten(out_tree, out_primals)
|
||||
primal_avals = list(map(core.get_aval, primals_flat))
|
||||
lifted_jvp = partial(_lift_linearized, jaxpr, primal_avals, consts,
|
||||
(in_tree, out_tree), out_pvals)
|
||||
# Ensure that lifted_jvp is a PyTree
|
||||
lifted_jvp = Partial(partial(_lift_linearized, jaxpr, primal_avals,
|
||||
(in_tree, out_tree), out_pvals), consts)
|
||||
return out_primal_py, lifted_jvp
|
||||
|
||||
def _lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pvals, *py_args):
|
||||
def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
|
||||
def fun(*tangents):
|
||||
tangent_avals = list(map(core.get_aval, tangents))
|
||||
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
|
||||
|
Loading…
x
Reference in New Issue
Block a user