rocm_jax/jax/api.py

439 lines
17 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
User-facing transformations.
These mostly wrap internal transformations, providing convenience flags to
control behavior and handling Python containers (tuples/lists/dicts) of
arguments and outputs.
"""
2018-11-17 18:03:33 -08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
2018-11-17 18:03:33 -08:00
import itertools
import operator as op
2018-11-17 18:03:33 -08:00
import numpy as onp
from . import core
from . import linear_util as lu
from .core import pack, eval_jaxpr
from .api_util import (pytree_fun_to_jaxtupletree_fun, apply_jaxtree_fun,
pytree_to_jaxtupletree, wraps)
from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef,
tree_map, tree_flatten, tree_unflatten, tree_structure,
tree_transpose)
from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
WrapHashably, prod)
from .lib.xla_bridge import canonicalize_dtype
2018-11-17 18:03:33 -08:00
from .abstract_arrays import ShapedArray
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import ad
from .interpreters import batching
from .interpreters import parallel
2018-11-17 18:03:33 -08:00
map = safe_map
zip = safe_zip
2018-11-17 18:03:33 -08:00
def jit(fun, static_argnums=(), **params):
"""Sets up `fun` for just-in-time compilation with XLA.
Args:
fun: Function to be jitted. Should be a pure function, as side-effects may
2018-12-20 10:09:34 -08:00
only be executed once. Its positional arguments and return value should be
arrays, scalars, or standard Python containers (tuple/list/dict) thereof.
Keyword arguments and positional arguments specified by `static_argnums`
can be anything at all. These are treated as static (see below).
static_argnums: A tuple of ints. Specifies which arguments to treat as
2018-12-20 10:09:34 -08:00
static (compile-time constant). Operations that only depend on static
arguments will be constant-folded. Calling the jitted function with
different values for these constants will trigger recompilation.
Returns:
A wrapped version of `fun`, set up for just-in-time compilation.
"""
@wraps(fun)
2018-11-17 18:03:33 -08:00
def f_jitted(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f, dyn_args = argnums_partial(f, dyn_argnums, args)
args_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, dyn_args))
2018-11-17 18:03:33 -08:00
check_args(args_flat)
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
out_flat = xla.xla_call(jaxtree_fun, *args_flat, **params)
2018-11-17 18:03:33 -08:00
return build_tree(out_tree(), out_flat)
f_jitted.__name__ = "jit({})".format(f_jitted.__name__)
2018-11-17 18:03:33 -08:00
return f_jitted
2018-12-20 10:09:34 -08:00
2018-11-17 18:03:33 -08:00
def grad(fun, argnums=0):
"""Creates a function which evaluates the gradient of `fun`.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
2018-12-20 10:09:34 -08:00
`argnums` should be arrays, scalars, or standard Python containers. It
should return a scalar (which includes arrays with shape `()` but not
arrays with shape `(1,)` etc.)
argnums: Optional, integer or tuple of integers. Specifies which positional
argument(s) to differentiate with respect to (default 0).
2018-12-20 10:09:34 -08:00
Returns:
A function with the same arguments as `fun`, that evaluates the gradient of
`fun`. If `argnums` is an integer then the gradient has the same shape and
type as the positional argument indicated by that integer. If argnums is a
tuple of integers, the gradient is a tuple of values with the same shapes
and types as the corresponding arguments.
"""
2018-12-20 10:09:34 -08:00
value_and_grad_f = value_and_grad(fun, argnums)
docstr = ("Gradient of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"gradient, which has the same shape as the arguments at "
"positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
2018-11-17 18:03:33 -08:00
def grad_f(*args, **kwargs):
2018-12-20 10:09:34 -08:00
ans, g = value_and_grad_f(*args, **kwargs)
return g
return grad_f
def value_and_grad(fun, argnums=0):
"""Creates a function which evaluates both `fun` and the gradient of `fun`.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
`argnums` should be arrays, scalars, or standard Python containers. It
should return a scalar (which includes arrays with shape `()` but not
arrays with shape `(1,)` etc.)
argnums: Optional, integer or tuple of integers. Specifies which positional
argument(s) to differentiate with respect to (default 0).
2018-12-20 10:09:34 -08:00
Returns:
A function with the same arguments as `fun` that evaluates both `fun` and
the gradient of `fun` and returns them as a pair (a two-element tuple). If
`argnums` is an integer then the gradient has the same shape and type as the
positional argument indicated by that integer. If argnums is a tuple of
integers, the gradient is a tuple of values with the same shapes and types
as the corresponding arguments.
"""
docstr = ("Value and gradient of {fun} with respect to positional "
"argument(s) {argnums}. Takes the same arguments as {fun} but "
"returns a two-element tuple where the first element is the value "
"of {fun} and the second element is the gradient, which has the "
"same shape as the arguments at positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
2018-12-20 10:09:34 -08:00
def value_and_grad_f(*args, **kwargs):
2018-11-17 18:03:33 -08:00
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
ans, vjp_py = vjp(f_partial, *dyn_args)
check_scalar(ans)
g = vjp_py(onp.ones((), onp.result_type(ans)))
2018-12-20 10:09:34 -08:00
g = g[0] if isinstance(argnums, int) else g
return (ans, g)
2018-11-17 18:03:33 -08:00
2018-12-20 10:09:34 -08:00
return value_and_grad_f
2018-11-17 18:03:33 -08:00
def jacfwd(fun, argnums=0):
"""Jacobian of `fun` evaluated column-by-column using forward-mode AD."""
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
pushfwd = partial(jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
return jacfun
def jacrev(fun, argnums=0):
"""Jacobian of `fun` evaluated row-by-row using reverse-mode AD."""
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
y, pullback = vjp(f_partial, *dyn_args)
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
return jacfun
2019-01-07 08:56:19 -08:00
jacobian = jacrev
2018-11-17 18:03:33 -08:00
def hessian(fun):
return jacfwd(jacrev(fun))
2018-11-17 18:03:33 -08:00
def _std_basis(pytree):
leaves, _ = tree_flatten(pytree)
ndim = sum(map(onp.size, leaves))
return _unravel_array_into_pytree(pytree, 1, onp.eye(ndim))
def _unravel_array_into_pytree(pytree, axis, arr):
leaves, treedef = tree_flatten(pytree)
axis = axis % arr.ndim
dtypes = map(_dtype, leaves)
shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves]
parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis)
reshaped_parts = [onp.reshape(part.astype(dtype), shape)
for part, dtype, shape in zip(parts, dtypes, shapes)]
return tree_unflatten(treedef, reshaped_parts)
def _split(x, indices, axis):
if isinstance(x, onp.ndarray):
return onp.split(x, indices, axis)
else:
return x.split(indices, axis)
def _dtype(x):
return canonicalize_dtype(onp.result_type(x))
def vmap(fun, in_axes=0, out_axes=0):
"""Vectorizing map. Creates a function which maps `fun` over additional axes.
Args:
fun: Function to be mapped over additional axes.
in_axes, out_axes: Specifies which axes to map over. These may be integers,
2018-12-20 10:09:34 -08:00
None, or (possibly nested) tuples of integers or None.
Returns:
Batched/vectorized version of `fun` with arguments that correspond to those
of `fun`, but with extra array axes at positions indicated by `in_axes`, and
a return value that corresponds to that of `fun`, but with extra array axes
at positions indicated by `out_axes`.
For example, we can implement a matrix-matrix product using a vector dot
product:
>>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0) # ([a,b], [b]) -> [a]
>>> mm = vmap(mv, (None, 1), 1) # ([a,b], [b,c]) -> [a,c]
(`[a,b]` indicates an array with shape (a,b))
"""
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
"but with additional array axes over which {fun} is mapped.")
2019-01-18 07:03:12 -08:00
if (not isinstance(in_axes, (list, tuple, type(None), int))
or not isinstance(out_axes, (list, tuple, type(None), int))):
msg = ("vmap arguments in_axes and out_axes must each be an integer, None, "
"or a (nested) tuple of those types, got {} and {} respectively.")
raise TypeError(msg.format(type(in_axes), type(out_axes)))
@wraps(fun, docstr=docstr)
def batched_fun(*args, **kwargs):
if not isinstance(fun, lu.WrappedFun):
f = lu.wrap_init(fun, kwargs)
2019-01-18 07:03:12 -08:00
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
out_flat = batching.batch(jaxtree_fun, in_flat, in_axes_, out_axes)
return build_tree(out_tree(), out_flat)
return batched_fun
2018-11-17 18:03:33 -08:00
def pmap(fun, axis_name, in_axes=0, out_axes=0):
"""Vectorizing pseudo-map for single-program multiple-data (SPMD) functions."""
def pmap_fun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
out_flat = parallel.pmap(jaxtree_fun, axis_name, args, in_axes_, out_axes)
return build_tree(out_tree(), out_flat)
return pmap_fun
2019-01-24 16:27:34 -08:00
def axisvar_split(fun, name, new_names):
2019-01-24 17:55:03 -08:00
"""Split axis variable names into new names in an SPMD function."""
2019-01-24 16:27:34 -08:00
def split_fun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
out_flat = parallel.axisvar_split(jaxtree_fun, name, new_names).call_wrapped(*args)
return build_tree(out_tree(), out_flat)
return split_fun
2019-01-24 17:55:03 -08:00
def chunk(fun, name, chunksize, in_axes=0, out_axes=0):
"""Stage SPMD primitives to first operate on chunks, then use collectives."""
temp_name = object()
def chunked_fun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
f = parallel.axisvar_split(f, name, (temp_name, name))
reshape = partial(parallel.reshape_axis, chunksize)
reshaped_args = map(reshape, in_axes_, args)
out_flat = parallel.pmap(f, temp_name, reshaped_args, in_axes_, out_axes)
return build_tree(out_tree(), out_flat)
return chunked_fun
def papply(fun, in_axes=0):
"""Apply a function using parallel computation by sharding inputs."""
axis_name = parallel.newvar()
def papply_fun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
args_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
out_flat = parallel.papply(jaxtree_fun, axis_name, args_flat, in_axes_)
return build_tree(out_tree(), out_flat)
return papply_fun, axis_name
2018-11-17 18:03:33 -08:00
def jvp(fun, primals, tangents):
def trim_arg(primal, tangent):
primal_jtuple, tree_def = pytree_to_jaxtupletree(primal)
tangent_jtuple, tree_def_2 = pytree_to_jaxtupletree(tangent)
2018-11-17 18:03:33 -08:00
assert tree_def == tree_def_2, (tree_def, tree_def_2)
return primal_jtuple, tangent_jtuple, tree_def
if not isinstance(fun, lu.WrappedFun):
fun = lu.wrap_init(fun)
ps_flat, ts_flat, in_trees = unzip3(map(trim_arg, primals, tangents))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
out_primal, out_tangent = ad.jvp(jaxtree_fun).call_wrapped(ps_flat, ts_flat)
2018-11-17 18:03:33 -08:00
return (build_tree(out_tree(), out_primal), build_tree(out_tree(), out_tangent))
def linearize(traceable, *primals):
fun = lu.wrap_init(traceable)
primals_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, primals))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
out_primal, out_pval, jaxpr, consts = ad.linearize(jaxtree_fun, *primals_flat)
2018-11-17 18:03:33 -08:00
out_tree = out_tree()
out_primal_py = build_tree(out_tree, out_primal)
lifted_jvp = partial(lift_linearized, jaxpr, consts, (in_trees, out_tree), out_pval)
return out_primal_py, lifted_jvp
def lift_linearized(jaxpr, consts, io_tree, out_pval, py_args):
def fun(*args):
primals = pack(args) # doesn't matter what these are-they'll be ignored
tangents = pack(args)
_, ans = eval_jaxpr(jaxpr, consts, (), primals, tangents)
return pe.merge_pvals(ans, out_pval)
2018-11-17 18:03:33 -08:00
return apply_jaxtree_fun(fun, io_tree, *py_args)
2018-11-17 18:03:33 -08:00
def vjp(fun, *primals):
if not isinstance(fun, lu.WrappedFun):
fun = lu.wrap_init(fun)
primals_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, primals))
2018-11-17 18:03:33 -08:00
check_args(primals_flat)
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)
2018-11-17 18:03:33 -08:00
out_tree = out_tree()
out_primal_py = build_tree(out_tree, out_primal)
ct_in_trees = [out_tree]
ct_out_tree = PyTreeDef(node_types[tuple], None, in_trees)
def out_vjp_packed(cotangent_in):
return out_vjp(cotangent_in)
vjp_py = partial(apply_jaxtree_fun, out_vjp_packed, (ct_in_trees, ct_out_tree))
2018-11-17 18:03:33 -08:00
return out_primal_py, vjp_py
def trace_to_jaxpr(traceable, py_pvals, **kwargs):
fun = lu.wrap_init(traceable)
pvals, in_trees = unzip2(map(tree_to_pval_tuples, py_pvals))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
2018-11-17 18:03:33 -08:00
return jaxpr, consts, out_pval, (in_trees, out_tree())
def lift_jaxpr(jaxpr, consts, io_tree, pvals, py_args):
def fun(*args):
ans = eval_jaxpr(jaxpr, consts, (), *args)
return pe.merge_pvals(ans, pvals)
return apply_jaxtree_fun(fun, io_tree, *py_args)
2018-11-17 18:03:33 -08:00
def make_jaxpr(f):
def pv_like(x):
aval = xla.abstractify(x)
return pe.PartialVal((aval, core.unit))
fun = lu.wrap_init(f)
@wraps(f)
def jaxpr_maker(*args, **kwargs):
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
pvals = map(pv_like, jax_args)
jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
return jaxpr
jaxpr_maker.__name__ = "make_jaxpr({})".format(jaxpr_maker.__name__)
return jaxpr_maker
tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals)
2018-11-17 18:03:33 -08:00
device_put = jit(lambda x: x)
device_get_array = lambda x: x.copy() if type(x) is xla.DeviceArray else x
device_get = partial(tree_map, device_get_array)
2018-11-17 18:03:33 -08:00
def argnums_partial(f, dyn_argnums, args):
if isinstance(dyn_argnums, int):
dyn_argnums = (dyn_argnums,)
else:
dyn_argnums = tuple(dyn_argnums)
fixed_args = tuple([None if i in dyn_argnums else WrapHashably(arg)
for i, arg in enumerate(args)])
dyn_args = tuple(args[i] for i in dyn_argnums)
2018-11-17 18:03:33 -08:00
return argnums_partial_(f, dyn_argnums, fixed_args), dyn_args
@lu.transformation
def argnums_partial_(dyn_argnums, fixed_args, *dyn_args):
args = [None if arg is None else arg.val for arg in fixed_args]
2018-11-17 18:03:33 -08:00
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
ans = yield args
yield ans
def check_args(args):
for arg in args:
if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(arg, type(arg)))
def check_scalar(x):
msg = "Gradient only defined for scalar-output functions. Output was: {}".format
try:
aval = core.get_aval(x)
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
raise TypeError(msg(x))
except TypeError:
raise TypeError(msg(x))