mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
test closure conversion, following docstring example
This commit is contained in:
parent
0ad5d2f0c2
commit
77d7339bb3
@ -4333,6 +4333,42 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.array([3., 4., 5.]),
|
||||
check_dtypes=False)
|
||||
|
||||
def test_closure_convert(self):
|
||||
def minimize(objective_fn, x0):
|
||||
converted_fn, aux_args = api.closure_convert(objective_fn, x0)
|
||||
return _minimize(converted_fn, x0, *aux_args)
|
||||
|
||||
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
||||
def _minimize(objective_fn, x0, *args):
|
||||
_ = objective_fn(x0, *args)
|
||||
return jnp.cos(x0)
|
||||
|
||||
def fwd(objective_fn, x0, *args):
|
||||
y = _minimize(objective_fn, x0, *args)
|
||||
return y, (y, args)
|
||||
|
||||
def rev(objective_fn, res, g):
|
||||
y, args = res
|
||||
x0_bar = 17. * y
|
||||
args_bars = [42. * a for a in args]
|
||||
return (x0_bar, *args_bars)
|
||||
|
||||
_minimize.defvjp(fwd, rev)
|
||||
|
||||
def obj(c, x):
|
||||
return jnp.sum((x - c) ** 2.)
|
||||
|
||||
def solve(c, x):
|
||||
def closure(x):
|
||||
return obj(c, x)
|
||||
return jnp.sum(minimize(closure, x))
|
||||
|
||||
c, x = jnp.ones(2), jnp.zeros(2)
|
||||
self.assertAllClose(solve(c, x), 2.0, check_dtypes=False)
|
||||
g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x)
|
||||
self.assertAllClose(g_c, 42. * jnp.ones(2), check_dtypes=False)
|
||||
self.assertAllClose(g_x, 17. * jnp.ones(2), check_dtypes=False)
|
||||
|
||||
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user