fix custom_jvp_call_jaxpr transpose function (#3231)

* make custom_jvp_call_jaxpr handle multilinear funs

see #3226

* remove old comment
This commit is contained in:
Matthew Johnson 2020-05-28 10:20:36 -07:00 committed by GitHub
parent c1ccbdf1a7
commit 572928dfa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 8 deletions

View File

@ -343,15 +343,11 @@ xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl)
# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. We transpose it
# like a core.call.
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,
avals):
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
# already been linearized, we can drop the jvp rule.
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk):
del jvp_jaxpr_thunk
name = 'custom_jvp_call_jaxpr_linear'
avals = [core.get_aval(l) for l in fun_jaxpr.literals] + avals
return ad.call_transpose(core.call_p, dict(name=name), fun_jaxpr.jaxpr,
tuple(fun_jaxpr.literals) + args, cts, avals)
return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.literals, args, cts)
ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose

View File

@ -2457,6 +2457,25 @@ class CustomJVPTest(jtu.JaxTestCase):
grad(experiment)(1.) # doesn't crash
def test_linear_in_scan(self):
@api.custom_jvp
def f(x):
return -x
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
return f(x), f(x_dot)
def foo(x):
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
return out
ans = api.grad(foo)(3.)
expected = -1.
self.assertAllClose(ans, expected, check_dtypes=False)
class CustomVJPTest(jtu.JaxTestCase):