659 lines
24 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
2019-02-23 20:34:14 -08:00
import itertools as it
from typing import Any, Callable, Dict
2019-02-23 20:34:14 -08:00
2018-11-17 18:03:33 -08:00
from . import partial_eval as pe
from .. import core as core
from ..core import Trace, Tracer, new_master, get_aval, call_p, Primitive, Literal
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
zeros_like_p, zero)
from ..abstract_arrays import raise_to_shaped
from ..util import unzip2, safe_map, safe_zip, partial, split_list, wrap_name
from ..tree_util import register_pytree_node
from .. import linear_util as lu
from ..api_util import flatten_fun, flatten_fun_nokwargs
2019-07-26 23:17:21 -04:00
from ..tree_util import tree_flatten, tree_unflatten
2018-11-17 18:03:33 -08:00
zip = safe_zip
map = safe_map
2019-02-15 06:35:54 -08:00
def identity(x): return x
2018-11-17 18:03:33 -08:00
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
if not has_aux:
return jvpfun(jvp_subtrace(fun), instantiate)
else:
fun, aux = jvp_subtrace_aux(fun)
return jvpfun(fun, instantiate), aux
2018-11-17 18:03:33 -08:00
@lu.transformation
def jvpfun(instantiate, primals, tangents):
2018-11-17 18:03:33 -08:00
with new_master(JVPTrace) as master:
out_primals, out_tangents = yield (master, primals, tangents), {}
2018-11-17 18:03:33 -08:00
del master
if type(instantiate) is bool:
instantiate = [instantiate] * len(out_tangents)
out_tangents = [instantiate_zeros(x, t) if inst else t for x, t, inst
in zip(out_primals, out_tangents, instantiate)]
yield out_primals, out_tangents
@lu.transformation
2018-11-17 18:03:33 -08:00
def jvp_subtrace(master, primals, tangents):
trace = JVPTrace(master, core.cur_sublevel())
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x._trace.level < trace.level
in_tracers = [JVPTracer(trace, x, t) if t is not zero else x
for x, t in zip(primals, tangents)]
ans = yield in_tracers, {}
2019-07-26 23:17:21 -04:00
out_tracers = map(trace.full_raise, ans)
yield unzip2([(out_tracer.primal, out_tracer.tangent)
for out_tracer in out_tracers])
2018-11-17 18:03:33 -08:00
@lu.transformation_with_aux
def jvp_subtrace_aux(master, primals, tangents):
trace = JVPTrace(master, core.cur_sublevel())
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x._trace.level < trace.level
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
ans_tracers = map(trace.full_raise, ans)
aux_tracers = map(trace.full_raise, aux)
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
aux_primals, _ = unzip2((t.primal, t.tangent) for t in aux_tracers)
aux_primals = map(core.full_lower, aux_primals)
yield (out_primals, out_tangents), aux_primals
def linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
if not has_aux:
2019-07-26 23:17:21 -04:00
jvpfun = jvp(traceable)
else:
jvpfun, aux = jvp(traceable, has_aux=True)
2019-07-26 23:17:21 -04:00
in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
+ tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
2019-07-26 23:17:21 -04:00
for p in primals))
_, in_tree = tree_flatten(((primals, primals), {}))
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
_, out_primals_consts = unzip2(out_primals_pvals)
if not has_aux:
return out_primals_consts, out_tangents_pvals, jaxpr, consts
else:
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
2018-11-17 18:03:33 -08:00
def vjp(traceable, primals, has_aux=False):
if not has_aux:
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
else:
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
def vjp_(*cts):
cts = tuple(map(ignore_consts, cts, pvals))
dummy_primals_and_cts = (core.unit,) * len(cts) + cts
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
arg_cts = backward_pass(jaxpr, consts, dummy_args, dummy_primals_and_cts)
arg_cts = arg_cts[len(primals):]
return map(instantiate_zeros, primals, arg_cts)
2018-11-17 18:03:33 -08:00
if not has_aux:
return out_primals, vjp_
else:
return out_primals, vjp_, aux
2018-11-17 18:03:33 -08:00
def ignore_consts(ct, pval):
aval, const = pval
if isinstance(aval, core.AbstractValue):
return ct
elif aval is None:
return core.unit
else:
raise TypeError(aval)
def unpair_pval(pval):
aval, const = pval
const_1, const_2 = const
if aval is None:
return (None, const_1), (None, const_2)
else:
aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2)
2018-11-17 18:03:33 -08:00
def backward_pass(jaxpr: core.Jaxpr, consts, args, cotangents_in):
if all(ct is zero for ct in cotangents_in):
return [zero] * len(jaxpr.invars)
2018-11-17 18:03:33 -08:00
def write_cotangent(v, ct):
# assert v not in primal_env
if ct is not None and type(v) is not Literal:
2018-11-17 18:03:33 -08:00
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
def read_cotangent(v):
return ct_env.get(v, zero)
def read_primal(v):
if type(v) is Literal:
return v.val
else:
return primal_env.get(v, UndefinedPrimal(v.aval))
def write_primal(v, val):
if not is_undefined_primal(val):
primal_env[v] = val
primal_env: Dict[Any, Any] = {}
write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.invars, args)
2018-11-17 18:03:33 -08:00
def is_linear(var):
if type(var) is Literal:
return False
else:
return var not in primal_env
linear_eqns = []
for eqn in jaxpr.eqns:
if not eqn.primitive.call_primitive:
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
if any(not is_linear(v) for v in eqn.invars):
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)
ct_env: Dict[Any, Any] = {}
map(write_cotangent, jaxpr.outvars, cotangents_in)
for eqn in linear_eqns[::-1]:
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
else:
cts_in, = map(read_cotangent, eqn.outvars)
if eqn.primitive.call_primitive:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in)
2018-11-17 18:03:33 -08:00
else:
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
map(write_cotangent, eqn.invars, cts_out)
cotangents_out = map(read_cotangent, jaxpr.invars)
return cotangents_out
2018-11-17 18:03:33 -08:00
def _eval_subjaxpr_primals(prim, jaxpr, in_vals, params):
assert not jaxpr.constvars
all_args, in_tree_def = tree_flatten((in_vals,))
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = prim.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
def _eval_primals(jaxpr, args):
primal_env = {}
def read_primal(v):
if type(v) is Literal:
return v.val
else:
return primal_env.get(v, UndefinedPrimal(v.aval))
def write_primal(v, val):
if not is_undefined_primal(val):
primal_env[v] = val
def is_linear(var):
if type(var) is Literal:
return False
else:
return var not in primal_env
write_primal(core.unitvar, core.unit)
assert not jaxpr.constvars
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.primitive.call_primitive:
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
if any(not is_linear(v) for v in eqn.invars):
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)
class UndefinedPrimal:
__slots__ = ['aval']
def __init__(self, aval):
self.aval = aval
def __repr__(self):
return 'UndefinedPrimal({})'.format(self.aval)
def is_undefined_primal(x):
return type(x) is UndefinedPrimal
register_pytree_node(UndefinedPrimal,
lambda z: ((), z.aval),
lambda aval, _: UndefinedPrimal(aval))
2018-11-17 18:03:33 -08:00
def get_primitive_transpose(p):
try:
return primitive_transposes[p]
except KeyError as err:
2018-11-17 18:03:33 -08:00
raise NotImplementedError(
"Transpose rule (for reverse-mode differentiation) for '{}' "
"not implemented".format(p)) from err
2018-11-17 18:03:33 -08:00
class JVPTrace(Trace):
def pure(self, val):
return JVPTracer(self, val, zero)
def lift(self, val):
return JVPTracer(self, val, zero)
def sublift(self, val):
2018-11-17 18:03:33 -08:00
return JVPTracer(self, val.primal, val.tangent)
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
2018-11-17 18:03:33 -08:00
try:
jvp = primitive_jvps[primitive]
except KeyError as err:
2018-11-17 18:03:33 -08:00
raise NotImplementedError(
"Forward-mode differentiation rule for '{}' not implemented"
.format(primitive)) from err
2018-11-17 18:03:33 -08:00
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
if primitive.multiple_results:
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
else:
return JVPTracer(self, primal_out, tangent_out)
2018-11-17 18:03:33 -08:00
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
primals = [t.primal for t in tracers]
tangents = [t.tangent for t in tracers]
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
name = params.get('name', f.__name__)
params = dict(params, name=wrap_name(name, 'jvp'))
result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
def post_process_call(self, call_primitive, out_tracers, params):
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out = primals + tangents
del primals, tangents
2018-11-17 18:03:33 -08:00
master = self.master
def todo(x):
n = len(x) // 2
primals, tangents = x[:n], x[n:]
2018-11-17 18:03:33 -08:00
trace = JVPTrace(master, core.cur_sublevel())
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo
2018-11-17 18:03:33 -08:00
def process_custom_jvp_call(self, _, __, f_jvp, tracers):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
primals_in = map(core.full_lower, primals_in)
tangents_in = map(instantiate_zeros, primals_in, tangents_in)
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
return map(partial(JVPTracer, self), primals_out, tangents_out)
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
tangents_in = map(instantiate_zeros, primals_in, tangents_in)
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
out_tree, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
avals_out=avals_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
2018-11-17 18:03:33 -08:00
def join(self, xt, yt):
xz, yz = xt is zero, yt is zero
if xz == yz:
2018-11-17 18:03:33 -08:00
return xt, yt
elif yz and not xz:
return xt, zeros_like_jaxval(xt)
elif xz and not yz:
return zeros_like_jaxval(yt), yt
2018-11-17 18:03:33 -08:00
else:
raise TypeError((xt, yt))
2018-11-17 18:03:33 -08:00
class JVPTracer(Tracer):
2019-01-16 16:51:54 +00:00
__slots__ = ['primal', 'tangent']
2018-11-17 18:03:33 -08:00
def __init__(self, trace, primal, tangent):
2019-05-10 15:52:12 -07:00
if not core.skip_checks:
_primal_tangent_shapes_match(primal, tangent)
self._trace = trace
2018-11-17 18:03:33 -08:00
self.primal = primal
self.tangent = tangent
@property
def aval(self):
# TODO(dougalm): add epsilon ball
return get_aval(self.primal)
def full_lower(self):
if self.tangent is zero:
return core.full_lower(self.primal)
else:
return self
2019-05-10 15:52:12 -07:00
def _primal_tangent_shapes_match(primal, tangent):
if tangent is not zero:
2019-05-10 15:52:12 -07:00
primal_aval = raise_to_shaped(get_aval(primal))
tangent_aval = raise_to_shaped(get_aval(tangent))
assert primal_aval == tangent_aval
2018-11-17 18:03:33 -08:00
# -------------------- Primitives --------------------
primitive_jvps : Dict[core.Primitive, Callable] = {}
2018-11-17 18:03:33 -08:00
primitive_transposes: Dict[core.Primitive, Callable] = {}
2018-11-17 18:03:33 -08:00
def deflinear(primitive, transpose_rule):
primitive_jvps[primitive] = partial(linear_jvp, primitive)
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
def linear_jvp(primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
if all(tangent is zero for tangent in tangents):
return val_out, zero
else:
tangents = map(instantiate_zeros, primals, tangents)
return val_out, primitive.bind(*tangents, **params)
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
return zero if cotangent is zero else transpose_rule(cotangent, **kwargs)
def deflinear2(primitive, transpose_rule):
primitive_jvps[primitive] = partial(linear_jvp, primitive)
primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)
def linear_transpose2(transpose_rule, cotangent, *args, **kwargs):
return zero if cotangent is zero else transpose_rule(cotangent, *args, **kwargs)
2018-11-17 18:03:33 -08:00
def defjvp(primitive, *jvprules):
assert isinstance(primitive, Primitive)
primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
def standard_jvp(jvprules, primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
2019-02-20 12:36:18 -08:00
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and t is not zero]
return val_out, functools.reduce(add_tangents, tangents_out, zero)
2018-11-17 18:03:33 -08:00
def defjvp2(primitive, *jvprules):
assert isinstance(primitive, Primitive)
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and t is not zero)
return val_out, functools.reduce(add_tangents, tangents_out, zero)
2018-11-17 18:03:33 -08:00
def add_tangents(x, y):
if x is zero:
return y
elif y is zero:
return x
else:
return add_jaxvals(x, y)
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
assert isinstance(prim, Primitive)
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs)
defjvp(prim, lhs_jvp, rhs_jvp)
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
assert is_undefined_primal(x) ^ is_undefined_primal(y)
if is_undefined_primal(x):
2018-11-17 18:03:33 -08:00
out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
return out, None
else:
out = zero if cotangent is zero else rhs_rule(cotangent, x, **kwargs)
return None, out
def defjvp_zero(primitive):
assert isinstance(primitive, Primitive)
primitive_jvps[primitive] = partial(zero_jvp, primitive)
def zero_jvp(primitive, primals, tangents, **params):
return primitive.bind(*primals, **params), zero
deflinear(zeros_like_p, lambda t: [zero])
2018-11-17 18:03:33 -08:00
deflinear(core.identity_p, lambda t: (t,))
deflinear(add_jaxvals_p, lambda t: (t, t))
def instantiate_zeros(example, tangent):
if tangent is zero:
return zeros_like_jaxval(example)
else:
return tangent
def instantiate_zeros_aval(aval, tangent):
if tangent is zero:
return zeros_like_aval(aval)
else:
return tangent
@lu.transformation_with_aux
def traceable(num_primals, in_tree_def, *primals_and_tangents):
new_primals = primals_and_tangents[:num_primals]
new_tangents = primals_and_tangents[num_primals:]
new_tangents = tree_unflatten(in_tree_def, new_tangents)
primal_out, tangent_out = yield (new_primals, new_tangents), {}
out_flat, tree_def = tree_flatten((primal_out, tangent_out))
yield out_flat, tree_def
2018-11-17 18:03:33 -08:00
def call_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
2019-02-23 20:34:14 -08:00
def map_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
arg_cts = tree_unflatten(out_tree(), out_flat)
mapped_invars = params['mapped_invars'] # True for each mapped invar
# The freevars are being fanned out (not mapped). During transpose the
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
assert len(mapped_invars) == len(arg_cts)
arg_cts = (arg_ct if arg_mapped or arg_ct is zero else arg_ct.sum(0)
for arg_ct, arg_mapped in zip(arg_cts, mapped_invars))
return arg_cts
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, pvals_out, literals_out = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
avals_out, _ = unzip2(pvals_out)
jaxpr_out = core.TypedJaxpr(jaxpr_out, literals_out, avals_in, avals_out)
return jaxpr_out, out_nonzeros()
@lu.transformation_with_aux
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
num_primals = len(nonzeros)
primals = list(primals_and_nztangents[:num_primals])
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
tangents = [next(nonzero_tangents) if nz else zero for nz in nonzeros]
primals_out, tangents_out = yield (primals, tangents), {}
out_nonzeros = [t is not zero for t in tangents_out]
nonzero_tangents_out = [t for t in tangents_out if t is not zero]
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
def rearrange_binders(jaxpr: core.TypedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
new_invars, new_outvars, jaxpr.jaxpr.eqns)
new_in_avals = _perm(primals_in, tangents_in, jaxpr.in_avals)
new_out_avals = _perm(primals_out, tangents_out, jaxpr.out_avals)
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, jaxpr.literals, new_in_avals,
new_out_avals)
return new_typed_jaxpr
def _perm(primal_counts, tangent_counts, lst):
n = sum(primal_counts)
primals, tangents = lst[:n], lst[n:]
primal_groups = split_list(primals, primal_counts[:-1])
tangent_groups = split_list(tangents, tangent_counts[:-1])
return _interleave(primal_groups, tangent_groups)
def _interleave(xs, ys):
assert len(xs) == len(ys)
return [e for pair in zip(xs, ys) for l in pair for e in l]
custom_lin_p = core.Primitive('custom_lin')
custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
custom_lin_p.multiple_results = True
def _raise_custom_vjp_error_on_jvp(*_, **__):
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
"function.")
custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
res, _ = split_list(invals, [num_res])
cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
cts_in = bwd.call_wrapped(*res, *cts_out)
cts_in_flat, _ = tree_flatten(cts_in) # already checked tree structure
return [None] * num_res + cts_in_flat
primitive_transposes[custom_lin_p] = _custom_lin_transpose
# TODO(mattjj): delete everything below here (deprecated custom_transforms)
def defvjp_all(prim, custom_vjp):
# see https://github.com/google/jax/pull/636
name = prim.name
def fun_jvp(xs, ts, **params):
ts = map(instantiate_zeros, xs, ts)
primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
if prim.multiple_results:
return primals, tangents
else:
primal, = primals
tangent, = tangents
return primal, tangent
primitive_jvps[prim] = fun_jvp
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
fun_jvp_p.multiple_results = True
def fun_jvp_partial_eval(trace, *tracers, **params):
primals, tangents = split_list(tracers, [len(tracers) // 2])
primals_out, vjp_py = custom_vjp(*primals, **params)
if not prim.multiple_results:
primals_out = [primals_out]
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals]
2020-03-30 13:49:56 -07:00
with core.initial_style_staging():
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
instantiate=True)
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
num_res=len(res), out_avals=out_avals)
return primals_out + tangents_out
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
fun_lin_p.multiple_results = True
fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
def fun_lin_transpose(cts, *args, **kwargs):
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
res, _ = split_list(args, [num_res])
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
return [None] * num_res + outs
primitive_transposes[fun_lin_p] = fun_lin_transpose
def defvjp(prim, *vjps):
def vjpmaker(*primals):
ans = prim.bind(*primals)
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
for x, vjp in zip(primals, vjps)]
return ans, vjpfun
defvjp_all(prim, vjpmaker)
def defvjp2(prim, *vjps):
def vjpmaker(*primals):
ans = prim.bind(*primals)
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
for x, vjp in zip(primals, vjps)]
return ans, vjpfun
defvjp_all(prim, vjpmaker)