Merge pull request #25875 from jax-ml:issue-25517

PiperOrigin-RevId: 715364096
This commit is contained in:
jax authors 2025-01-14 06:55:05 -08:00
commit 29a3dded3d

View File

@ -1838,10 +1838,12 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
def fun(*tangents):
tangent_avals = list(map(core.get_aval, tangents))
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
if not core.typecompat(primal_aval.to_tangent_aval(), tangent_aval):
expected_tangent_aval = primal_aval.to_tangent_aval()
if not core.typecompat(expected_tangent_aval, tangent_aval):
raise ValueError("linearized function called on tangent values inconsistent with "
"the original primal values: "
f"got {tangent_aval} for primal aval {primal_aval}")
f"got tangent aval {tangent_aval} for primal aval {primal_aval} "
f"but expected {expected_tangent_aval}")
tangents_out = eval_jaxpr(jaxpr, consts, *tangents)
tangents_out_ = iter(tangents_out)
full_out = [pval.get_known() if pval.is_known() else next(tangents_out_)