mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
custom_jvp named shape hotfix
When we added "avals with names", we intended to start by making the distinction between types with and without named axes load-bearing only in specific parts of the system, while (continuing to) ignore it elsewhere. This fixes a spot I missed, and that a user ran into. Most likely, we'll want to restore something like this typecheck after vmap and pmap use avals with names; for now, the typecheck won't always be satisfied in those contexts and needs to be loosened.
This commit is contained in:
parent
0d68dbd619
commit
e5d84522b7
@ -242,8 +242,12 @@ def _flatten_jvp(in_tree, *args):
|
||||
msg = ("Custom JVP rule must produce primal and tangent outputs with equal "
|
||||
"container (pytree) structures, but got {} and {} respectively.")
|
||||
raise TypeError(msg.format(out_tree, out_tree2))
|
||||
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
|
||||
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) for t in tangents_out]
|
||||
primal_avals_out = [
|
||||
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
|
||||
for x in primals_out]
|
||||
tangent_avals_out = [
|
||||
raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
|
||||
for t in tangents_out]
|
||||
if primal_avals_out != tangent_avals_out:
|
||||
if len(primal_avals_out) == 1:
|
||||
(av1,), (av2,) = primal_avals_out, tangent_avals_out
|
||||
|
Loading…
x
Reference in New Issue
Block a user