2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
2018-11-21 13:27:26 -08:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-08 13:17:55 -05:00
|
|
|
import functools
|
2019-02-23 20:34:14 -08:00
|
|
|
import itertools as it
|
2020-01-18 08:26:23 -05:00
|
|
|
from typing import Any
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from . import partial_eval as pe
|
|
|
|
from .. import core as core
|
2019-07-26 16:48:17 -04:00
|
|
|
from ..core import Trace, Tracer, new_master, get_aval, call_p, Primitive, Literal
|
2019-05-07 08:52:08 -07:00
|
|
|
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
|
2020-01-05 04:35:34 +01:00
|
|
|
zeros_like_p, zero)
|
2019-04-23 17:47:28 -07:00
|
|
|
from ..abstract_arrays import raise_to_shaped
|
2020-01-05 04:35:34 +01:00
|
|
|
from ..util import unzip2, safe_map, safe_zip, partial, split_list
|
|
|
|
from ..tree_util import register_pytree_node
|
|
|
|
from .. import linear_util as lu
|
2019-07-27 15:46:14 -07:00
|
|
|
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
2019-07-26 23:17:21 -04:00
|
|
|
from ..tree_util import tree_flatten, tree_unflatten
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
zip = safe_zip
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
def jvp(fun, has_aux=False, instantiate=True) -> Any:
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(jvp_subtrace(fun), instantiate)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, aux = jvp_subtrace_aux(fun)
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(fun, instantiate), aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-18 08:26:23 -05:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2019-04-01 16:03:56 -04:00
|
|
|
def jvpfun(instantiate, primals, tangents):
|
2018-11-17 18:03:33 -08:00
|
|
|
with new_master(JVPTrace) as master:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, out_tangents = yield (master, primals, tangents), {}
|
2018-11-17 18:03:33 -08:00
|
|
|
del master
|
2019-07-27 15:46:14 -07:00
|
|
|
if type(instantiate) is bool:
|
|
|
|
instantiate = [instantiate] * len(out_tangents)
|
|
|
|
out_tangents = [instantiate_zeros(x, t) if inst else t for x, t, inst
|
|
|
|
in zip(out_primals, out_tangents, instantiate)]
|
|
|
|
yield out_primals, out_tangents
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2018-11-17 18:03:33 -08:00
|
|
|
def jvp_subtrace(master, primals, tangents):
|
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
assert x.trace.level < trace.level
|
2019-09-09 17:47:15 -07:00
|
|
|
in_tracers = [JVPTracer(trace, x, t) if t is not zero else x
|
|
|
|
for x, t in zip(primals, tangents)]
|
|
|
|
ans = yield in_tracers, {}
|
2019-07-26 23:17:21 -04:00
|
|
|
out_tracers = map(trace.full_raise, ans)
|
|
|
|
yield unzip2([(out_tracer.primal, out_tracer.tangent)
|
|
|
|
for out_tracer in out_tracers])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def jvp_subtrace_aux(master, primals, tangents):
|
2019-03-07 14:08:02 -08:00
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
assert x.trace.level < trace.level
|
2019-04-10 22:09:14 -07:00
|
|
|
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
|
2019-07-27 15:46:14 -07:00
|
|
|
ans_tracers = map(trace.full_raise, ans)
|
|
|
|
aux_tracers = map(trace.full_raise, aux)
|
|
|
|
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
|
|
|
|
aux_primals, _ = unzip2((t.primal, t.tangent) for t in aux_tracers)
|
2020-01-06 18:08:00 -08:00
|
|
|
aux_primals = map(core.full_lower, aux_primals)
|
2019-07-27 15:46:14 -07:00
|
|
|
yield (out_primals, out_tangents), aux_primals
|
2019-03-07 14:08:02 -08:00
|
|
|
|
|
|
|
def linearize(traceable, *primals, **kwargs):
|
|
|
|
has_aux = kwargs.pop('has_aux', False)
|
|
|
|
if not has_aux:
|
2019-07-26 23:17:21 -04:00
|
|
|
jvpfun = jvp(traceable)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
|
|
|
jvpfun, aux = jvp(traceable, has_aux=True)
|
2019-07-26 23:17:21 -04:00
|
|
|
|
|
|
|
in_pvals = (tuple(pe.PartialVal((None, p)) for p in primals)
|
|
|
|
+ tuple(pe.PartialVal((get_aval(p).at_least_vspace(), core.unit))
|
|
|
|
for p in primals))
|
|
|
|
_, in_tree = tree_flatten(((primals, primals), {}))
|
|
|
|
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
|
|
|
|
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
|
|
|
|
pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
|
|
|
|
aval_primals, const_primals = unzip2(pval_primals)
|
|
|
|
assert all(aval_primal is None for aval_primal in aval_primals)
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-07-26 23:17:21 -04:00
|
|
|
return const_primals, pval_tangents, jaxpr, consts
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-26 23:17:21 -04:00
|
|
|
return const_primals, pval_tangents, jaxpr, consts, aux()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def vjp(traceable, primals, has_aux=False):
|
|
|
|
if not has_aux:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
|
|
|
|
def vjp_(*cts):
|
|
|
|
cts = tuple(map(ignore_consts, cts, pvals))
|
|
|
|
dummy_primals_and_cts = (core.unit,) * len(cts) + cts
|
|
|
|
dummy_args = (undefined_primal,) * len(jaxpr.invars)
|
|
|
|
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts)
|
|
|
|
arg_cts = arg_cts[len(primals):]
|
|
|
|
return map(instantiate_zeros, primals, arg_cts)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_primals, vjp_
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_primals, vjp_, aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-03 22:24:46 -05:00
|
|
|
def ignore_consts(ct, pval):
|
|
|
|
aval, const = pval
|
|
|
|
if isinstance(aval, core.AbstractValue):
|
|
|
|
return ct
|
|
|
|
elif aval is None:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
def unpair_pval(pval):
|
|
|
|
aval, const = pval
|
|
|
|
const_1, const_2 = const
|
|
|
|
if aval is None:
|
|
|
|
return (None, const_1), (None, const_2)
|
|
|
|
else:
|
|
|
|
aval_1, aval_2 = aval
|
|
|
|
return (aval_1, const_1), (aval_2, const_2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
2019-11-22 10:53:11 -08:00
|
|
|
if all(ct is zero for ct in cotangents_in):
|
|
|
|
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def write_cotangent(v, ct):
|
|
|
|
# assert v not in primal_env
|
|
|
|
if ct is not None:
|
|
|
|
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
|
|
|
|
|
|
|
def read_cotangent(v):
|
|
|
|
return ct_env.get(v, zero)
|
|
|
|
|
2019-04-25 10:43:50 -07:00
|
|
|
def read_primal(v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return primal_env.get(v, undefined_primal)
|
2019-04-25 10:43:50 -07:00
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
def write_primal(v, val):
|
2019-07-27 15:46:14 -07:00
|
|
|
if val is not undefined_primal:
|
2019-05-01 15:47:01 -07:00
|
|
|
primal_env[v] = val
|
|
|
|
|
|
|
|
primal_env = {}
|
2019-11-22 10:53:11 -08:00
|
|
|
write_primal(core.unitvar, core.unit)
|
2019-07-27 15:46:14 -07:00
|
|
|
map(write_primal, jaxpr.constvars, consts)
|
|
|
|
map(write_primal, jaxpr.freevars, freevar_vals)
|
|
|
|
map(write_primal, jaxpr.invars, args)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
def is_linear(var):
|
|
|
|
if type(var) is Literal:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return primal_env.get(var, undefined_primal) is undefined_primal
|
|
|
|
|
|
|
|
linear_eqns = []
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if not eqn.bound_subjaxprs:
|
|
|
|
if any(is_linear(v) for v in eqn.invars):
|
|
|
|
linear_eqns.append(eqn)
|
|
|
|
else:
|
|
|
|
in_vals = map(read_primal, eqn.invars)
|
|
|
|
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write_primal, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write_primal(eqn.outvars[0], ans)
|
|
|
|
else:
|
|
|
|
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
2019-11-27 15:25:49 -08:00
|
|
|
assert not any(is_linear(v) for v in const_vars)
|
2019-11-27 14:28:13 -08:00
|
|
|
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
|
2019-11-22 10:53:11 -08:00
|
|
|
linear_eqns.append(eqn)
|
2019-11-27 15:25:49 -08:00
|
|
|
elif eqn.primitive is not pe.remat_call_p:
|
|
|
|
ans = _eval_subjaxpr_primals(
|
|
|
|
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
|
|
|
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
|
|
|
map(write_primal, eqn.outvars, ans)
|
|
|
|
|
|
|
|
# we special-case remat_call here because it can be mixed linear /
|
|
|
|
# nonlinear, so we always evaluate it even if it has a linear part
|
|
|
|
if eqn.primitive is pe.remat_call_p:
|
|
|
|
ans = _eval_subjaxpr_primals(
|
|
|
|
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
|
|
|
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
2019-11-27 14:28:13 -08:00
|
|
|
map(write_primal, eqn.outvars, ans)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
ct_env = {}
|
|
|
|
map(write_cotangent, jaxpr.outvars, cotangents_in)
|
2019-11-22 10:53:11 -08:00
|
|
|
for eqn in linear_eqns[::-1]:
|
2019-07-27 15:46:14 -07:00
|
|
|
invals = map(read_primal, eqn.invars)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
cts_in = map(read_cotangent, eqn.outvars)
|
2019-04-25 10:43:50 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
cts_in, = map(read_cotangent, eqn.outvars)
|
2018-11-17 18:03:33 -08:00
|
|
|
if eqn.bound_subjaxprs:
|
2019-07-27 15:46:14 -07:00
|
|
|
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
|
|
|
sub_consts = map(read_primal, const_vars)
|
|
|
|
sub_freevar_vals = map(read_primal, bound_vars)
|
|
|
|
ct_free_vars_out, cts_out = get_primitive_transpose(eqn.primitive)(
|
|
|
|
eqn.params, subjaxpr, sub_consts, sub_freevar_vals, invals, cts_in)
|
2018-11-17 18:03:33 -08:00
|
|
|
map(write_cotangent, bound_vars, ct_free_vars_out)
|
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
|
|
|
|
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
|
|
|
|
map(write_cotangent, eqn.invars, cts_out)
|
2018-12-03 22:24:46 -05:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
freevar_cts = map(read_cotangent, jaxpr.freevars)
|
|
|
|
cotangents_out = map(read_cotangent, jaxpr.invars)
|
2018-11-17 18:03:33 -08:00
|
|
|
return freevar_cts, cotangents_out
|
|
|
|
|
2019-11-27 15:25:49 -08:00
|
|
|
def _eval_subjaxpr_primals(prim, jaxpr, consts, freevar_vals, in_vals, params):
|
|
|
|
all_args, in_tree_def = tree_flatten((consts, freevar_vals, in_vals))
|
2020-01-05 04:35:34 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
|
2019-11-27 15:25:49 -08:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
|
|
|
out_flat = prim.bind(fun, *all_args, **params)
|
|
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
def _eval_primals(jaxpr, consts, freevar_vals, args):
|
|
|
|
primal_env = {}
|
|
|
|
|
|
|
|
def read_primal(v):
|
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
|
|
|
return primal_env.get(v, undefined_primal)
|
|
|
|
|
|
|
|
def write_primal(v, val):
|
|
|
|
if val is not undefined_primal:
|
|
|
|
primal_env[v] = val
|
|
|
|
|
|
|
|
def is_linear(var):
|
|
|
|
if type(var) is Literal:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return primal_env.get(var, undefined_primal) is undefined_primal
|
|
|
|
|
|
|
|
write_primal(core.unitvar, core.unit)
|
|
|
|
map(write_primal, jaxpr.constvars, consts)
|
|
|
|
map(write_primal, jaxpr.freevars, freevar_vals)
|
|
|
|
map(write_primal, jaxpr.invars, args)
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if not eqn.bound_subjaxprs:
|
|
|
|
if not any(is_linear(v) for v in eqn.invars):
|
|
|
|
in_vals = map(read_primal, eqn.invars)
|
|
|
|
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write_primal, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write_primal(eqn.outvars[0], ans)
|
|
|
|
else:
|
|
|
|
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
2019-11-27 15:25:49 -08:00
|
|
|
assert not any(is_linear(v) for v in const_vars)
|
|
|
|
if (eqn.primitive is pe.remat_call_p or
|
|
|
|
not any(is_linear(v) for v in it.chain(eqn.invars, bound_vars))):
|
|
|
|
ans = _eval_subjaxpr_primals(
|
|
|
|
eqn.primitive, subjaxpr, map(read_primal, const_vars),
|
|
|
|
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
|
|
|
|
map(write_primal, eqn.outvars, ans)
|
2019-11-22 10:53:11 -08:00
|
|
|
return map(read_primal, jaxpr.outvars)
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
class UndefinedPrimal(object):
|
|
|
|
def __repr__(self): return '_'
|
|
|
|
undefined_primal = UndefinedPrimal()
|
|
|
|
register_pytree_node(UndefinedPrimal,
|
|
|
|
lambda z: ((), None),
|
|
|
|
lambda *_: undefined_primal)
|
2019-05-07 08:52:08 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def get_primitive_transpose(p):
|
|
|
|
try:
|
|
|
|
return primitive_transposes[p]
|
|
|
|
except KeyError:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Reverse-mode differentiation rule for '{}' not implemented".format(p))
|
|
|
|
|
|
|
|
class JVPTrace(Trace):
|
|
|
|
|
|
|
|
def pure(self, val):
|
|
|
|
return JVPTracer(self, val, zero)
|
|
|
|
|
|
|
|
def lift(self, val):
|
|
|
|
return JVPTracer(self, val, zero)
|
|
|
|
|
2019-01-24 16:08:03 -08:00
|
|
|
def sublift(self, val):
|
2018-11-17 18:03:33 -08:00
|
|
|
return JVPTracer(self, val.primal, val.tangent)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
2018-11-17 18:03:33 -08:00
|
|
|
try:
|
|
|
|
jvp = primitive_jvps[primitive]
|
|
|
|
except KeyError:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Forward-mode differentiation rule for '{}' not implemented"
|
|
|
|
.format(primitive))
|
|
|
|
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
|
2019-07-27 15:46:14 -07:00
|
|
|
if primitive.multiple_results:
|
|
|
|
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
|
|
|
|
else:
|
|
|
|
return JVPTracer(self, primal_out, tangent_out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert call_primitive.multiple_results
|
2018-11-17 18:03:33 -08:00
|
|
|
primals = [t.primal for t in tracers]
|
|
|
|
tangents = [t.tangent for t in tracers]
|
2019-07-27 15:46:14 -07:00
|
|
|
nonzero_tangents, in_tree_def = tree_flatten(tangents)
|
|
|
|
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), len(primals), in_tree_def)
|
|
|
|
result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
|
|
|
|
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
|
|
|
|
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
|
|
|
|
out = primals + tangents
|
|
|
|
del primals, tangents
|
2018-11-17 18:03:33 -08:00
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
2019-07-27 15:46:14 -07:00
|
|
|
n = len(x) // 2
|
|
|
|
primals, tangents = x[:n], x[n:]
|
2018-11-17 18:03:33 -08:00
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
2019-07-27 15:46:14 -07:00
|
|
|
return map(partial(JVPTracer, trace), primals, tangents)
|
|
|
|
return out, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def join(self, xt, yt):
|
2019-07-27 15:46:14 -07:00
|
|
|
xz, yz = xt is zero, yt is zero
|
|
|
|
if xz == yz:
|
2018-11-17 18:03:33 -08:00
|
|
|
return xt, yt
|
2019-07-27 15:46:14 -07:00
|
|
|
elif yz and not xz:
|
|
|
|
return xt, zeros_like_jaxval(xt)
|
|
|
|
elif xz and not yz:
|
|
|
|
return zeros_like_jaxval(yt), yt
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError((xt, yt))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
class JVPTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['primal', 'tangent']
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __init__(self, trace, primal, tangent):
|
2019-05-10 15:52:12 -07:00
|
|
|
if not core.skip_checks:
|
|
|
|
_primal_tangent_shapes_match(primal, tangent)
|
2018-11-17 18:03:33 -08:00
|
|
|
self.trace = trace
|
|
|
|
self.primal = primal
|
|
|
|
self.tangent = tangent
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
# TODO(dougalm): add epsilon ball
|
|
|
|
return get_aval(self.primal)
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
if self.tangent is zero:
|
|
|
|
return core.full_lower(self.primal)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2019-05-10 15:52:12 -07:00
|
|
|
def _primal_tangent_shapes_match(primal, tangent):
|
2019-07-27 15:46:14 -07:00
|
|
|
if tangent is not zero:
|
2019-05-10 15:52:12 -07:00
|
|
|
primal_aval = raise_to_shaped(get_aval(primal))
|
|
|
|
tangent_aval = raise_to_shaped(get_aval(tangent))
|
|
|
|
assert primal_aval == tangent_aval
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# -------------------- Primitives --------------------
|
|
|
|
|
|
|
|
|
|
|
|
primitive_jvps = {}
|
|
|
|
composite_jvps = {}
|
|
|
|
|
|
|
|
primitive_transposes = {}
|
|
|
|
|
|
|
|
|
|
|
|
def deflinear(primitive, transpose_rule):
|
|
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
|
|
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
|
|
|
|
|
|
|
|
def linear_jvp(primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
if all(tangent is zero for tangent in tangents):
|
|
|
|
return val_out, zero
|
|
|
|
else:
|
|
|
|
tangents = map(instantiate_zeros, primals, tangents)
|
|
|
|
return val_out, primitive.bind(*tangents, **params)
|
|
|
|
|
|
|
|
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
|
|
|
|
return zero if cotangent is zero else transpose_rule(cotangent, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def defjvp(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
|
|
|
|
|
|
|
|
|
|
|
|
def standard_jvp(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
2019-02-20 12:36:18 -08:00
|
|
|
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not zero]
|
2020-01-08 13:17:55 -05:00
|
|
|
return val_out, functools.reduce(add_tangents, tangents_out, zero)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defjvp2(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
|
|
|
|
|
|
|
|
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not zero)
|
2020-01-08 13:17:55 -05:00
|
|
|
return val_out, functools.reduce(add_tangents, tangents_out, zero)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def add_tangents(x, y):
|
|
|
|
if x is zero:
|
|
|
|
return y
|
|
|
|
elif y is zero:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return add_jaxvals(x, y)
|
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def defvjp_all(prim, custom_vjp):
|
|
|
|
# see https://github.com/google/jax/pull/636
|
2019-04-23 17:47:28 -07:00
|
|
|
name = prim.name
|
|
|
|
|
2019-06-03 07:17:37 -07:00
|
|
|
def fun_jvp(xs, ts, **params):
|
2019-07-27 15:46:14 -07:00
|
|
|
ts = map(instantiate_zeros, xs, ts)
|
|
|
|
primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
|
|
|
|
primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
|
|
|
|
if prim.multiple_results:
|
|
|
|
return primals, tangents
|
|
|
|
else:
|
|
|
|
primal, = primals
|
|
|
|
tangent, = tangents
|
|
|
|
return primal, tangent
|
2019-04-23 17:47:28 -07:00
|
|
|
primitive_jvps[prim] = fun_jvp
|
|
|
|
|
|
|
|
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
|
2019-07-27 15:46:14 -07:00
|
|
|
fun_jvp_p.multiple_results = True
|
2019-06-03 07:17:37 -07:00
|
|
|
def fun_jvp_partial_eval(trace, *tracers, **params):
|
2019-07-27 15:46:14 -07:00
|
|
|
primals, tangents = split_list(tracers, [len(tracers) // 2])
|
|
|
|
primals_out, vjp_py = custom_vjp(*primals, **params)
|
|
|
|
if not prim.multiple_results:
|
|
|
|
primals_out = [primals_out]
|
|
|
|
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
|
|
|
|
ct_pvals = [pe.PartialVal((aval, core.unit)) for aval in out_avals]
|
2020-01-05 04:35:34 +01:00
|
|
|
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
|
2019-07-27 15:46:14 -07:00
|
|
|
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
|
|
|
|
num_res=len(res), out_avals=out_avals)
|
|
|
|
return primals_out + tangents_out
|
2019-04-23 17:47:28 -07:00
|
|
|
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
|
|
|
|
|
|
|
|
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
|
2019-07-27 15:46:14 -07:00
|
|
|
fun_lin_p.multiple_results = True
|
|
|
|
fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
|
|
|
|
def fun_lin_transpose(cts, *args, **kwargs):
|
|
|
|
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
|
|
|
|
res, _ = split_list(args, [num_res])
|
2019-08-25 19:59:50 -07:00
|
|
|
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = core.eval_jaxpr(trans_jaxpr, res, (), *cts)
|
|
|
|
return [None] * num_res + outs
|
2019-04-23 17:47:28 -07:00
|
|
|
primitive_transposes[fun_lin_p] = fun_lin_transpose
|
|
|
|
|
|
|
|
def defvjp(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
|
|
|
|
|
|
|
def defvjp2(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
|
|
|
|
assert isinstance(prim, Primitive)
|
|
|
|
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
|
|
|
|
rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs)
|
|
|
|
defjvp(prim, lhs_jvp, rhs_jvp)
|
|
|
|
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
|
|
|
|
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
|
|
|
|
|
|
|
|
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert (x is undefined_primal) ^ (y is undefined_primal)
|
|
|
|
if x is undefined_primal:
|
2018-11-17 18:03:33 -08:00
|
|
|
out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
|
|
|
|
return out, None
|
|
|
|
else:
|
|
|
|
out = zero if cotangent is zero else rhs_rule(cotangent, x, **kwargs)
|
|
|
|
return None, out
|
|
|
|
|
|
|
|
|
|
|
|
def defjvp_zero(primitive):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(zero_jvp, primitive)
|
|
|
|
|
|
|
|
def zero_jvp(primitive, primals, tangents, **params):
|
|
|
|
return primitive.bind(*primals, **params), zero
|
|
|
|
|
|
|
|
|
2018-12-11 09:18:38 -08:00
|
|
|
deflinear(zeros_like_p, lambda t: [zero])
|
2018-11-17 18:03:33 -08:00
|
|
|
deflinear(core.identity_p, lambda t: (t,))
|
|
|
|
deflinear(add_jaxvals_p, lambda t: (t, t))
|
|
|
|
|
|
|
|
def instantiate_zeros(example, tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return zeros_like_jaxval(example)
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2019-05-07 08:52:08 -07:00
|
|
|
def instantiate_zeros_aval(aval, tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return zeros_like_aval(aval)
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
|
|
|
new_primals = primals_and_tangents[:num_primals]
|
|
|
|
new_tangents = primals_and_tangents[num_primals:]
|
|
|
|
new_tangents = tree_unflatten(in_tree_def, new_tangents)
|
2019-04-10 22:09:14 -07:00
|
|
|
primal_out, tangent_out = yield (new_primals, new_tangents), {}
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
|
|
|
|
yield out_flat, tree_def
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-26 07:56:48 -08:00
|
|
|
|
2019-02-21 18:37:51 -08:00
|
|
|
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
2019-07-27 15:46:14 -07:00
|
|
|
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
|
2020-01-05 04:35:34 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
|
|
|
out_flat = primitive.bind(fun, *all_args, **params)
|
|
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
2019-11-22 10:53:11 -08:00
|
|
|
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
|
2019-02-23 20:34:14 -08:00
|
|
|
|
|
|
|
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
2019-07-27 15:46:14 -07:00
|
|
|
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
|
2020-01-05 04:35:34 +01:00
|
|
|
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
|
|
|
out_flat = primitive.bind(fun, *all_args, **params)
|
|
|
|
freevar_cts, arg_cts = tree_unflatten(out_tree(), out_flat)
|
|
|
|
freevar_cts = [x.sum(0) if x is not zero else x for x in freevar_cts]
|
|
|
|
return freevar_cts, arg_cts
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2019-04-10 09:42:17 -07:00
|
|
|
|
2019-05-10 08:58:05 -07:00
|
|
|
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert len(jaxpr.in_avals) == len(nonzeros)
|
2020-01-05 04:35:34 +01:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-05-10 08:58:05 -07:00
|
|
|
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
2019-07-27 15:46:14 -07:00
|
|
|
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
|
|
|
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
|
|
|
pvals = [pe.PartialVal((aval, core.unit)) for aval in avals_in]
|
|
|
|
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
|
|
|
avals_out, _ = unzip2(pvals_out)
|
|
|
|
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
|
2019-05-10 08:20:40 -07:00
|
|
|
return jaxpr_out, out_nonzeros()
|
2019-04-10 09:42:17 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
|
|
|
|
num_primals = len(nonzeros)
|
|
|
|
primals = list(primals_and_nztangents[:num_primals])
|
|
|
|
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
|
|
|
|
tangents = [next(nonzero_tangents) if nz else zero for nz in nonzeros]
|
|
|
|
primals_out, tangents_out = yield (primals, tangents), {}
|
|
|
|
out_nonzeros = [t is not zero for t in tangents_out]
|
|
|
|
nonzero_tangents_out = [t for t in tangents_out if t is not zero]
|
|
|
|
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
|
|
|
|
|
|
|
|
def rearrange_binders(jaxpr, primals_in, tangents_in, primals_out, tangents_out):
|
|
|
|
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
|
|
|
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
|
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, jaxpr.jaxpr.freevars,
|
|
|
|
new_invars, new_outvars, jaxpr.jaxpr.eqns)
|
|
|
|
new_in_avals = _perm(primals_in, tangents_in, jaxpr.in_avals)
|
|
|
|
new_out_avals = _perm(primals_out, tangents_out, jaxpr.out_avals)
|
|
|
|
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, jaxpr.literals, new_in_avals,
|
|
|
|
new_out_avals)
|
|
|
|
return new_typed_jaxpr
|
|
|
|
|
|
|
|
def _perm(primal_counts, tangent_counts, lst):
|
|
|
|
n = sum(primal_counts)
|
|
|
|
primals, tangents = lst[:n], lst[n:]
|
|
|
|
primal_groups = split_list(primals, primal_counts[:-1])
|
|
|
|
tangent_groups = split_list(tangents, tangent_counts[:-1])
|
|
|
|
return _interleave(primal_groups, tangent_groups)
|
|
|
|
|
|
|
|
def _interleave(xs, ys):
|
|
|
|
assert len(xs) == len(ys)
|
|
|
|
return [e for pair in zip(xs, ys) for l in pair for e in l]
|