generalize jacfwd and jacrev to handle pytrees

This commit is contained in:
Matthew Johnson 2019-01-06 11:59:33 -08:00
parent ad4322c5da
commit 0f7c7c4eab
10 changed files with 168 additions and 60 deletions

View File

@ -157,9 +157,14 @@ def make_shaped_array(x):
dtype = xla_bridge.canonicalize_dtype(onp.result_type(x))
return ShapedArray(onp.shape(x), dtype)
def zeros_like_array(x):
dtype = xla_bridge.canonicalize_dtype(onp.result_type(x))
return onp.broadcast_to(onp.array(0, dtype), onp.shape(x))
array_types = [onp.ndarray, onp.float64, onp.float32, onp.complex64,
onp.int64, onp.int32, onp.bool_, onp.uint64, onp.uint32,
complex, float, int, bool]
for t in array_types:
core.pytype_aval_mappings[t] = ConcreteArray
ad_util.jaxval_zeros_likers[t] = zeros_like_array

View File

@ -25,6 +25,7 @@ from __future__ import division
from __future__ import print_function
import itertools
import operator as op
import numpy as onp
@ -33,10 +34,12 @@ from . import linear_util as lu
from .core import pack, eval_jaxpr
from .api_util import (pytree_fun_to_jaxtupletree_fun, apply_jaxtree_fun,
pytree_to_jaxtupletree, wraps)
from .flatten_util import ravel_fun, ravel_pytree
from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef, leaf,
tree_map)
from .util import unzip2, unzip3, curry, partial, safe_map, WrapHashably
from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef,
tree_map, tree_flatten, tree_unflatten, tree_structure,
tree_transpose)
from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
WrapHashably, prod)
from .lib.xla_bridge import canonicalize_dtype
from .abstract_arrays import ShapedArray
from .interpreters import partial_eval as pe
from .interpreters import xla
@ -44,6 +47,7 @@ from .interpreters import ad
from .interpreters import batching
map = safe_map
zip = safe_zip
def jit(fun, static_argnums=()):
@ -98,6 +102,12 @@ def grad(fun, argnums=0):
"""
value_and_grad_f = value_and_grad(fun, argnums)
docstr = ("Gradient of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"gradient, which has the same shape as the arguments at "
"positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
def grad_f(*args, **kwargs):
ans, g = value_and_grad_f(*args, **kwargs)
return g
@ -123,6 +133,14 @@ def value_and_grad(fun, argnums=0):
integers, the gradient is a tuple of values with the same shapes and types
as the corresponding arguments.
"""
docstr = ("Value and gradient of {fun} with respect to positional "
"argument(s) {argnums}. Takes the same arguments as {fun} but "
"returns a two-element tuple where the first element is the value "
"of {fun} and the second element is the gradient, which has the "
"same shape as the arguments at positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
def value_and_grad_f(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
@ -134,41 +152,61 @@ def value_and_grad(fun, argnums=0):
return value_and_grad_f
@curry
def jacfwd(fun, x):
"""Jacobian of `fun`, evaluated column-by-column using forward-mode AD"""
if not isinstance(fun, lu.WrappedFun):
fun = lu.wrap_init(fun)
pushfwd = partial(jvp, fun, (x,))
std_basis = onp.eye(onp.size(x)).reshape((-1,) + onp.shape(x)),
y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis)
return jac_flat.reshape(onp.shape(y) + onp.shape(x))
@curry
def jacrev(fun, x):
"""Jacobian of `fun`, evaluated row-by-row using reverse-mode AD"""
if not isinstance(fun, lu.WrappedFun):
fun = lu.wrap_init(fun)
y, pullback = vjp(fun, x)
std_basis = onp.eye(onp.size(y)).reshape((-1,) + onp.shape(y))
jac_flat, = vmap(pullback, out_axes=0)(std_basis)
return jac_flat.reshape(onp.shape(y) + onp.shape(x))
def jacfwd(fun, argnums=0):
"""Jacobian of `fun` evaluated column-by-column using forward-mode AD."""
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
pushfwd = partial(jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
return jacfun
def jacrev(fun, argnums=0):
"""Jacobian of `fun` evaluated row-by-row using reverse-mode AD."""
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
y, pullback = vjp(f_partial, *dyn_args)
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
return jacfun
def hessian(fun):
return jacfwd(jacrev(fun))
def _std_basis(pytree):
leaves, _ = tree_flatten(pytree)
ndim = sum(map(onp.size, leaves))
return _unravel_array_into_pytree(pytree, 1, onp.eye(ndim))
def general_jacobian(jacfun, fun):
def jac_f(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
raveled_input, unravel_inputs = ravel_pytree(args)
raveled_fun, unravel_outputs = ravel_fun(f, unravel_inputs)
jacmat = jacfun(raveled_fun)(raveled_input)
return tree_map(unravel_inputs, vmap(unravel_outputs(), in_axes=1)(jacmat))
return jac_f
jacfwd2 = partial(general_jacobian, jacfwd)
jacrev2 = partial(general_jacobian, jacrev)
hessian2 = partial(general_jacobian, hessian) # TODO(mattjj): doesn't work yet
def _unravel_array_into_pytree(pytree, axis, arr):
leaves, treedef = tree_flatten(pytree)
axis = axis % arr.ndim
dtypes = map(_dtype, leaves)
shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves]
parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis)
reshaped_parts = [onp.reshape(part.astype(dtype), shape)
for part, dtype, shape in zip(parts, dtypes, shapes)]
return tree_unflatten(treedef, reshaped_parts)
def _split(x, indices, axis):
if isinstance(x, onp.ndarray):
return onp.split(x, indices, axis)
else:
return x.split(indices, axis)
def _dtype(x):
return canonicalize_dtype(onp.result_type(x))
def vmap(fun, in_axes=0, out_axes=0):
@ -194,6 +232,11 @@ def vmap(fun, in_axes=0, out_axes=0):
(`[a,b]` indicates an array with shape (a,b))
"""
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
"but with additional array axes over which {fun} is mapped.")
@wraps(fun, docstr=docstr)
def batched_fun(*args, **kwargs):
if not isinstance(fun, lu.WrappedFun):
f = lu.wrap_init(fun)
@ -301,7 +344,7 @@ def argnums_partial(f, dyn_argnums, args):
dyn_argnums = tuple(dyn_argnums)
fixed_args = tuple([None if i in dyn_argnums else WrapHashably(arg)
for i, arg in enumerate(args)])
dyn_args = [args[i] for i in dyn_argnums]
dyn_args = tuple(args[i] for i in dyn_argnums)
return argnums_partial_(f, dyn_argnums, fixed_args), dyn_args
@lu.transformation

View File

@ -25,12 +25,17 @@ map = safe_map
@curry
def wraps(wrapped, wrapper):
wrapper.__name__ = getattr(wrapped, "__name__", "<unnamed function>")
wrapper.__module__ = getattr(wrapped, "__module__", "<unknown module>")
if hasattr(wrapped, "__doc__"):
wrapper.__doc__ = getattr(wrapped, "__doc__")
return wrapper
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
try:
fun.__name__ = namestr.format(fun=get_name(wrapped))
fun.__module__ = get_module(wrapped)
fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs)
finally:
return fun
def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
def get_doc(fun): return getattr(fun, "__doc__", "")
@transformation_with_aux

View File

@ -18,17 +18,21 @@ from __future__ import print_function
from .tree_util import tree_flatten, tree_unflatten
from .linear_util import transformation_with_aux
from .util import safe_zip
import jax.numpy as np
from jax.api import vjp
zip = safe_zip
def ravel_pytree(pytree):
from jax.api import vjp # TODO(mattjj): fix circular imports
leaves, treedef = tree_flatten(pytree)
flat, unravel_list = vjp(_ravel_list, *leaves)
flat, unravel_list = vjp(ravel_list, *leaves)
unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
return flat, unravel_pytree
def _ravel_list(*lst):
import jax.numpy as np # TODO(mattjj): fix circular imports
def ravel_list(*lst):
return np.concatenate([np.ravel(elt) for elt in lst]) if lst else np.array([])

View File

@ -295,7 +295,7 @@ def moveaxis(sz, dst, src, x):
return x
else:
if src is None:
x = broadcast(x, sz)
x = broadcast(x, sz, force_broadcast=True)
src = 0
if src == dst:
return x
@ -306,21 +306,20 @@ def moveaxis(sz, dst, src, x):
else:
raise TypeError(type(aval))
def broadcast(x, sz):
def broadcast(x, sz, force_broadcast=False):
aval = get_aval(x)
if type(aval) is AbstractTuple:
return pack(map(partial(broadcast, sz=sz), x))
elif isinstance(aval, ShapedArray):
# for scalars, don't actually broadcast
if not onp.ndim(x):
# for scalars, maybe don't actually broadcast
if not onp.ndim(x) and not force_broadcast:
return x
# See comment at the top of this section about this try/except.
try:
return x.broadcast((sz,))
except AttributeError:
assert not isinstance(x, Tracer)
# see comment at the top of this section
if isinstance(x, onp.ndarray) or onp.isscalar(x):
return onp.broadcast_to(x, (sz,) + onp.shape(x))
else:
return x.broadcast((sz,)) # should be a JAX arraylike
else:
raise TypeError(type(x))

View File

@ -17,9 +17,12 @@ from __future__ import division
from __future__ import print_function
from collections import namedtuple, defaultdict
from distutils.util import strtobool
import itertools as it
import numpy as onp
import operator as op
import os
import numpy as onp
import six
from six.moves import xrange
@ -34,7 +37,9 @@ from ..lib import xla_bridge as xb
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_device_values', True, 'Enable device-persistent values.')
flags.DEFINE_bool('jax_device_values',
strtobool(os.getenv('JAX_DEVICE_VALUES', "True")),
'Enable device-persistent values.')
map = safe_map

View File

@ -679,7 +679,7 @@ def zeros_like_array(x):
for t in itertools.chain(array_types, [xla.DeviceArray]):
ad_util.jaxval_adders[t] = add
ad_util.jaxval_zeros_likers[t] = zeros_like_array
ad_util.jaxval_zeros_likers[xla.DeviceArray] = zeros_like_array
batching.pytype_aval_mappings[xla.DeviceArray] = make_shaped_array

View File

@ -1724,3 +1724,4 @@ setattr(DeviceArray, "astype", lax.convert_element_type)
# Extra methods that are handy
setattr(DeviceArray, "broadcast", lax.broadcast)
setattr(DeviceArray, "split", split)

View File

@ -102,8 +102,34 @@ def _tree_unflatten(xs, treedef):
return treedef.node_type.from_iterable(treedef.node_data, children)
def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
flat, treedef = tree_flatten(pytree_to_transpose)
expected_treedef = _nested_treedef(inner_treedef, outer_treedef)
if treedef != expected_treedef:
raise TypeError("Mismatch\n{}\n != \n{}".format(treedef, expected_treedef))
inner_size = _num_leaves(inner_treedef)
outer_size = _num_leaves(outer_treedef)
flat = iter(flat)
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
transposed_lol = zip(*lol)
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
return tree_unflatten(inner_treedef, subtrees)
def _num_leaves(treedef):
return 1 if treedef is leaf else sum(map(_num_leaves, treedef.children))
def _nested_treedef(inner, outer):
# just used in tree_transpose error checking
if outer is leaf:
return inner
else:
children = map(partial(_nested_treedef, inner), outer.children)
return PyTreeDef(outer.node_type, outer.node_data, tuple(children))
def tree_structure(tree):
spec, _ = process_pytree(tree, lambda _: None)
_, spec = process_pytree(lambda _: None, tree)
return spec
@ -126,9 +152,12 @@ class PyTreeDef(object):
return hash((self.node_type, self.node_data, tuple(self.children)))
def __eq__(self, other):
return (self.node_type == other.node_type and
self.node_data == other.node_data and
self.children == other.children)
if other is leaf:
return False
else:
return (self.node_type == other.node_type and
self.node_data == other.node_data and
self.children == other.children)
def __ne__(self, other):
return not self == other

View File

@ -24,6 +24,7 @@ from jax import test_util as jtu
import jax.numpy as np
from jax import jit, grad, device_get, device_put, jacfwd, jacrev
from jax import api
from jax.core import Primitive
from jax.interpreters.partial_eval import def_abstract_eval
from jax.interpreters.ad import defjvp
@ -259,6 +260,22 @@ class APITest(jtu.JaxTestCase):
f = lambda x: np.tanh(np.dot(A, x))
assert onp.allclose(jacfwd(f)(x), jacrev(f)(x))
def test_std_basis(self):
basis = api._std_basis(np.zeros(3))
assert getattr(basis, "shape", None) == (3, 3)
assert onp.allclose(basis, onp.eye(3))
basis = api._std_basis(np.zeros((3, 3)))
assert getattr(basis, "shape", None) == (9, 3, 3)
assert onp.allclose(basis, onp.eye(9).reshape(9, 3, 3))
basis = api._std_basis([0., (np.zeros(3), np.zeros((3, 4)))])
assert isinstance(basis, list) and len(basis) == 2
assert getattr(basis[0], "shape", None) == (16,)
assert isinstance(basis[1], tuple) and len(basis[1]) == 2
assert getattr(basis[1][0], "shape", None) == (16, 3)
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
if __name__ == '__main__':
absltest.main()