mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
412b9d5209
commit
503e5973ce
12
jax/api.py
12
jax/api.py
@ -45,7 +45,7 @@ from .api_util import (wraps, flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
|
||||
donation_vector, rebase_donate_argnums)
|
||||
from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
tree_transpose, tree_leaves, tree_multimap,
|
||||
treedef_is_leaf)
|
||||
treedef_is_leaf, Partial)
|
||||
from .util import (unzip2, curry, partial, safe_map, safe_zip, prod,
|
||||
split_list, extend_name_stack, wrap_name)
|
||||
from .lib import xla_bridge as xb
|
||||
@ -1488,7 +1488,7 @@ def _check_inexact_input_vjp(x):
|
||||
"or complex type, got type {}")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
|
||||
def _vjp_pullback_wrapper(fun, cotangent_dtypes, io_tree, py_args):
|
||||
def _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args):
|
||||
in_tree_expected, out_tree = io_tree
|
||||
args, in_tree = tree_flatten(py_args)
|
||||
if in_tree != in_tree_expected:
|
||||
@ -1563,8 +1563,12 @@ def _vjp(fun: lu.WrappedFun, *primals, **kwargs):
|
||||
out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
|
||||
out_tree, aux_tree = out_aux_trees()
|
||||
out_primal_py = tree_unflatten(out_tree, out_primal)
|
||||
vjp_py = partial(_vjp_pullback_wrapper, out_vjp,
|
||||
[_dtype(x) for x in out_primal], (out_tree, in_tree))
|
||||
# Ensure that vjp_py is a PyTree so that we can pass it from the forward to the
|
||||
# backward pass in a custom VJP.
|
||||
vjp_py = Partial(partial(_vjp_pullback_wrapper,
|
||||
[_dtype(x) for x in out_primal],
|
||||
(out_tree, in_tree)),
|
||||
out_vjp)
|
||||
if not has_aux:
|
||||
return out_primal_py, vjp_py
|
||||
else:
|
||||
|
@ -27,7 +27,7 @@ from ..util import unzip2, safe_map, safe_zip, partial, split_list, wrap_name
|
||||
from ..tree_util import register_pytree_node
|
||||
from .. import linear_util as lu
|
||||
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from ..tree_util import tree_flatten, tree_unflatten
|
||||
from ..tree_util import tree_flatten, tree_unflatten, Partial
|
||||
from .. import source_info_util
|
||||
|
||||
zip = safe_zip
|
||||
@ -109,12 +109,16 @@ def vjp(traceable, primals, has_aux=False):
|
||||
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
|
||||
else:
|
||||
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
|
||||
def vjp_(*cts):
|
||||
|
||||
def unbound_vjp(pvals, jaxpr, consts, *cts):
|
||||
cts = tuple(map(ignore_consts, cts, pvals))
|
||||
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
|
||||
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
|
||||
return map(instantiate_zeros, arg_cts)
|
||||
|
||||
# Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward
|
||||
# pass in a custom VJP.
|
||||
vjp_ = Partial(partial(unbound_vjp, pvals, jaxpr), consts)
|
||||
if not has_aux:
|
||||
return out_primals, vjp_
|
||||
else:
|
||||
|
@ -2863,6 +2863,38 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
jax.grad(clip_gradient)(1.) # doesn't crash
|
||||
|
||||
def test_nestable_vjp(self):
|
||||
# Verify that https://github.com/google/jax/issues/3667 is resolved.
|
||||
def f(x):
|
||||
return x ** 2
|
||||
|
||||
@api.custom_vjp
|
||||
def g(x):
|
||||
return f(x)
|
||||
|
||||
def g_fwd(x):
|
||||
y, f_vjp = api.vjp(f, x)
|
||||
return y, f_vjp
|
||||
|
||||
def g_bwd(f_vjp, y_bar):
|
||||
return f_vjp(y_bar)
|
||||
|
||||
g.defvjp(g_fwd, g_bwd)
|
||||
|
||||
# Check that VJP can be nested in simple situations. For this to pass,
|
||||
# vjp has to return a PyTree.
|
||||
_, g_vjp = api.vjp(g, 1.0)
|
||||
y, = g_vjp(1.0)
|
||||
self.assertAllClose(y, jnp.array(2.0))
|
||||
|
||||
# Check that VJP can be nested in complex situations. For this to pass,
|
||||
# vjp can't treat the closed-over tracer x as a static argument.
|
||||
@jit
|
||||
def z(x):
|
||||
_, g_vjp = api.vjp(g, x)
|
||||
return g_vjp
|
||||
y, = z(1.0)(3.0)
|
||||
self.assertAllClose(y, jnp.array(6.0))
|
||||
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user