mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
15bc62204e
commit
0011fd5e94
@ -75,7 +75,7 @@ def _initial_style_jaxpr(fun, in_avals):
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
return typed_jaxpr
|
||||
|
||||
def sum_tangents(x, *xs):
|
||||
def sum_tangents(_, x, *xs):
|
||||
return reduce(ad.add_tangents, xs, x)
|
||||
|
||||
def zeros_like_pytree(x):
|
||||
@ -196,7 +196,7 @@ class custom_jvp:
|
||||
zeros = zeros_like_pytree(primal_out)
|
||||
all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros
|
||||
for t, jvp in zip(tangents, jvps)]
|
||||
tangent_out = tree_multimap(sum_tangents, *all_tangents_out)
|
||||
tangent_out = tree_multimap(sum_tangents, primal_out, *all_tangents_out)
|
||||
return primal_out, tangent_out
|
||||
|
||||
self.defjvp(jvp)
|
||||
|
@ -2476,6 +2476,17 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
expected = -1.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_custom_jvps_first_rule_is_none(self):
|
||||
# https://github.com/google/jax/issues/3389
|
||||
@api.custom_jvp
|
||||
def f(x, y):
|
||||
return x ** 2 * y
|
||||
|
||||
f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot)
|
||||
ans = grad(f, 1)(2., 3.) # doesn't crash
|
||||
expected = 12.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user