add full lower to custom_jvp/vjp call bind

fixes #2578
This commit is contained in:
Matthew Johnson 2020-04-02 22:52:07 -07:00
parent 64a7d17239
commit 0e49133e12
2 changed files with 15 additions and 0 deletions

View File

@ -252,6 +252,7 @@ def _flatten_jvp(in_tree, *args):
yield primals_out + tangents_out, out_tree
def _custom_jvp_call_bind(prim, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
level = (core.trace_state.trace_stack.next_level(True)
if top_trace is None else top_trace.level)
@ -490,6 +491,7 @@ def _flatten_bwd(in_tree, out_trees, *args):
yield cts_in
def _custom_vjp_call_bind(prim, fun, fwd, bwd, *args, out_trees):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
level = (core.trace_state.trace_stack.next_level(True)
if top_trace is None else top_trace.level)

View File

@ -2781,6 +2781,19 @@ class CustomVJPTest(jtu.JaxTestCase):
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_lowering_out_of_traces(self):
# https://github.com/google/jax/issues/2578
class F(collections.namedtuple("F", ["a"])):
def __call__(self, x):
return jax.nn.relu(self.a) * x
@jax.jit
def g(f, x):
return f(x)
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
class DeprecatedCustomTransformsTest(jtu.JaxTestCase):