Allow jax.custom_gradient to return vjp with singleton return value

This commit is contained in:
Sharad Vikram 2021-01-26 12:39:35 -08:00
parent f4cf710f53
commit 6061b0979a
2 changed files with 14 additions and 2 deletions

View File

@ -22,7 +22,7 @@ from . import core
from . import dtypes
from . import linear_util as lu
from .tree_util import (tree_flatten, tree_unflatten, tree_map, tree_multimap,
register_pytree_node_class)
treedef_is_leaf, register_pytree_node_class)
from ._src.util import cache, safe_zip, safe_map, split_list
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
from .core import raise_to_shaped
@ -818,7 +818,10 @@ def custom_gradient(fun):
cts_flat, out_tree_ = tree_flatten((cts,))
if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}')
cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
return tree_unflatten(in_tree, cts_out)
cts_out = tree_unflatten(in_tree, cts_out)
if treedef_is_leaf(in_tree):
cts_out = (cts_out,)
return cts_out
wrapped_fun.defvjp(fwd, bwd)
return wrapped_fun

View File

@ -4536,6 +4536,15 @@ class CustomVJPTest(jtu.JaxTestCase):
api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.array([3., 4., 5.]),
check_dtypes=False)
def test_custom_gradient_can_return_singleton_value_in_vjp(self):
@api.custom_gradient
def f(x):
return x ** 2, lambda g: g * x
self.assertAllClose(f(3.), 9., check_dtypes=False)
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False)
def test_closure_convert(self):
def minimize(objective_fn, x0):
converted_fn, aux_args = api.closure_convert(objective_fn, x0)