mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow jax.custom_gradient
to return vjp with singleton return value
This commit is contained in:
parent
f4cf710f53
commit
6061b0979a
@ -22,7 +22,7 @@ from . import core
|
||||
from . import dtypes
|
||||
from . import linear_util as lu
|
||||
from .tree_util import (tree_flatten, tree_unflatten, tree_map, tree_multimap,
|
||||
register_pytree_node_class)
|
||||
treedef_is_leaf, register_pytree_node_class)
|
||||
from ._src.util import cache, safe_zip, safe_map, split_list
|
||||
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
|
||||
from .core import raise_to_shaped
|
||||
@ -818,7 +818,10 @@ def custom_gradient(fun):
|
||||
cts_flat, out_tree_ = tree_flatten((cts,))
|
||||
if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}')
|
||||
cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
|
||||
return tree_unflatten(in_tree, cts_out)
|
||||
cts_out = tree_unflatten(in_tree, cts_out)
|
||||
if treedef_is_leaf(in_tree):
|
||||
cts_out = (cts_out,)
|
||||
return cts_out
|
||||
|
||||
wrapped_fun.defvjp(fwd, bwd)
|
||||
return wrapped_fun
|
||||
|
@ -4536,6 +4536,15 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.array([3., 4., 5.]),
|
||||
check_dtypes=False)
|
||||
|
||||
def test_custom_gradient_can_return_singleton_value_in_vjp(self):
|
||||
@api.custom_gradient
|
||||
def f(x):
|
||||
return x ** 2, lambda g: g * x
|
||||
|
||||
self.assertAllClose(f(3.), 9., check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
|
||||
self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False)
|
||||
|
||||
def test_closure_convert(self):
|
||||
def minimize(objective_fn, x0):
|
||||
converted_fn, aux_args = api.closure_convert(objective_fn, x0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user