mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #4826 from apaszke:fix-invertible
PiperOrigin-RevId: 341309335
This commit is contained in:
commit
a7966f7d94
@ -66,14 +66,14 @@ def invertible(fun):
|
||||
"gradients computed correctly (their uses inside this function will be ignored)!")
|
||||
# TODO: This requires the body to be jittable, but this shouldn't be necessary.
|
||||
# Is there a way to trace a jaxpr while running it?
|
||||
outs = core.eval_jaxpr(jaxpr, consts, *flat_args)
|
||||
return tree_unflatten(out_tree(), outs), (args, outs, consts, DontFlatten((jaxpr, in_tree)))
|
||||
flat_outs = core.eval_jaxpr(jaxpr, consts, *flat_args)
|
||||
return tree_unflatten(out_tree(), flat_outs), (flat_args, flat_outs, consts, DontFlatten((jaxpr, in_tree)))
|
||||
|
||||
def bwd(res, cts):
|
||||
args, outs, consts, aux = res
|
||||
flat_args, flat_outs, consts, aux = res
|
||||
jaxpr, in_tree = aux.val
|
||||
flat_cts, _ = tree_flatten(cts)
|
||||
return tree_unflatten(in_tree, inv_backward_pass(jaxpr, consts, args, outs, flat_cts))
|
||||
return tree_unflatten(in_tree, inv_backward_pass(jaxpr, consts, flat_args, flat_outs, flat_cts))
|
||||
|
||||
ifun.defvjp(fwd, bwd)
|
||||
|
||||
@ -192,7 +192,6 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
|
||||
for primal in primals_out)
|
||||
should_vjp = any(type(ct) is not ad.Zero for ct in cts_in)
|
||||
assert not eqn.primitive.call_primitive
|
||||
assert not (should_invert ^ should_vjp) # Either both true or both false
|
||||
|
||||
# Skip primals equations that are only jvp coefficients and don't affect
|
||||
# primal outputs.
|
||||
|
@ -4356,6 +4356,16 @@ class InvertibleADTest(jtu.JaxTestCase):
|
||||
jax.value_and_grad(lambda x: np.sum(finv(x, o)[0]))(o),
|
||||
check_dtypes=True)
|
||||
|
||||
def test_invertible_pytree(self):
|
||||
def f(x, y):
|
||||
return jnp.exp(x[0]) * x[1] + y
|
||||
|
||||
finv = jax.invertible(f)
|
||||
o = np.ones((5,))
|
||||
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f((x, x), x)[0]))(o),
|
||||
jax.value_and_grad(lambda x: np.sum(finv((x, x), x)[0]))(o),
|
||||
check_dtypes=True)
|
||||
|
||||
|
||||
class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user