mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix custom_jvp_call_jaxpr transpose function (#3231)
* make custom_jvp_call_jaxpr handle multilinear funs see #3226 * remove old comment
This commit is contained in:
parent
c1ccbdf1a7
commit
572928dfa3
@ -343,15 +343,11 @@ xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
|
||||
xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl)
|
||||
|
||||
# If a (multi)linear function is defined with a custom jvp, then
|
||||
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. We transpose it
|
||||
# like a core.call.
|
||||
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,
|
||||
avals):
|
||||
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
|
||||
# already been linearized, we can drop the jvp rule.
|
||||
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk):
|
||||
del jvp_jaxpr_thunk
|
||||
name = 'custom_jvp_call_jaxpr_linear'
|
||||
avals = [core.get_aval(l) for l in fun_jaxpr.literals] + avals
|
||||
return ad.call_transpose(core.call_p, dict(name=name), fun_jaxpr.jaxpr,
|
||||
tuple(fun_jaxpr.literals) + args, cts, avals)
|
||||
return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.literals, args, cts)
|
||||
ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
|
||||
|
||||
|
||||
|
@ -2457,6 +2457,25 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
grad(experiment)(1.) # doesn't crash
|
||||
|
||||
def test_linear_in_scan(self):
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
return -x
|
||||
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, = primals
|
||||
x_dot, = tangents
|
||||
return f(x), f(x_dot)
|
||||
|
||||
def foo(x):
|
||||
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
||||
return out
|
||||
|
||||
ans = api.grad(foo)(3.)
|
||||
expected = -1.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user