custom_jvp named shape hotfix

When we added "avals with names", we intended to start by making the distinction between types with and without named axes load-bearing only in specific parts of the system, while (continuing to) ignore it elsewhere. This fixes a spot I missed, and that a user ran into.

Most likely, we'll want to restore something like this typecheck after vmap and pmap use avals with names; for now, the typecheck won't always be satisfied in those contexts and needs to be loosened.
This commit is contained in:
James Bradbury 2021-06-28 19:47:50 -07:00 committed by GitHub
parent 0d68dbd619
commit e5d84522b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -242,8 +242,12 @@ def _flatten_jvp(in_tree, *args):
msg = ("Custom JVP rule must produce primal and tangent outputs with equal "
"container (pytree) structures, but got {} and {} respectively.")
raise TypeError(msg.format(out_tree, out_tree2))
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) for t in tangents_out]
primal_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
for x in primals_out]
tangent_avals_out = [
raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
for t in tangents_out]
if primal_avals_out != tangent_avals_out:
if len(primal_avals_out) == 1:
(av1,), (av2,) = primal_avals_out, tangent_avals_out