add jax.custom_gradient wrapper for jax.custom_vjp

There was a deprecatd version of this wrapper implemented in terms of
jax.custom_transforms (which itself is deprecated, and hopefully soon to
be removed), but this commit adds an implementation in terms of
jax.custom_vjp. One drawback it has relative to jax.custom_vjp is that
it doesn't support Python control flow in the backward-pass function.
This commit is contained in:
Matthew Johnson 2020-10-23 13:54:23 -07:00
parent 9ba28d2634
commit 3c6cdcfc8f
3 changed files with 138 additions and 20 deletions

View File

@ -65,7 +65,7 @@ from .interpreters import batching
from .interpreters import masking
from .interpreters import invertible_ad as iad
from .interpreters.invertible_ad import custom_ivjp
from .custom_derivatives import custom_jvp, custom_vjp
from .custom_derivatives import custom_jvp, custom_vjp, custom_gradient
from .config import flags, config, bool_env
AxisName = Any
@ -2472,16 +2472,6 @@ def defvjp(fun, *vjprules):
return ans, vjpfun
defvjp_all(fun, custom_vjp)
def custom_gradient(fun):
"""This API is deprecated. See :py:func:`jax.custom_jvp` and :py:func:`jax.custom_vjp` instead."""
def primal_fun(*args, **kwargs):
ans, _ = fun(*args, **kwargs)
return ans
primal_fun = custom_transforms(primal_fun)
defvjp_all(primal_fun, fun)
return primal_fun
def _ensure_tuple(x: Union[int, Iterable[int]]) -> Tuple[int, ...]:
return (x,) if isinstance(x, int) else tuple(x)

View File

@ -20,7 +20,8 @@ from typing import Callable, Sequence, Tuple, Any
from . import core
from . import linear_util as lu
from .tree_util import tree_flatten, tree_unflatten, tree_map, tree_multimap
from .tree_util import (tree_flatten, tree_unflatten, tree_map, tree_multimap,
register_pytree_node_class)
from .util import safe_zip, safe_map, split_list
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
from .abstract_arrays import raise_to_shaped
@ -721,3 +722,106 @@ def omnistaging_disabler() -> None:
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees,
num_consts=len(consts))
def custom_gradient(fun):
"""Convenience function for defining custom VJP rules (aka custom gradients).
While the canonical way to define custom VJP rules is via ``jax.custom_vjp``,
the ``custom_gradient`` convenience wrapper follows TensorFlow's
``tf.custom_gradient`` API. The difference here is that ``custom_gradient``
can be used as a decorator on one function that returns both the primal value
(representing the output of the mathematical function to be differentiated)
and the VJP (gradient) function. See
https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
If the mathematical function to be differentiated has type signature ``a ->
b``, then the Python callable ``fun`` should have signature
``a -> (b, CT b --o CT a)`` where we use ``CT x`` to denote a cotangent type
for ``x`` and the ``--o`` arrow to denote a linear function. See the example
below. That is, ``fun`` should return a pair where the first element
represents the value of the mathematical function to be differentiated and the
second element is a function to be called on the backward pass of reverse-mode
automatic differentiation (i.e. the "custom gradient" function).
The function returned as the second element of the output of ``fun`` can close
over intermediate values computed when evaluating the function to be
differentiated. That is, use lexical closure to share work between the forward
pass and the backward pass of reverse-mode automatic differentiation. However,
it cannot support Python control flow.
Args:
fun: a Python callable specifying both the mathematical function to be
differentiated and its reverse-mode differentiation rule. It should return
a pair consisting of an output value and a Python callable that represents
the custom gradient function.
Returns:
A Python callable that accepts the same arguments as ``fun`` and returns the
output value specified by the first element of ``fun``'s output pair.
For example:
>>> @jax.custom_gradient
... def f(x):
... return x ** 2, lambda g: (g * x,)
...
>>> print(f(3.))
9.0
>>> print(jax.grad(f)(3.))
3.0
An example with a function on two arguments, so that the VJP function must
return a tuple of length two:
>>> @jax.custom_gradient
... def f(x, y):
... return x * y, lambda g: (y, x)
...
>>> print(f(3., 4.))
12.0
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(4.0, 3.0)
"""
@custom_vjp
def wrapped_fun(*args, **kwargs):
ans, _ = fun(*args, **kwargs)
return ans
def fwd(*args, **kwargs):
ans, rule = fun(*args, **kwargs)
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
else:
ans_pvals = [pe.PartialVal.unknown(a) for a in ans_avals]
jaxpr, _, consts = pe.trace_to_jaxpr(rule, ans_pvals, instantiate=True)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
def bwd(res, cts):
jaxpr, in_tree, out_tree, consts = res
cts_flat, out_tree_ = tree_flatten((cts,))
if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}')
cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
return tree_unflatten(in_tree, cts_out)
wrapped_fun.defvjp(fwd, bwd)
return wrapped_fun
@register_pytree_node_class
class Residuals:
def __init__(self, jaxpr, in_tree, out_tree, consts):
self.jaxpr = jaxpr
self.in_tree = in_tree
self.out_tree = out_tree
self.consts = consts
def __iter__(self):
return iter((self.jaxpr, self.in_tree, self.out_tree, self.consts))
def tree_flatten(self):
return self.consts, (self.jaxpr, self.in_tree, self.out_tree)
@classmethod
def tree_unflatten(cls, aux, consts):
jaxpr, in_tree, out_tree = aux
return cls(jaxpr, in_tree, out_tree, consts)

View File

@ -4203,6 +4203,37 @@ class CustomVJPTest(jtu.JaxTestCase):
expected = 2 * jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_gradient(self):
@api.custom_gradient
def f(x):
return x ** 2, lambda g: (g * x,)
self.assertAllClose(f(3.), 9., check_dtypes=False)
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False)
def test_custom_gradient_2(self):
@api.custom_gradient
def f(x, y):
return x * y, lambda g: (y, x)
self.assertAllClose(f(3., 4.), 12., check_dtypes=False)
self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.),
check_dtypes=False)
def test_custom_gradient_3(self):
@api.custom_gradient
def f(x):
vjp = lambda g: (jnp.cos(x) * jnp.array([3., 4., 5.]),)
return jnp.sum(jnp.sin(x)), vjp
self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))),
check_dtypes=False)
self.assertAllClose(
api.grad(f)(jnp.arange(3.)),
api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.array([3., 4., 5.]),
check_dtypes=False)
class InvertibleADTest(jtu.JaxTestCase):
@ -4473,14 +4504,6 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
expected = 2.
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_gradient(self):
@api.custom_gradient
def f(x):
return x ** 2, lambda g: (g * x,)
self.assertAllClose(f(3.), 9., check_dtypes=False)
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
def test_custom_vjp_zeros(self):
@api.custom_transforms
def f(x, y):
@ -4517,6 +4540,7 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
b = jnp.dot(a + jnp.eye(a.shape[0]), real_x)
print(gf(a, b)) # doesn't crash
class BufferDonationTest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.