diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 1d0769e64..20d330571 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -242,8 +242,12 @@ def _flatten_jvp(in_tree, *args): msg = ("Custom JVP rule must produce primal and tangent outputs with equal " "container (pytree) structures, but got {} and {} respectively.") raise TypeError(msg.format(out_tree, out_tree2)) - primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] - tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) for t in tangents_out] + primal_avals_out = [ + raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape() + for x in primals_out] + tangent_avals_out = [ + raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape() + for t in tangents_out] if primal_avals_out != tangent_avals_out: if len(primal_avals_out) == 1: (av1,), (av2,) = primal_avals_out, tangent_avals_out