mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #7130 from google:custom-jvp-hotfix
PiperOrigin-RevId: 382665377
This commit is contained in:
commit
c97d63dec3
@ -242,8 +242,12 @@ def _flatten_jvp(in_tree, *args):
|
||||
msg = ("Custom JVP rule must produce primal and tangent outputs with equal "
|
||||
"container (pytree) structures, but got {} and {} respectively.")
|
||||
raise TypeError(msg.format(out_tree, out_tree2))
|
||||
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
|
||||
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) for t in tangents_out]
|
||||
primal_avals_out = [
|
||||
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
|
||||
for x in primals_out]
|
||||
tangent_avals_out = [
|
||||
raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
|
||||
for t in tangents_out]
|
||||
if primal_avals_out != tangent_avals_out:
|
||||
if len(primal_avals_out) == 1:
|
||||
(av1,), (av2,) = primal_avals_out, tangent_avals_out
|
||||
|
Loading…
x
Reference in New Issue
Block a user