mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add jax.custom_gradient wrapper for jax.custom_vjp
There was a deprecatd version of this wrapper implemented in terms of jax.custom_transforms (which itself is deprecated, and hopefully soon to be removed), but this commit adds an implementation in terms of jax.custom_vjp. One drawback it has relative to jax.custom_vjp is that it doesn't support Python control flow in the backward-pass function.
This commit is contained in:
parent
9ba28d2634
commit
3c6cdcfc8f
12
jax/api.py
12
jax/api.py
@ -65,7 +65,7 @@ from .interpreters import batching
|
||||
from .interpreters import masking
|
||||
from .interpreters import invertible_ad as iad
|
||||
from .interpreters.invertible_ad import custom_ivjp
|
||||
from .custom_derivatives import custom_jvp, custom_vjp
|
||||
from .custom_derivatives import custom_jvp, custom_vjp, custom_gradient
|
||||
from .config import flags, config, bool_env
|
||||
|
||||
AxisName = Any
|
||||
@ -2472,16 +2472,6 @@ def defvjp(fun, *vjprules):
|
||||
return ans, vjpfun
|
||||
defvjp_all(fun, custom_vjp)
|
||||
|
||||
def custom_gradient(fun):
|
||||
"""This API is deprecated. See :py:func:`jax.custom_jvp` and :py:func:`jax.custom_vjp` instead."""
|
||||
|
||||
def primal_fun(*args, **kwargs):
|
||||
ans, _ = fun(*args, **kwargs)
|
||||
return ans
|
||||
primal_fun = custom_transforms(primal_fun)
|
||||
defvjp_all(primal_fun, fun)
|
||||
return primal_fun
|
||||
|
||||
def _ensure_tuple(x: Union[int, Iterable[int]]) -> Tuple[int, ...]:
|
||||
return (x,) if isinstance(x, int) else tuple(x)
|
||||
|
||||
|
@ -20,7 +20,8 @@ from typing import Callable, Sequence, Tuple, Any
|
||||
|
||||
from . import core
|
||||
from . import linear_util as lu
|
||||
from .tree_util import tree_flatten, tree_unflatten, tree_map, tree_multimap
|
||||
from .tree_util import (tree_flatten, tree_unflatten, tree_map, tree_multimap,
|
||||
register_pytree_node_class)
|
||||
from .util import safe_zip, safe_map, split_list
|
||||
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
|
||||
from .abstract_arrays import raise_to_shaped
|
||||
@ -721,3 +722,106 @@ def omnistaging_disabler() -> None:
|
||||
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees,
|
||||
num_consts=len(consts))
|
||||
|
||||
|
||||
def custom_gradient(fun):
|
||||
"""Convenience function for defining custom VJP rules (aka custom gradients).
|
||||
|
||||
While the canonical way to define custom VJP rules is via ``jax.custom_vjp``,
|
||||
the ``custom_gradient`` convenience wrapper follows TensorFlow's
|
||||
``tf.custom_gradient`` API. The difference here is that ``custom_gradient``
|
||||
can be used as a decorator on one function that returns both the primal value
|
||||
(representing the output of the mathematical function to be differentiated)
|
||||
and the VJP (gradient) function. See
|
||||
https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
|
||||
|
||||
If the mathematical function to be differentiated has type signature ``a ->
|
||||
b``, then the Python callable ``fun`` should have signature
|
||||
``a -> (b, CT b --o CT a)`` where we use ``CT x`` to denote a cotangent type
|
||||
for ``x`` and the ``--o`` arrow to denote a linear function. See the example
|
||||
below. That is, ``fun`` should return a pair where the first element
|
||||
represents the value of the mathematical function to be differentiated and the
|
||||
second element is a function to be called on the backward pass of reverse-mode
|
||||
automatic differentiation (i.e. the "custom gradient" function).
|
||||
|
||||
The function returned as the second element of the output of ``fun`` can close
|
||||
over intermediate values computed when evaluating the function to be
|
||||
differentiated. That is, use lexical closure to share work between the forward
|
||||
pass and the backward pass of reverse-mode automatic differentiation. However,
|
||||
it cannot support Python control flow.
|
||||
|
||||
Args:
|
||||
fun: a Python callable specifying both the mathematical function to be
|
||||
differentiated and its reverse-mode differentiation rule. It should return
|
||||
a pair consisting of an output value and a Python callable that represents
|
||||
the custom gradient function.
|
||||
|
||||
Returns:
|
||||
A Python callable that accepts the same arguments as ``fun`` and returns the
|
||||
output value specified by the first element of ``fun``'s output pair.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_gradient
|
||||
... def f(x):
|
||||
... return x ** 2, lambda g: (g * x,)
|
||||
...
|
||||
>>> print(f(3.))
|
||||
9.0
|
||||
>>> print(jax.grad(f)(3.))
|
||||
3.0
|
||||
|
||||
An example with a function on two arguments, so that the VJP function must
|
||||
return a tuple of length two:
|
||||
|
||||
>>> @jax.custom_gradient
|
||||
... def f(x, y):
|
||||
... return x * y, lambda g: (y, x)
|
||||
...
|
||||
>>> print(f(3., 4.))
|
||||
12.0
|
||||
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
|
||||
(4.0, 3.0)
|
||||
"""
|
||||
@custom_vjp
|
||||
def wrapped_fun(*args, **kwargs):
|
||||
ans, _ = fun(*args, **kwargs)
|
||||
return ans
|
||||
|
||||
def fwd(*args, **kwargs):
|
||||
ans, rule = fun(*args, **kwargs)
|
||||
ans_flat, out_tree = tree_flatten((ans,))
|
||||
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
|
||||
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
|
||||
else:
|
||||
ans_pvals = [pe.PartialVal.unknown(a) for a in ans_avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(rule, ans_pvals, instantiate=True)
|
||||
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
|
||||
|
||||
def bwd(res, cts):
|
||||
jaxpr, in_tree, out_tree, consts = res
|
||||
cts_flat, out_tree_ = tree_flatten((cts,))
|
||||
if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}')
|
||||
cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
|
||||
return tree_unflatten(in_tree, cts_out)
|
||||
|
||||
wrapped_fun.defvjp(fwd, bwd)
|
||||
return wrapped_fun
|
||||
|
||||
@register_pytree_node_class
|
||||
class Residuals:
|
||||
def __init__(self, jaxpr, in_tree, out_tree, consts):
|
||||
self.jaxpr = jaxpr
|
||||
self.in_tree = in_tree
|
||||
self.out_tree = out_tree
|
||||
self.consts = consts
|
||||
def __iter__(self):
|
||||
return iter((self.jaxpr, self.in_tree, self.out_tree, self.consts))
|
||||
def tree_flatten(self):
|
||||
return self.consts, (self.jaxpr, self.in_tree, self.out_tree)
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux, consts):
|
||||
jaxpr, in_tree, out_tree = aux
|
||||
return cls(jaxpr, in_tree, out_tree, consts)
|
||||
|
@ -4203,6 +4203,37 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
expected = 2 * jnp.cos(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_custom_gradient(self):
|
||||
@api.custom_gradient
|
||||
def f(x):
|
||||
return x ** 2, lambda g: (g * x,)
|
||||
|
||||
self.assertAllClose(f(3.), 9., check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
|
||||
self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False)
|
||||
|
||||
def test_custom_gradient_2(self):
|
||||
@api.custom_gradient
|
||||
def f(x, y):
|
||||
return x * y, lambda g: (y, x)
|
||||
|
||||
self.assertAllClose(f(3., 4.), 12., check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.),
|
||||
check_dtypes=False)
|
||||
|
||||
def test_custom_gradient_3(self):
|
||||
@api.custom_gradient
|
||||
def f(x):
|
||||
vjp = lambda g: (jnp.cos(x) * jnp.array([3., 4., 5.]),)
|
||||
return jnp.sum(jnp.sin(x)), vjp
|
||||
|
||||
self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(
|
||||
api.grad(f)(jnp.arange(3.)),
|
||||
api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.array([3., 4., 5.]),
|
||||
check_dtypes=False)
|
||||
|
||||
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
@ -4473,14 +4504,6 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
||||
expected = 2.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_custom_gradient(self):
|
||||
@api.custom_gradient
|
||||
def f(x):
|
||||
return x ** 2, lambda g: (g * x,)
|
||||
|
||||
self.assertAllClose(f(3.), 9., check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
|
||||
|
||||
def test_custom_vjp_zeros(self):
|
||||
@api.custom_transforms
|
||||
def f(x, y):
|
||||
@ -4517,6 +4540,7 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
||||
b = jnp.dot(a + jnp.eye(a.shape[0]), real_x)
|
||||
print(gf(a, b)) # doesn't crash
|
||||
|
||||
|
||||
class BufferDonationTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
||||
|
Loading…
x
Reference in New Issue
Block a user