Merge pull request #7130 from google:custom-jvp-hotfix

PiperOrigin-RevId: 382665377
This commit is contained in:
jax authors 2021-07-01 21:23:17 -07:00
commit c97d63dec3

View File

@ -242,8 +242,12 @@ def _flatten_jvp(in_tree, *args):
msg = ("Custom JVP rule must produce primal and tangent outputs with equal "
"container (pytree) structures, but got {} and {} respectively.")
raise TypeError(msg.format(out_tree, out_tree2))
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) for t in tangents_out]
primal_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
for x in primals_out]
tangent_avals_out = [
raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
for t in tangents_out]
if primal_avals_out != tangent_avals_out:
if len(primal_avals_out) == 1:
(av1,), (av2,) = primal_avals_out, tangent_avals_out