mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
64a7d17239
commit
0e49133e12
@ -252,6 +252,7 @@ def _flatten_jvp(in_tree, *args):
|
||||
yield primals_out + tangents_out, out_tree
|
||||
|
||||
def _custom_jvp_call_bind(prim, fun, jvp, *args):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
level = (core.trace_state.trace_stack.next_level(True)
|
||||
if top_trace is None else top_trace.level)
|
||||
@ -490,6 +491,7 @@ def _flatten_bwd(in_tree, out_trees, *args):
|
||||
yield cts_in
|
||||
|
||||
def _custom_vjp_call_bind(prim, fun, fwd, bwd, *args, out_trees):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
level = (core.trace_state.trace_stack.next_level(True)
|
||||
if top_trace is None else top_trace.level)
|
||||
|
@ -2781,6 +2781,19 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
expected = jax.grad(f, 0)(2., 0.1) + jax.grad(f, 0)(2., 0.2)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_lowering_out_of_traces(self):
|
||||
# https://github.com/google/jax/issues/2578
|
||||
|
||||
class F(collections.namedtuple("F", ["a"])):
|
||||
def __call__(self, x):
|
||||
return jax.nn.relu(self.a) * x
|
||||
|
||||
@jax.jit
|
||||
def g(f, x):
|
||||
return f(x)
|
||||
|
||||
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
|
||||
|
||||
|
||||
class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user