turn lifted_jvp into a PyTree

This commit is contained in:
Clemens Giuliani 2021-06-24 23:54:57 +02:00
parent df3cc0d980
commit 3041c18250

View File

@ -1847,11 +1847,12 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
out_tree = out_tree()
out_primal_py = tree_unflatten(out_tree, out_primals)
primal_avals = list(map(core.get_aval, primals_flat))
lifted_jvp = partial(_lift_linearized, jaxpr, primal_avals, consts,
(in_tree, out_tree), out_pvals)
# Ensure that lifted_jvp is a PyTree
lifted_jvp = Partial(partial(_lift_linearized, jaxpr, primal_avals,
(in_tree, out_tree), out_pvals), consts)
return out_primal_py, lifted_jvp
def _lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pvals, *py_args):
def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
def fun(*tangents):
tangent_avals = list(map(core.get_aval, tangents))
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):