mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #25875 from jax-ml:issue-25517
PiperOrigin-RevId: 715364096
This commit is contained in:
commit
29a3dded3d
@ -1838,10 +1838,12 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
|
||||
def fun(*tangents):
|
||||
tangent_avals = list(map(core.get_aval, tangents))
|
||||
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
|
||||
if not core.typecompat(primal_aval.to_tangent_aval(), tangent_aval):
|
||||
expected_tangent_aval = primal_aval.to_tangent_aval()
|
||||
if not core.typecompat(expected_tangent_aval, tangent_aval):
|
||||
raise ValueError("linearized function called on tangent values inconsistent with "
|
||||
"the original primal values: "
|
||||
f"got {tangent_aval} for primal aval {primal_aval}")
|
||||
f"got tangent aval {tangent_aval} for primal aval {primal_aval} "
|
||||
f"but expected {expected_tangent_aval}")
|
||||
tangents_out = eval_jaxpr(jaxpr, consts, *tangents)
|
||||
tangents_out_ = iter(tangents_out)
|
||||
full_out = [pval.get_known() if pval.is_known() else next(tangents_out_)
|
||||
|
Loading…
x
Reference in New Issue
Block a user