mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
generalize jacfwd and jacrev to handle pytrees
This commit is contained in:
parent
ad4322c5da
commit
0f7c7c4eab
@ -157,9 +157,14 @@ def make_shaped_array(x):
|
||||
dtype = xla_bridge.canonicalize_dtype(onp.result_type(x))
|
||||
return ShapedArray(onp.shape(x), dtype)
|
||||
|
||||
def zeros_like_array(x):
|
||||
dtype = xla_bridge.canonicalize_dtype(onp.result_type(x))
|
||||
return onp.broadcast_to(onp.array(0, dtype), onp.shape(x))
|
||||
|
||||
array_types = [onp.ndarray, onp.float64, onp.float32, onp.complex64,
|
||||
onp.int64, onp.int32, onp.bool_, onp.uint64, onp.uint32,
|
||||
complex, float, int, bool]
|
||||
|
||||
for t in array_types:
|
||||
core.pytype_aval_mappings[t] = ConcreteArray
|
||||
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
||||
|
111
jax/api.py
111
jax/api.py
@ -25,6 +25,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import operator as op
|
||||
|
||||
import numpy as onp
|
||||
|
||||
@ -33,10 +34,12 @@ from . import linear_util as lu
|
||||
from .core import pack, eval_jaxpr
|
||||
from .api_util import (pytree_fun_to_jaxtupletree_fun, apply_jaxtree_fun,
|
||||
pytree_to_jaxtupletree, wraps)
|
||||
from .flatten_util import ravel_fun, ravel_pytree
|
||||
from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef, leaf,
|
||||
tree_map)
|
||||
from .util import unzip2, unzip3, curry, partial, safe_map, WrapHashably
|
||||
from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef,
|
||||
tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
tree_transpose)
|
||||
from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
|
||||
WrapHashably, prod)
|
||||
from .lib.xla_bridge import canonicalize_dtype
|
||||
from .abstract_arrays import ShapedArray
|
||||
from .interpreters import partial_eval as pe
|
||||
from .interpreters import xla
|
||||
@ -44,6 +47,7 @@ from .interpreters import ad
|
||||
from .interpreters import batching
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
def jit(fun, static_argnums=()):
|
||||
@ -98,6 +102,12 @@ def grad(fun, argnums=0):
|
||||
"""
|
||||
value_and_grad_f = value_and_grad(fun, argnums)
|
||||
|
||||
docstr = ("Gradient of {fun} with respect to positional argument(s) "
|
||||
"{argnums}. Takes the same arguments as {fun} but returns the "
|
||||
"gradient, which has the same shape as the arguments at "
|
||||
"positions {argnums}.")
|
||||
|
||||
@wraps(fun, docstr=docstr, argnums=argnums)
|
||||
def grad_f(*args, **kwargs):
|
||||
ans, g = value_and_grad_f(*args, **kwargs)
|
||||
return g
|
||||
@ -123,6 +133,14 @@ def value_and_grad(fun, argnums=0):
|
||||
integers, the gradient is a tuple of values with the same shapes and types
|
||||
as the corresponding arguments.
|
||||
"""
|
||||
|
||||
docstr = ("Value and gradient of {fun} with respect to positional "
|
||||
"argument(s) {argnums}. Takes the same arguments as {fun} but "
|
||||
"returns a two-element tuple where the first element is the value "
|
||||
"of {fun} and the second element is the gradient, which has the "
|
||||
"same shape as the arguments at positions {argnums}.")
|
||||
|
||||
@wraps(fun, docstr=docstr, argnums=argnums)
|
||||
def value_and_grad_f(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
@ -134,41 +152,61 @@ def value_and_grad(fun, argnums=0):
|
||||
|
||||
return value_and_grad_f
|
||||
|
||||
@curry
|
||||
def jacfwd(fun, x):
|
||||
"""Jacobian of `fun`, evaluated column-by-column using forward-mode AD"""
|
||||
if not isinstance(fun, lu.WrappedFun):
|
||||
fun = lu.wrap_init(fun)
|
||||
pushfwd = partial(jvp, fun, (x,))
|
||||
std_basis = onp.eye(onp.size(x)).reshape((-1,) + onp.shape(x)),
|
||||
y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis)
|
||||
return jac_flat.reshape(onp.shape(y) + onp.shape(x))
|
||||
|
||||
@curry
|
||||
def jacrev(fun, x):
|
||||
"""Jacobian of `fun`, evaluated row-by-row using reverse-mode AD"""
|
||||
if not isinstance(fun, lu.WrappedFun):
|
||||
fun = lu.wrap_init(fun)
|
||||
y, pullback = vjp(fun, x)
|
||||
std_basis = onp.eye(onp.size(y)).reshape((-1,) + onp.shape(y))
|
||||
jac_flat, = vmap(pullback, out_axes=0)(std_basis)
|
||||
return jac_flat.reshape(onp.shape(y) + onp.shape(x))
|
||||
def jacfwd(fun, argnums=0):
|
||||
"""Jacobian of `fun` evaluated column-by-column using forward-mode AD."""
|
||||
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
pushfwd = partial(jvp, f_partial, dyn_args)
|
||||
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
|
||||
|
||||
return jacfun
|
||||
|
||||
def jacrev(fun, argnums=0):
|
||||
"""Jacobian of `fun` evaluated row-by-row using reverse-mode AD."""
|
||||
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
y, pullback = vjp(f_partial, *dyn_args)
|
||||
jac = vmap(pullback)(_std_basis(y))
|
||||
jac = jac[0] if isinstance(argnums, int) else jac
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
|
||||
return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
|
||||
|
||||
return jacfun
|
||||
|
||||
def hessian(fun):
|
||||
return jacfwd(jacrev(fun))
|
||||
|
||||
def _std_basis(pytree):
|
||||
leaves, _ = tree_flatten(pytree)
|
||||
ndim = sum(map(onp.size, leaves))
|
||||
return _unravel_array_into_pytree(pytree, 1, onp.eye(ndim))
|
||||
|
||||
def general_jacobian(jacfun, fun):
|
||||
def jac_f(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
raveled_input, unravel_inputs = ravel_pytree(args)
|
||||
raveled_fun, unravel_outputs = ravel_fun(f, unravel_inputs)
|
||||
jacmat = jacfun(raveled_fun)(raveled_input)
|
||||
return tree_map(unravel_inputs, vmap(unravel_outputs(), in_axes=1)(jacmat))
|
||||
return jac_f
|
||||
jacfwd2 = partial(general_jacobian, jacfwd)
|
||||
jacrev2 = partial(general_jacobian, jacrev)
|
||||
hessian2 = partial(general_jacobian, hessian) # TODO(mattjj): doesn't work yet
|
||||
def _unravel_array_into_pytree(pytree, axis, arr):
|
||||
leaves, treedef = tree_flatten(pytree)
|
||||
axis = axis % arr.ndim
|
||||
dtypes = map(_dtype, leaves)
|
||||
shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves]
|
||||
parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis)
|
||||
reshaped_parts = [onp.reshape(part.astype(dtype), shape)
|
||||
for part, dtype, shape in zip(parts, dtypes, shapes)]
|
||||
return tree_unflatten(treedef, reshaped_parts)
|
||||
|
||||
def _split(x, indices, axis):
|
||||
if isinstance(x, onp.ndarray):
|
||||
return onp.split(x, indices, axis)
|
||||
else:
|
||||
return x.split(indices, axis)
|
||||
|
||||
def _dtype(x):
|
||||
return canonicalize_dtype(onp.result_type(x))
|
||||
|
||||
|
||||
def vmap(fun, in_axes=0, out_axes=0):
|
||||
@ -194,6 +232,11 @@ def vmap(fun, in_axes=0, out_axes=0):
|
||||
|
||||
(`[a,b]` indicates an array with shape (a,b))
|
||||
"""
|
||||
|
||||
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
|
||||
"but with additional array axes over which {fun} is mapped.")
|
||||
|
||||
@wraps(fun, docstr=docstr)
|
||||
def batched_fun(*args, **kwargs):
|
||||
if not isinstance(fun, lu.WrappedFun):
|
||||
f = lu.wrap_init(fun)
|
||||
@ -301,7 +344,7 @@ def argnums_partial(f, dyn_argnums, args):
|
||||
dyn_argnums = tuple(dyn_argnums)
|
||||
fixed_args = tuple([None if i in dyn_argnums else WrapHashably(arg)
|
||||
for i, arg in enumerate(args)])
|
||||
dyn_args = [args[i] for i in dyn_argnums]
|
||||
dyn_args = tuple(args[i] for i in dyn_argnums)
|
||||
return argnums_partial_(f, dyn_argnums, fixed_args), dyn_args
|
||||
|
||||
@lu.transformation
|
||||
|
@ -25,12 +25,17 @@ map = safe_map
|
||||
|
||||
|
||||
@curry
|
||||
def wraps(wrapped, wrapper):
|
||||
wrapper.__name__ = getattr(wrapped, "__name__", "<unnamed function>")
|
||||
wrapper.__module__ = getattr(wrapped, "__module__", "<unknown module>")
|
||||
if hasattr(wrapped, "__doc__"):
|
||||
wrapper.__doc__ = getattr(wrapped, "__doc__")
|
||||
return wrapper
|
||||
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
|
||||
try:
|
||||
fun.__name__ = namestr.format(fun=get_name(wrapped))
|
||||
fun.__module__ = get_module(wrapped)
|
||||
fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs)
|
||||
finally:
|
||||
return fun
|
||||
|
||||
def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
|
||||
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
|
||||
def get_doc(fun): return getattr(fun, "__doc__", "")
|
||||
|
||||
|
||||
@transformation_with_aux
|
||||
|
@ -18,17 +18,21 @@ from __future__ import print_function
|
||||
|
||||
from .tree_util import tree_flatten, tree_unflatten
|
||||
from .linear_util import transformation_with_aux
|
||||
from .util import safe_zip
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.api import vjp
|
||||
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
def ravel_pytree(pytree):
|
||||
from jax.api import vjp # TODO(mattjj): fix circular imports
|
||||
leaves, treedef = tree_flatten(pytree)
|
||||
flat, unravel_list = vjp(_ravel_list, *leaves)
|
||||
flat, unravel_list = vjp(ravel_list, *leaves)
|
||||
unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
|
||||
return flat, unravel_pytree
|
||||
|
||||
def _ravel_list(*lst):
|
||||
import jax.numpy as np # TODO(mattjj): fix circular imports
|
||||
def ravel_list(*lst):
|
||||
return np.concatenate([np.ravel(elt) for elt in lst]) if lst else np.array([])
|
||||
|
||||
|
||||
|
@ -295,7 +295,7 @@ def moveaxis(sz, dst, src, x):
|
||||
return x
|
||||
else:
|
||||
if src is None:
|
||||
x = broadcast(x, sz)
|
||||
x = broadcast(x, sz, force_broadcast=True)
|
||||
src = 0
|
||||
if src == dst:
|
||||
return x
|
||||
@ -306,21 +306,20 @@ def moveaxis(sz, dst, src, x):
|
||||
else:
|
||||
raise TypeError(type(aval))
|
||||
|
||||
def broadcast(x, sz):
|
||||
def broadcast(x, sz, force_broadcast=False):
|
||||
aval = get_aval(x)
|
||||
if type(aval) is AbstractTuple:
|
||||
return pack(map(partial(broadcast, sz=sz), x))
|
||||
elif isinstance(aval, ShapedArray):
|
||||
# for scalars, don't actually broadcast
|
||||
if not onp.ndim(x):
|
||||
# for scalars, maybe don't actually broadcast
|
||||
if not onp.ndim(x) and not force_broadcast:
|
||||
return x
|
||||
|
||||
# See comment at the top of this section about this try/except.
|
||||
try:
|
||||
return x.broadcast((sz,))
|
||||
except AttributeError:
|
||||
assert not isinstance(x, Tracer)
|
||||
# see comment at the top of this section
|
||||
if isinstance(x, onp.ndarray) or onp.isscalar(x):
|
||||
return onp.broadcast_to(x, (sz,) + onp.shape(x))
|
||||
else:
|
||||
return x.broadcast((sz,)) # should be a JAX arraylike
|
||||
else:
|
||||
raise TypeError(type(x))
|
||||
|
||||
|
@ -17,9 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple, defaultdict
|
||||
from distutils.util import strtobool
|
||||
import itertools as it
|
||||
import numpy as onp
|
||||
import operator as op
|
||||
import os
|
||||
|
||||
import numpy as onp
|
||||
import six
|
||||
from six.moves import xrange
|
||||
|
||||
@ -34,7 +37,9 @@ from ..lib import xla_bridge as xb
|
||||
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool('jax_device_values', True, 'Enable device-persistent values.')
|
||||
flags.DEFINE_bool('jax_device_values',
|
||||
strtobool(os.getenv('JAX_DEVICE_VALUES', "True")),
|
||||
'Enable device-persistent values.')
|
||||
|
||||
map = safe_map
|
||||
|
||||
|
@ -679,7 +679,7 @@ def zeros_like_array(x):
|
||||
|
||||
for t in itertools.chain(array_types, [xla.DeviceArray]):
|
||||
ad_util.jaxval_adders[t] = add
|
||||
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
||||
ad_util.jaxval_zeros_likers[xla.DeviceArray] = zeros_like_array
|
||||
|
||||
batching.pytype_aval_mappings[xla.DeviceArray] = make_shaped_array
|
||||
|
||||
|
@ -1724,3 +1724,4 @@ setattr(DeviceArray, "astype", lax.convert_element_type)
|
||||
|
||||
# Extra methods that are handy
|
||||
setattr(DeviceArray, "broadcast", lax.broadcast)
|
||||
setattr(DeviceArray, "split", split)
|
||||
|
@ -102,8 +102,34 @@ def _tree_unflatten(xs, treedef):
|
||||
return treedef.node_type.from_iterable(treedef.node_data, children)
|
||||
|
||||
|
||||
def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
|
||||
flat, treedef = tree_flatten(pytree_to_transpose)
|
||||
expected_treedef = _nested_treedef(inner_treedef, outer_treedef)
|
||||
if treedef != expected_treedef:
|
||||
raise TypeError("Mismatch\n{}\n != \n{}".format(treedef, expected_treedef))
|
||||
|
||||
inner_size = _num_leaves(inner_treedef)
|
||||
outer_size = _num_leaves(outer_treedef)
|
||||
flat = iter(flat)
|
||||
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
|
||||
transposed_lol = zip(*lol)
|
||||
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
|
||||
return tree_unflatten(inner_treedef, subtrees)
|
||||
|
||||
def _num_leaves(treedef):
|
||||
return 1 if treedef is leaf else sum(map(_num_leaves, treedef.children))
|
||||
|
||||
def _nested_treedef(inner, outer):
|
||||
# just used in tree_transpose error checking
|
||||
if outer is leaf:
|
||||
return inner
|
||||
else:
|
||||
children = map(partial(_nested_treedef, inner), outer.children)
|
||||
return PyTreeDef(outer.node_type, outer.node_data, tuple(children))
|
||||
|
||||
|
||||
def tree_structure(tree):
|
||||
spec, _ = process_pytree(tree, lambda _: None)
|
||||
_, spec = process_pytree(lambda _: None, tree)
|
||||
return spec
|
||||
|
||||
|
||||
@ -126,9 +152,12 @@ class PyTreeDef(object):
|
||||
return hash((self.node_type, self.node_data, tuple(self.children)))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.node_type == other.node_type and
|
||||
self.node_data == other.node_data and
|
||||
self.children == other.children)
|
||||
if other is leaf:
|
||||
return False
|
||||
else:
|
||||
return (self.node_type == other.node_type and
|
||||
self.node_data == other.node_data and
|
||||
self.children == other.children)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
@ -24,6 +24,7 @@ from jax import test_util as jtu
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad, device_get, device_put, jacfwd, jacrev
|
||||
from jax import api
|
||||
from jax.core import Primitive
|
||||
from jax.interpreters.partial_eval import def_abstract_eval
|
||||
from jax.interpreters.ad import defjvp
|
||||
@ -259,6 +260,22 @@ class APITest(jtu.JaxTestCase):
|
||||
f = lambda x: np.tanh(np.dot(A, x))
|
||||
assert onp.allclose(jacfwd(f)(x), jacrev(f)(x))
|
||||
|
||||
def test_std_basis(self):
|
||||
basis = api._std_basis(np.zeros(3))
|
||||
assert getattr(basis, "shape", None) == (3, 3)
|
||||
assert onp.allclose(basis, onp.eye(3))
|
||||
|
||||
basis = api._std_basis(np.zeros((3, 3)))
|
||||
assert getattr(basis, "shape", None) == (9, 3, 3)
|
||||
assert onp.allclose(basis, onp.eye(9).reshape(9, 3, 3))
|
||||
|
||||
basis = api._std_basis([0., (np.zeros(3), np.zeros((3, 4)))])
|
||||
assert isinstance(basis, list) and len(basis) == 2
|
||||
assert getattr(basis[0], "shape", None) == (16,)
|
||||
assert isinstance(basis[1], tuple) and len(basis[1]) == 2
|
||||
assert getattr(basis[1][0], "shape", None) == (16, 3)
|
||||
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user