2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
2018-11-21 13:27:26 -08:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-23 20:34:14 -08:00
|
|
|
import itertools as it
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from . import partial_eval as pe
|
|
|
|
from .. import core as core
|
2018-11-22 17:38:07 +01:00
|
|
|
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
|
2019-05-07 08:52:08 -07:00
|
|
|
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
|
2018-11-17 18:03:33 -08:00
|
|
|
zeros_like_p, zero, Zero)
|
2019-04-23 17:47:28 -07:00
|
|
|
from ..abstract_arrays import raise_to_shaped
|
2018-11-21 13:20:44 -08:00
|
|
|
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
|
2019-03-21 19:08:19 -07:00
|
|
|
from ..tree_util import process_pytree, build_tree, register_pytree_node, tree_map
|
2018-11-17 18:03:33 -08:00
|
|
|
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
|
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
from six.moves import builtins, reduce
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
zip = safe_zip
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-01 16:03:56 -04:00
|
|
|
def jvp(fun, has_aux=False, instantiate=True):
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
2019-04-01 16:03:56 -04:00
|
|
|
return jvpfun(jvp_subtrace(fun), instantiate)
|
2019-03-07 14:08:02 -08:00
|
|
|
else:
|
2019-04-01 16:03:56 -04:00
|
|
|
fun, aux = jvp_subtrace_aux(fun, instantiate)
|
|
|
|
return jvpfun(fun, instantiate), aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@transformation
|
2019-04-01 16:03:56 -04:00
|
|
|
def jvpfun(instantiate, primals, tangents):
|
2018-11-17 18:03:33 -08:00
|
|
|
with new_master(JVPTrace) as master:
|
2019-04-10 22:09:14 -07:00
|
|
|
out_primal, out_tangent = yield (master, primals, tangents), {}
|
2018-11-17 18:03:33 -08:00
|
|
|
del master
|
2019-04-01 16:03:56 -04:00
|
|
|
if instantiate:
|
|
|
|
out_tangent = instantiate_zeros(out_primal, out_tangent)
|
2018-11-17 18:03:33 -08:00
|
|
|
yield (out_primal, out_tangent)
|
|
|
|
|
2019-04-01 16:03:56 -04:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
@transformation
|
|
|
|
def jvp_subtrace(master, primals, tangents):
|
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
assert x.trace.level < trace.level
|
2019-04-10 22:09:14 -07:00
|
|
|
ans = yield map(partial(JVPTracer, trace), primals, tangents), {}
|
2018-11-17 18:03:33 -08:00
|
|
|
out_tracer = trace.full_raise(ans)
|
|
|
|
out_primal, out_tangent = out_tracer.primal, out_tracer.tangent
|
|
|
|
yield (out_primal, out_tangent)
|
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
@transformation_with_aux
|
|
|
|
def jvp_subtrace_aux(master, primals, tangents):
|
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
for x in list(primals) + list(tangents):
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
assert x.trace.level < trace.level
|
2019-04-10 22:09:14 -07:00
|
|
|
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
|
2019-03-07 14:48:05 -08:00
|
|
|
out_tracer, aux_tracer = map(trace.full_raise, (ans, aux))
|
2019-03-07 14:08:02 -08:00
|
|
|
out_primal, out_tangent = out_tracer.primal, out_tracer.tangent
|
|
|
|
aux = aux_tracer.primal # ignore aux tangent
|
|
|
|
yield (out_primal, out_tangent), aux
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
@transformation
|
|
|
|
def pack_output(*args):
|
2019-04-10 22:09:14 -07:00
|
|
|
ans = yield args, {}
|
2018-11-17 18:03:33 -08:00
|
|
|
yield pack(ans)
|
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def linearize(traceable, *primals, **kwargs):
|
|
|
|
has_aux = kwargs.pop('has_aux', False)
|
|
|
|
if not has_aux:
|
|
|
|
jvpfun = pack_output(jvp(traceable))
|
|
|
|
else:
|
|
|
|
jvpfun, aux = jvp(traceable, has_aux=True)
|
|
|
|
jvpfun = pack_output(jvpfun)
|
2018-11-17 18:03:33 -08:00
|
|
|
tangent_avals = [get_aval(p).at_least_vspace() for p in primals]
|
|
|
|
in_pvals = (pe.PartialVal((None, pack(primals))),
|
|
|
|
pe.PartialVal((core.AbstractTuple(tangent_avals), core.unit)))
|
|
|
|
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jvpfun, in_pvals)
|
2018-12-03 22:24:46 -05:00
|
|
|
pval_primal, pval_tangent = unpair_pval(out_pval)
|
|
|
|
aval_primal, const_primal = pval_primal
|
|
|
|
assert aval_primal is None
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
|
|
|
return const_primal, pval_tangent, jaxpr, consts
|
|
|
|
else:
|
|
|
|
return const_primal, pval_tangent, jaxpr, consts, aux()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def vjp(traceable, primals, has_aux=False):
|
|
|
|
if not has_aux:
|
|
|
|
out_primal, pval, jaxpr, consts = linearize(traceable, *primals)
|
|
|
|
else:
|
|
|
|
out_primal, pval, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
|
2018-11-17 18:03:33 -08:00
|
|
|
def vjp_(ct):
|
2018-12-03 22:24:46 -05:00
|
|
|
ct = ignore_consts(ct, pval)
|
2018-11-17 18:03:33 -08:00
|
|
|
dummy_primal_and_ct = pack((core.unit, ct))
|
2019-02-21 18:37:51 -08:00
|
|
|
dummy_args = (None,) * len(jaxpr.invars)
|
|
|
|
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primal_and_ct)
|
2018-11-17 18:03:33 -08:00
|
|
|
return instantiate_zeros(pack(primals), arg_cts[1])
|
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
if not has_aux:
|
|
|
|
return out_primal, vjp_
|
|
|
|
else:
|
|
|
|
return out_primal, vjp_, aux
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-03 22:24:46 -05:00
|
|
|
def ignore_consts(ct, pval):
|
|
|
|
aval, const = pval
|
|
|
|
if isinstance(aval, core.AbstractValue):
|
|
|
|
return ct
|
|
|
|
elif isinstance(aval, pe.JaxprTracerTuple):
|
|
|
|
return pack(map(ignore_consts, ct, zip(aval, const)))
|
|
|
|
elif aval is None:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
def unpair_pval(pval):
|
|
|
|
aval, const = pval
|
|
|
|
const_1, const_2 = const
|
|
|
|
if aval is None:
|
|
|
|
return (None, const_1), (None, const_2)
|
|
|
|
else:
|
|
|
|
aval_1, aval_2 = aval
|
|
|
|
return (aval_1, const_1), (aval_2, const_2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-21 18:37:51 -08:00
|
|
|
def backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in):
|
2018-11-17 18:03:33 -08:00
|
|
|
def write_cotangent(v, ct):
|
|
|
|
# assert v not in primal_env
|
|
|
|
if ct is not None:
|
|
|
|
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
|
|
|
|
|
|
|
def read_cotangent(v):
|
|
|
|
return ct_env.get(v, zero)
|
|
|
|
|
2019-04-25 10:43:50 -07:00
|
|
|
def read_primal(v):
|
|
|
|
return primal_env.get(v)
|
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
def write_primal(v, val):
|
|
|
|
if val is not None:
|
|
|
|
primal_env[v] = val
|
|
|
|
|
|
|
|
primal_env = {}
|
|
|
|
core.pat_fmap(write_primal, jaxpr.constvars, consts)
|
|
|
|
core.pat_fmap(write_primal, jaxpr.freevars, freevar_vals)
|
|
|
|
core.pat_fmap(write_primal, jaxpr.invars, args)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
ct_env = {jaxpr.outvar: cotangent_in}
|
2018-11-17 18:03:33 -08:00
|
|
|
for eqn in jaxpr.eqns[::-1]:
|
|
|
|
cts_in = map(read_cotangent, eqn.outvars)
|
|
|
|
ct_in = TangentTuple(cts_in) if eqn.destructure else cts_in[0]
|
2019-04-25 10:43:50 -07:00
|
|
|
if not eqn.restructure:
|
|
|
|
invals = map(read_primal, eqn.invars)
|
|
|
|
else:
|
|
|
|
invals = [tuple(map(read_primal, v)) if type(v) is tuple
|
|
|
|
else read_primal(v) for v in eqn.invars]
|
2018-11-17 18:03:33 -08:00
|
|
|
if eqn.bound_subjaxprs:
|
|
|
|
subjaxprs, sub_consts, sub_freevar_vals = unzip3([
|
|
|
|
(subjaxpr,
|
2019-04-25 10:43:50 -07:00
|
|
|
map(read_primal, const_vars),
|
|
|
|
map(read_primal, bound_vars))
|
2018-11-17 18:03:33 -08:00
|
|
|
for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs])
|
|
|
|
cts_out, ct_free_vars_out = get_primitive_transpose(eqn.primitive)(
|
|
|
|
eqn.params, subjaxprs, sub_consts, sub_freevar_vals, invals, ct_in)
|
2018-11-21 13:20:44 -08:00
|
|
|
# TODO(dougalm): support cases != 1
|
|
|
|
assert(len(eqn.bound_subjaxprs) == 1)
|
|
|
|
_, _, bound_vars = eqn.bound_subjaxprs[0]
|
2018-11-17 18:03:33 -08:00
|
|
|
map(write_cotangent, bound_vars, ct_free_vars_out)
|
|
|
|
else:
|
|
|
|
cts_out = get_primitive_transpose(eqn.primitive)(ct_in, *invals, **eqn.params)
|
|
|
|
|
|
|
|
if cts_out is zero:
|
|
|
|
cts_out = [zero for _ in eqn.invars]
|
2019-02-23 20:34:14 -08:00
|
|
|
map(write_cotangent, eqn.invars, cts_out)
|
2018-12-03 22:24:46 -05:00
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
freevar_cts = core.pat_fmap(read_cotangent, jaxpr.freevars)
|
2019-05-07 08:52:08 -07:00
|
|
|
cotangents_out = core.pat_fmap(lambda v, _: read_cotangent(v), jaxpr.invars, None)
|
|
|
|
cotangents_out = tuple(map(pack_cotangents_like_caller, args, cotangents_out))
|
2018-11-17 18:03:33 -08:00
|
|
|
return freevar_cts, cotangents_out
|
|
|
|
|
2019-05-07 08:52:08 -07:00
|
|
|
def pack_cotangents_like_caller(arg, ct):
|
|
|
|
if type(arg) is tuple:
|
|
|
|
return tuple(map(pack_cotangents_like_caller, arg, ct))
|
|
|
|
elif arg is None:
|
|
|
|
return recursively_pack(ct)
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
def recursively_pack(ct):
|
|
|
|
if type(ct) is tuple:
|
|
|
|
ct = tuple(map(recursively_pack, ct))
|
|
|
|
if any(elt is zero or isinstance(elt, TangentTuple) for elt in ct):
|
|
|
|
return TangentTuple(ct)
|
|
|
|
else:
|
|
|
|
return pack(ct)
|
|
|
|
else:
|
|
|
|
return ct
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def get_primitive_transpose(p):
|
|
|
|
try:
|
|
|
|
return primitive_transposes[p]
|
|
|
|
except KeyError:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Reverse-mode differentiation rule for '{}' not implemented".format(p))
|
|
|
|
|
|
|
|
class TangentTuple(tuple):
|
|
|
|
pass
|
|
|
|
|
|
|
|
register_pytree_node(
|
|
|
|
TangentTuple, lambda xs: (xs, None), lambda _, xs: TangentTuple(xs))
|
|
|
|
|
|
|
|
class JVPTrace(Trace):
|
|
|
|
|
|
|
|
def pure(self, val):
|
|
|
|
return JVPTracer(self, val, zero)
|
|
|
|
|
|
|
|
def lift(self, val):
|
|
|
|
return JVPTracer(self, val, zero)
|
|
|
|
|
2019-01-24 16:08:03 -08:00
|
|
|
def sublift(self, val):
|
2018-11-17 18:03:33 -08:00
|
|
|
return JVPTracer(self, val.primal, val.tangent)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
primals_in = [t.primal for t in tracers]
|
|
|
|
tangents_in = [t.tangent for t in tracers]
|
|
|
|
try:
|
|
|
|
jvp = primitive_jvps[primitive]
|
|
|
|
except KeyError:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Forward-mode differentiation rule for '{}' not implemented"
|
|
|
|
.format(primitive))
|
|
|
|
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
|
|
|
|
return JVPTracer(self, primal_out, tangent_out)
|
|
|
|
|
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
|
|
|
primals = [t.primal for t in tracers]
|
|
|
|
tangents = [t.tangent for t in tracers]
|
|
|
|
nonzero_tangents, in_tree_def = tree_to_jaxtuples(tangents)
|
|
|
|
f, out_tree_def = traceable(jvp_subtrace(f, self.master), in_tree_def)
|
2019-02-20 12:36:18 -08:00
|
|
|
result = call_primitive.bind(f, pack(primals), nonzero_tangents, **params)
|
2018-11-17 18:03:33 -08:00
|
|
|
primal_out, tangent_out = build_tree(out_tree_def(), result)
|
|
|
|
return JVPTracer(self, primal_out, tangent_out)
|
|
|
|
|
2019-05-03 12:37:14 -07:00
|
|
|
def post_process_call(self, call_primitive, out_tracer, params):
|
2018-11-17 18:03:33 -08:00
|
|
|
out_jtuple, tree_def = tree_to_jaxtuples((out_tracer.primal, out_tracer.tangent))
|
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
|
|
|
trace = JVPTrace(master, core.cur_sublevel())
|
|
|
|
return JVPTracer(trace, *build_tree(tree_def, x))
|
|
|
|
|
|
|
|
return out_jtuple, todo
|
|
|
|
|
|
|
|
def join(self, xt, yt):
|
|
|
|
isfull = lambda t: t is not zero and not isinstance(t, TangentTuple)
|
|
|
|
if isfull(xt) and isfull(yt):
|
|
|
|
return xt, yt
|
|
|
|
elif isfull(xt):
|
|
|
|
if yt is zero:
|
|
|
|
return xt, zeros_like_jaxval(xt)
|
|
|
|
elif isinstance(xt, TangentTuple):
|
|
|
|
return xt, JaxTuple(map(zeros_like_jaxval, xt))
|
|
|
|
else:
|
|
|
|
raise TypeError
|
|
|
|
elif isfull(yt):
|
|
|
|
if xt is zero:
|
|
|
|
return zeros_like_jaxval(yt), yt
|
|
|
|
elif isinstance(xt, TangentTuple):
|
|
|
|
return JaxTuple(map(zeros_like_jaxval, yt)), yt
|
|
|
|
else:
|
|
|
|
raise TypeError
|
|
|
|
elif isinstance(xt, TangentTuple) or isinstance(yt, TangentTuple):
|
|
|
|
if xt is zero:
|
|
|
|
xt = TangentTuple((zero,) * len(yt))
|
|
|
|
elif yt is zero:
|
|
|
|
yt = TangentTuple((zero,) * len(xt))
|
|
|
|
return TangentTuple(map(self.join), xt, yt)
|
|
|
|
elif xt is zero and yt is zero:
|
|
|
|
return xt, yt
|
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError((xt, yt))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def pack(self, tracers):
|
|
|
|
primals = pack(t.primal for t in tracers)
|
|
|
|
tangents = TangentTuple([t.tangent for t in tracers])
|
|
|
|
return JVPTracer(self, primals, tangents)
|
|
|
|
|
|
|
|
|
|
|
|
class JVPTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['primal', 'tangent']
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __init__(self, trace, primal, tangent):
|
|
|
|
self.trace = trace
|
|
|
|
self.primal = primal
|
|
|
|
self.tangent = tangent
|
2018-12-12 09:00:39 -08:00
|
|
|
# TODO(mattjj,dougalm): behind skip_checks, check primal/tangent shapes and
|
|
|
|
# dtypes agree (up to jax_enable_x64 flag)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
# TODO(dougalm): add epsilon ball
|
|
|
|
return get_aval(self.primal)
|
|
|
|
|
|
|
|
def unpack(self):
|
|
|
|
if self.tangent is zero:
|
2019-01-07 16:47:13 -08:00
|
|
|
return self.full_lower()
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2019-01-07 16:47:13 -08:00
|
|
|
return map(partial(JVPTracer, self.trace), self.primal, self.tangent)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
if self.tangent is zero:
|
|
|
|
return core.full_lower(self.primal)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
|
|
|
# -------------------- Primitives --------------------
|
|
|
|
|
|
|
|
|
|
|
|
primitive_jvps = {}
|
|
|
|
composite_jvps = {}
|
|
|
|
|
|
|
|
primitive_transposes = {}
|
|
|
|
|
|
|
|
|
|
|
|
def deflinear(primitive, transpose_rule):
|
|
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
|
|
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
|
|
|
|
|
|
|
|
def linear_jvp(primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
if all(tangent is zero for tangent in tangents):
|
|
|
|
return val_out, zero
|
|
|
|
else:
|
|
|
|
tangents = map(instantiate_zeros, primals, tangents)
|
|
|
|
return val_out, primitive.bind(*tangents, **params)
|
|
|
|
|
|
|
|
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
|
|
|
|
return zero if cotangent is zero else transpose_rule(cotangent, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def defjvp(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
|
|
|
|
|
|
|
|
|
|
|
|
def standard_jvp(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
2019-02-20 12:36:18 -08:00
|
|
|
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not zero]
|
2018-11-17 18:03:33 -08:00
|
|
|
return val_out, reduce(add_tangents, tangents_out, zero)
|
|
|
|
|
|
|
|
def defjvp2(primitive, *jvprules):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
|
|
|
|
|
|
|
|
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
|
|
|
|
val_out = primitive.bind(*primals, **params)
|
|
|
|
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
|
|
if rule is not None and t is not zero)
|
|
|
|
return val_out, reduce(add_tangents, tangents_out, zero)
|
|
|
|
|
|
|
|
def add_tangents(x, y):
|
|
|
|
if x is zero:
|
|
|
|
return y
|
|
|
|
elif y is zero:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return add_jaxvals(x, y)
|
|
|
|
|
|
|
|
|
2019-04-23 17:47:28 -07:00
|
|
|
def defvjp_all(prim, custom_vjp):
|
|
|
|
name = prim.name
|
|
|
|
|
|
|
|
def fun_jvp(xs, ts):
|
|
|
|
ts = map(instantiate_zeros, xs, ts) # TODO(mattjj): avoid instantiation?
|
|
|
|
primal_out, tangent_out = fun_jvp_p.bind(pack(xs), pack(ts))
|
|
|
|
return primal_out, tangent_out
|
|
|
|
primitive_jvps[prim] = fun_jvp
|
|
|
|
|
|
|
|
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
|
|
|
|
def fun_jvp_partial_eval(trace, *tracers):
|
|
|
|
primals_tracer, tangents_tracer = tracers
|
|
|
|
primal_out, vjp_py = custom_vjp(*primals_tracer)
|
|
|
|
|
|
|
|
in_aval = raise_to_shaped(get_aval(primal_out))
|
|
|
|
ct_pval = pe.PartialVal((in_aval, core.unit))
|
|
|
|
vjp_jaxpr, out_pval, residuals = pe.trace_unwrapped_to_jaxpr(
|
|
|
|
lambda ct: pack(vjp_py(ct)), (ct_pval,))
|
|
|
|
out_pv, out_const = out_pval
|
|
|
|
tangent_out = fun_lin_p.bind(out_const, pack(residuals), tangents_tracer,
|
|
|
|
in_aval=in_aval, out_pv=out_pv, vjp_jaxpr=vjp_jaxpr)
|
|
|
|
|
|
|
|
return pack((primal_out, tangent_out))
|
|
|
|
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
|
|
|
|
|
|
|
|
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
|
|
|
|
fun_lin_p.def_abstract_eval(lambda c, r, ts, in_aval, out_pv, vjp_jaxpr: in_aval)
|
|
|
|
def fun_lin_transpose(ct, out_const, residuals, ts, in_aval, out_pv, vjp_jaxpr):
|
|
|
|
assert ts is None and out_const is not None and residuals is not None
|
|
|
|
ans = core.eval_jaxpr(vjp_jaxpr, residuals, (), ct)
|
|
|
|
out = pe.merge_pvals(ans, pe.PartialVal((out_pv, out_const)))
|
|
|
|
return [None, None, out]
|
|
|
|
primitive_transposes[fun_lin_p] = fun_lin_transpose
|
|
|
|
|
|
|
|
def defvjp(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
|
|
|
|
|
|
|
def defvjp2(prim, *vjps):
|
|
|
|
def vjpmaker(*primals):
|
|
|
|
ans = prim.bind(*primals)
|
|
|
|
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
|
|
|
|
for x, vjp in zip(primals, vjps)]
|
|
|
|
return ans, vjpfun
|
|
|
|
defvjp_all(prim, vjpmaker)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
|
|
|
|
assert isinstance(prim, Primitive)
|
|
|
|
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
|
|
|
|
rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs)
|
|
|
|
defjvp(prim, lhs_jvp, rhs_jvp)
|
|
|
|
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
|
|
|
|
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
|
|
|
|
|
|
|
|
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
|
2018-12-17 17:20:52 -08:00
|
|
|
assert (x is None) ^ (y is None)
|
2018-11-17 18:03:33 -08:00
|
|
|
if x is None:
|
|
|
|
out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
|
|
|
|
return out, None
|
|
|
|
else:
|
|
|
|
out = zero if cotangent is zero else rhs_rule(cotangent, x, **kwargs)
|
|
|
|
return None, out
|
|
|
|
|
|
|
|
|
|
|
|
def defjvp_zero(primitive):
|
|
|
|
assert isinstance(primitive, Primitive)
|
|
|
|
primitive_jvps[primitive] = partial(zero_jvp, primitive)
|
|
|
|
|
|
|
|
def zero_jvp(primitive, primals, tangents, **params):
|
|
|
|
return primitive.bind(*primals, **params), zero
|
|
|
|
|
|
|
|
|
2018-12-11 09:18:38 -08:00
|
|
|
deflinear(zeros_like_p, lambda t: [zero])
|
2018-11-17 18:03:33 -08:00
|
|
|
deflinear(core.identity_p, lambda t: (t,))
|
|
|
|
deflinear(core.pack_p, lambda t: list(t) if t is not zero else zero)
|
|
|
|
deflinear(add_jaxvals_p, lambda t: (t, t))
|
|
|
|
|
|
|
|
|
|
|
|
def instantiate_zeros(example, tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return zeros_like_jaxval(example)
|
|
|
|
elif isinstance(tangent, TangentTuple):
|
|
|
|
return pack(map(instantiate_zeros, example, tangent))
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2019-05-07 08:52:08 -07:00
|
|
|
def instantiate_zeros_aval(aval, tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return zeros_like_aval(aval)
|
|
|
|
elif isinstance(tangent, TangentTuple):
|
|
|
|
return pack(map(instantiate_zeros_aval, aval, tangent))
|
|
|
|
else:
|
|
|
|
return tangent
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
@transformation_with_aux
|
|
|
|
def traceable(in_tree_def, new_primals, new_tangents):
|
|
|
|
new_tangents = build_tree(in_tree_def, new_tangents)
|
2019-04-10 22:09:14 -07:00
|
|
|
primal_out, tangent_out = yield (new_primals, new_tangents), {}
|
2018-11-17 18:03:33 -08:00
|
|
|
out_jtuple, tree_def = tree_to_jaxtuples((primal_out, tangent_out))
|
|
|
|
yield out_jtuple, tree_def
|
|
|
|
|
|
|
|
@transformation_with_aux
|
|
|
|
def transposed_fun(jaxpr, in_tree_def, args):
|
2019-02-21 18:37:51 -08:00
|
|
|
args, consts, freevar_vals, ct = args
|
|
|
|
args, ct, freevar_vals = build_tree(in_tree_def, (args, ct, freevar_vals))
|
2019-04-10 22:09:14 -07:00
|
|
|
freevar_cts, cotangents_out = yield (jaxpr, consts, freevar_vals, args, ct), {}
|
2018-11-17 18:03:33 -08:00
|
|
|
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
|
|
|
|
yield out_jtuple, tree_def
|
|
|
|
|
2019-02-21 18:37:51 -08:00
|
|
|
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
2018-11-21 13:20:44 -08:00
|
|
|
jaxpr, = jaxpr
|
|
|
|
consts, = consts
|
|
|
|
freevar_vals, = freevar_vals
|
2018-11-17 18:03:33 -08:00
|
|
|
assert isinstance(jaxpr, core.Jaxpr)
|
2019-02-21 18:37:51 -08:00
|
|
|
(args, ct, freevar_vals), in_tree_def = tree_to_jaxtuples((args, ct, freevar_vals))
|
2018-11-17 18:03:33 -08:00
|
|
|
fun = wrap_init(backward_pass)
|
|
|
|
fun, out_tree_def = transposed_fun(fun, jaxpr, in_tree_def)
|
2019-02-21 18:37:51 -08:00
|
|
|
all_args = pack((pack(args), pack(consts), pack(freevar_vals), ct))
|
2019-02-01 16:30:17 -08:00
|
|
|
# TODO(dougalm): consider signalling to bind that no traces in fun closure
|
2019-02-21 18:37:51 -08:00
|
|
|
ans = primitive.bind(fun, all_args, **params)
|
2018-11-17 18:03:33 -08:00
|
|
|
return build_tree(out_tree_def(), ans)
|
|
|
|
|
2019-02-23 20:34:14 -08:00
|
|
|
@transformation_with_aux
|
|
|
|
def transposed_mapped(jaxpr, in_tree_def, freevar_vals, args):
|
|
|
|
args, consts, ct = args
|
|
|
|
args, ct = build_tree(in_tree_def, (args, ct))
|
2019-04-10 22:09:14 -07:00
|
|
|
freevar_cts, cotangents_out = yield (jaxpr, consts, freevar_vals, args, ct), {}
|
2019-02-23 20:34:14 -08:00
|
|
|
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
|
|
|
|
yield out_jtuple, tree_def
|
|
|
|
|
|
|
|
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
|
|
|
jaxpr, = jaxpr
|
|
|
|
consts, = consts
|
|
|
|
freevar_vals, = freevar_vals
|
|
|
|
(args, ct), in_tree_def = tree_to_jaxtuples((args, ct))
|
|
|
|
fun = wrap_init(backward_pass)
|
|
|
|
fun, out_tree_def = transposed_mapped(fun, jaxpr, in_tree_def, tuple(freevar_vals))
|
|
|
|
all_args = pack((pack(args), pack(consts), ct))
|
|
|
|
ans = primitive.bind(fun, all_args, **params)
|
|
|
|
cts_out, freevar_cts = build_tree(out_tree_def(), ans)
|
2019-02-24 09:00:21 -08:00
|
|
|
freevar_cts = tree_map(lambda x: x.sum(0), freevar_cts)
|
2019-02-23 20:34:14 -08:00
|
|
|
return cts_out, freevar_cts
|
|
|
|
|
2019-04-01 16:03:56 -04:00
|
|
|
def jaxpr_as_fun(jaxpr, consts, *args):
|
|
|
|
return core.eval_jaxpr(jaxpr, consts, (), *args)
|
|
|
|
|
|
|
|
def get_zeros(tangent):
|
|
|
|
if tangent is zero:
|
|
|
|
return True
|
|
|
|
elif isinstance(tangent, TangentTuple):
|
|
|
|
return tuple(map(get_zeros, tangent))
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
def put_zeros(pack, iszero, x):
|
|
|
|
if iszero is True:
|
|
|
|
return zero
|
|
|
|
elif iszero is False:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return pack(map(partial(put_zeros, pack), iszero, x))
|
|
|
|
|
|
|
|
def strip_zeros(unit, pack, iszero, x):
|
|
|
|
if iszero is True:
|
|
|
|
return unit
|
|
|
|
elif iszero is False:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return pack(map(partial(strip_zeros, unit, pack), iszero, x))
|
|
|
|
|
|
|
|
@transformation_with_aux
|
2019-05-08 07:50:38 -07:00
|
|
|
def f_jvp_traceable(zero_components, *primal_tangent_pairs):
|
2019-04-10 09:42:17 -07:00
|
|
|
primals, tangents = unzip2(primal_tangent_pairs)
|
|
|
|
tangents_zeros = map(partial(put_zeros, TangentTuple), zero_components, tangents)
|
|
|
|
primal_out, tangent_out = yield primals, tangents_zeros
|
|
|
|
# TODO check output is tuple
|
|
|
|
zeros_out = get_zeros(tangent_out)
|
|
|
|
tangent_out_nonzero = strip_zeros(core.unit, pack, zeros_out, tangent_out)
|
|
|
|
primal_tangent_pairs_out = [pack((p, t)) for p, t in zip(primal_out, tangent_out_nonzero)]
|
|
|
|
yield pack(primal_tangent_pairs_out), zeros_out
|
|
|
|
|
2019-05-08 07:50:38 -07:00
|
|
|
def jvp_jaxpr(jaxpr, zeros):
|
2019-04-10 09:42:17 -07:00
|
|
|
# jaxpr :: d -> a -> b -> (c1, c2)
|
|
|
|
# avals = (d, a, b)
|
2019-04-18 07:19:04 -07:00
|
|
|
# f :: d -> a -> b -> (c1, c2)
|
|
|
|
f = wrap_init(partial(jaxpr_as_fun, jaxpr.jaxpr, jaxpr.literals))
|
2019-05-08 07:50:38 -07:00
|
|
|
f_jvp, out_zeros = f_jvp_traceable(jvp(f, instantiate=False), zeros)
|
2019-04-10 09:42:17 -07:00
|
|
|
# f_jvp :: (d, d') -> (a, a') -> (b, b') -> ((c1, c1'), (c2, c2'))
|
|
|
|
tangent_avals = map(partial(strip_zeros, core.AbstractTuple(()), core.AbstractTuple),
|
2019-04-18 07:19:04 -07:00
|
|
|
zeros, jaxpr.in_avals)
|
2019-04-10 09:42:17 -07:00
|
|
|
pt_pvals = [pe.PartialVal((core.AbstractTuple((p_aval, t_aval)), core.unit))
|
2019-04-18 07:19:04 -07:00
|
|
|
for p_aval, t_aval in zip(jaxpr.in_avals, tangent_avals)]
|
|
|
|
jaxpr_out, pval_out, literals_out = pe.trace_to_jaxpr(
|
2019-04-10 09:42:17 -07:00
|
|
|
f_jvp, pt_pvals, instantiate=True)
|
2019-04-18 07:19:04 -07:00
|
|
|
# jaxpr_out :: (d, d') -> (a, a') -> (b, b') -> ((c1, c1'), (c2, c2'))
|
2019-04-10 09:42:17 -07:00
|
|
|
# out_zeros :: (zeros(c1), zeros(c2))
|
2019-04-18 07:19:04 -07:00
|
|
|
in_avals = tuple(map(core.AbstractTuple, zip(jaxpr.in_avals, tangent_avals)))
|
|
|
|
out_aval, _ = pval_out
|
|
|
|
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, in_avals, out_aval)
|
|
|
|
return jaxpr_out, out_zeros()
|
2019-04-10 09:42:17 -07:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
|
|
|
primitive_transposes[pe.compiled_call_p] = partial(call_transpose, pe.compiled_call_p)
|
|
|
|
|
|
|
|
|
|
|
|
tree_to_jaxtuples = partial(process_pytree, pack)
|