fix defjvps with None as first rule

fixes #3389
This commit is contained in:
Matthew Johnson 2020-06-09 15:19:53 -07:00
parent 15bc62204e
commit 0011fd5e94
2 changed files with 13 additions and 2 deletions

View File

@ -75,7 +75,7 @@ def _initial_style_jaxpr(fun, in_avals):
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
return typed_jaxpr
def sum_tangents(x, *xs):
def sum_tangents(_, x, *xs):
return reduce(ad.add_tangents, xs, x)
def zeros_like_pytree(x):
@ -196,7 +196,7 @@ class custom_jvp:
zeros = zeros_like_pytree(primal_out)
all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros
for t, jvp in zip(tangents, jvps)]
tangent_out = tree_multimap(sum_tangents, *all_tangents_out)
tangent_out = tree_multimap(sum_tangents, primal_out, *all_tangents_out)
return primal_out, tangent_out
self.defjvp(jvp)

View File

@ -2476,6 +2476,17 @@ class CustomJVPTest(jtu.JaxTestCase):
expected = -1.
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_jvps_first_rule_is_none(self):
# https://github.com/google/jax/issues/3389
@api.custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot)
ans = grad(f, 1)(2., 3.) # doesn't crash
expected = 12.
self.assertAllClose(ans, expected, check_dtypes=False)
class CustomVJPTest(jtu.JaxTestCase):