Merge pull request #4826 from apaszke:fix-invertible

PiperOrigin-RevId: 341309335
This commit is contained in:
jax authors 2020-11-08 14:31:40 -08:00
commit a7966f7d94
2 changed files with 14 additions and 5 deletions

View File

@ -66,14 +66,14 @@ def invertible(fun):
"gradients computed correctly (their uses inside this function will be ignored)!")
# TODO: This requires the body to be jittable, but this shouldn't be necessary.
# Is there a way to trace a jaxpr while running it?
outs = core.eval_jaxpr(jaxpr, consts, *flat_args)
return tree_unflatten(out_tree(), outs), (args, outs, consts, DontFlatten((jaxpr, in_tree)))
flat_outs = core.eval_jaxpr(jaxpr, consts, *flat_args)
return tree_unflatten(out_tree(), flat_outs), (flat_args, flat_outs, consts, DontFlatten((jaxpr, in_tree)))
def bwd(res, cts):
args, outs, consts, aux = res
flat_args, flat_outs, consts, aux = res
jaxpr, in_tree = aux.val
flat_cts, _ = tree_flatten(cts)
return tree_unflatten(in_tree, inv_backward_pass(jaxpr, consts, args, outs, flat_cts))
return tree_unflatten(in_tree, inv_backward_pass(jaxpr, consts, flat_args, flat_outs, flat_cts))
ifun.defvjp(fwd, bwd)
@ -192,7 +192,6 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
for primal in primals_out)
should_vjp = any(type(ct) is not ad.Zero for ct in cts_in)
assert not eqn.primitive.call_primitive
assert not (should_invert ^ should_vjp) # Either both true or both false
# Skip primals equations that are only jvp coefficients and don't affect
# primal outputs.

View File

@ -4356,6 +4356,16 @@ class InvertibleADTest(jtu.JaxTestCase):
jax.value_and_grad(lambda x: np.sum(finv(x, o)[0]))(o),
check_dtypes=True)
def test_invertible_pytree(self):
def f(x, y):
return jnp.exp(x[0]) * x[1] + y
finv = jax.invertible(f)
o = np.ones((5,))
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f((x, x), x)[0]))(o),
jax.value_and_grad(lambda x: np.sum(finv((x, x), x)[0]))(o),
check_dtypes=True)
class DeprecatedCustomTransformsTest(jtu.JaxTestCase):