Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2019-05-07 08:52:08 -07:00
parent 0988f6d8d5
commit 4c2ec3e442
5 changed files with 89 additions and 42 deletions

View File

@ -57,7 +57,7 @@ def zeros_like_aval(aval):
aval_zeros_likers = {}
def zeros_like_abstract_tuple(tup):
return AbstractTuple(map(zeros_like_aval, tup))
return JaxTuple(map(zeros_like_aval, tup))
aval_zeros_likers[AbstractTuple] = zeros_like_abstract_tuple

View File

@ -67,11 +67,12 @@ class TypedJaxpr(namedtuple('TypedJaxpr', ['jaxpr', 'literals', 'in_avals', 'out
@curry
def jaxpr_as_fun(typed_jaxpr, *args):
from jax.lax import _abstractify
for arg, in_aval in zip(args, typed_jaxpr.in_avals):
from jax.lax import _abstractify # TODO
invars = typed_jaxpr.jaxpr.invars
for arg, in_aval, varname in zip(args, typed_jaxpr.in_avals, invars):
arg_aval, _ = _abstractify(arg)
if arg_aval != in_aval:
raise TypeError("input type mismatch")
raise TypeError("input type mismatch for arg {}".format(varname))
out = eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, (), *args)
out_aval, _ = _abstractify(out)
if out_aval != typed_jaxpr.out_aval:
@ -157,7 +158,10 @@ def eval_jaxpr(jaxpr, consts, freevar_vals, *args):
# TODO enforce a specific set of types for jaxpr vars
def pat_fmap(f, v, *xs):
if type(v) in (tuple, list):
return tuple(map(partial(pat_fmap, f), v, *xs))
if len(xs) == 1 and xs[0] is None:
return tuple(map(partial(pat_fmap, f), v, [None] * len(v)))
else:
return tuple(map(partial(pat_fmap, f), v, *xs))
else:
return f(v, *xs)

View File

@ -186,9 +186,17 @@ def _scan_initial_impl(consts, init, xs, forward, length, jaxpr):
return (carry_out, ys_out)
ys_init = empty_arrays(ys_aval)
carry, ys = lax.fori_loop(0, length, body_fun, (init, ys_init))
# carry, ys = lax.fori_loop(0, length, body_fun, (init, ys_init))
carry, ys = fori_loop(0, length, body_fun, (init, ys_init))
return core.pack((carry, ys))
# TODO remove
def fori_loop(start, stop, body_fun, init_val):
carry = init_val
for i in range(start, stop):
carry = body_fun(i, carry)
return carry
def _scan_initial_jvp(primals, tangents, forward, length, jaxpr):
consts, init, xs = primals
@ -322,22 +330,31 @@ def _scan_initial_transpose(ct, consts, init, xs, forward, length, jaxpr):
# jaxpr :: d -> c -> (a, res) -> (c, b) # TODO assuming restructuring input
# jaxpr_lifted :: res -> (d, c, a) -> (c, b)
# jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
jaxpr_lifted = rearrange_binders(
lambda d, c, a_res: (a_res[1], (d, c, a_res[0])), jaxpr)
jaxpr_lifted_trans, out_tree = transpose_jaxpr2(jaxpr_lifted)
jaxpr_lifted_trans = transpose_jaxpr2(jaxpr_lifted)
jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans)
import ipdb; ipdb.set_trace()
ct_c, ct_bs = ct
carry_ct = core.pack((ct_c, ad_util.zeros_like_aval(jaxpr.in_avals[0])))
c_aval, b_aval = jaxpr.out_aval
d_aval, c_aval2, _ = jaxpr.in_avals
assert c_aval == c_aval2
bs_aval = promote_aval_rank(length, b_aval)
ct_d = ad_util.zeros_like_aval(d_aval)
ct_c, ct_bs = ad.instantiate_zeros_aval(core.AbstractTuple((c_aval, bs_aval)), ct)
carry_ct = core.pack((ct_c, ct_d))
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
core.check_jaxpr(jaxpr_trans.jaxpr)
unit_aval, (ct_c_aval, ct_d_aval), (ct_b_aval, _) = jaxpr_trans.in_avals
assert core.lattice_join(ct_c_aval, core.get_aval(ct_c)) == ct_c_aval
assert core.lattice_join(ct_d_aval, core.get_aval(ct_d)) == ct_d_aval
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
# scan_p.bind :: (d -> c -> a -> (c, b)) -> d -> c -> [a] -> (c, [b])
out = scan_initial_p.bind(
core.unit, carry_ct, core.pack((ct_bs, res)),
forward=not forward, length=length, jaxpr=jaxpr_trans)
import ipdb; ipdb.set_trace()
(ct_init, ct_consts), ct_as = out
return ct_consts, ct_init, (ct_as, None)
def rearrange_binders(f, typed_jaxpr):
jaxpr = typed_jaxpr.jaxpr.copy()
@ -351,21 +368,24 @@ _scan_newvar = pe.gensym('_scan')
def _move_stuff_and_add_add(typed_jaxpr):
# jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)
res_aval, (CTc_aval, CTb_aval) = typed_jaxpr.in_avals
CTd_aval, CTc_aval2, CTa_aval = typed_jaxpr.out_aval
assert CTc_aval == CTc_aval2
in_avals = (core.AbstractTuple(()), core.AbstractTuple((CTc_aval, CTd_aval)),
core.AbstractTuple((CTb_aval, res_aval)))
out_aval = core.AbstractTuple((core.AbstractTuple((CTd_aval, CTc_aval)),
out_aval = core.AbstractTuple((core.AbstractTuple((CTc_aval, CTd_aval)),
CTa_aval))
jaxpr = typed_jaxpr.jaxpr.copy()
# TODO assume not restructuring input
assert not any(type(invar) is tuple for invar in jaxpr.invars)
# munge input side
CTc_in = _scan_newvar()
CTb_in = _scan_newvar()
CTd_in = _scan_newvar()
res_in, CTc_CTb_in = jaxpr.invars
jaxpr.invars = ((), (CTc_in, CTd_in), (CTb_in, res_in))
jaxpr.eqns = (
@ -373,7 +393,6 @@ def _move_stuff_and_add_add(typed_jaxpr):
jaxpr.eqns)
# munge output side
CTd_in = _scan_newvar()
CTd_new = _scan_newvar()
CTd_sum = _scan_newvar()
CTc = _scan_newvar()
@ -384,33 +403,32 @@ def _move_stuff_and_add_add(typed_jaxpr):
jaxpr.eqns +
[_unpack_eqn(jaxpr.outvar, [CTd_new, CTc, CTa]),
_add_any_eqn(CTd_sum, CTd_new, CTd_in),
_pack_eqn([CTd_sum, CTc], partial_out),
_pack_eqn([CTc, CTd_sum], partial_out),
_pack_eqn([partial_out, CTa], outvar)])
jaxpr.outvar = outvar
# TODO should really have a check_typed_jaxpr
core.skip_checks or core.check_jaxpr(jaxpr)
return core.TypedJaxpr(jaxpr, typed_jaxpr.literals,
in_avals, out_aval)
def _add_any_eqn(tot, a, b):
return core.JaxprEqn([a, b], [tot], ad_util.add_jaxvals_p, (), False, False, {})
# transpose_jaxpr :: (res -> a -> b) -> (res -> CT b -> CT a)
def transpose_jaxpr2(jaxpr):
assert len(jaxpr.in_avals) == 2
nones = core.pat_fmap(lambda _: None, jaxpr.jaxpr.invars[1])
@lu.wrap_init
def transposed(res, b_bar):
_, a_bar = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (),
(res, nones), b_bar)
_, (_, a_bar) = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (),
(res, None), b_bar)
a_bar = ad.instantiate_zeros_aval(jaxpr.in_avals[1], a_bar)
return a_bar
@lu.transformation_with_aux
def flatten_out(*args):
ans = yield args
yield pytree_to_jaxtupletree(ans)
transposed, out_tree = flatten_out(transposed)
transposed_jaxpr = make_typed_jaxpr(transposed, (jaxpr.in_avals[0], jaxpr.out_aval))
return transposed_jaxpr, out_tree()
return transposed_jaxpr
def make_typed_jaxpr(traceable, in_avals):
pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]

View File

@ -21,7 +21,7 @@ import itertools as it
from . import partial_eval as pe
from .. import core as core
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
zeros_like_p, zero, Zero)
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
from ..tree_util import process_pytree, build_tree, register_pytree_node, tree_map
@ -185,12 +185,29 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in):
cts_out = [zero for _ in eqn.invars]
map(write_cotangent, eqn.invars, cts_out)
cotangents_out = core.pat_fmap(
lambda var, argval: read_cotangent(var) if argval is None else None,
jaxpr.invars, args)
freevar_cts = core.pat_fmap(read_cotangent, jaxpr.freevars)
cotangents_out = core.pat_fmap(lambda v, _: read_cotangent(v), jaxpr.invars, None)
cotangents_out = tuple(map(pack_cotangents_like_caller, args, cotangents_out))
return freevar_cts, cotangents_out
def pack_cotangents_like_caller(arg, ct):
if type(arg) is tuple:
return tuple(map(pack_cotangents_like_caller, arg, ct))
elif arg is None:
return recursively_pack(ct)
else:
return None
def recursively_pack(ct):
if type(ct) is tuple:
ct = tuple(map(recursively_pack, ct))
if any(elt is zero or isinstance(elt, TangentTuple) for elt in ct):
return TangentTuple(ct)
else:
return pack(ct)
else:
return ct
def get_primitive_transpose(p):
try:
return primitive_transposes[p]
@ -409,6 +426,14 @@ def instantiate_zeros(example, tangent):
else:
return tangent
def instantiate_zeros_aval(aval, tangent):
if tangent is zero:
return zeros_like_aval(aval)
elif isinstance(tangent, TangentTuple):
return pack(map(instantiate_zeros_aval, aval, tangent))
else:
return tangent
@transformation_with_aux
def traceable(in_tree_def, new_primals, new_tangents):
new_tangents = build_tree(in_tree_def, new_tangents)

View File

@ -32,19 +32,19 @@ def f(c, a):
as_ = np.ones((5, 3))
c = np.ones(4)
# print scan_reference(f, c, as_)
# print scan_initial(f, c, as_)
# print
print scan_reference(f, c, as_)
print scan_initial(f, c, as_)
print
# print jvp(lambda c, as_: scan_reference(f, c, as_), (c, as_), (c, as_))[1]
# print jvp(lambda c, as_: scan_initial(f, c, as_), (c, as_), (c, as_))[1]
# print
print jvp(lambda c, as_: scan_reference(f, c, as_), (c, as_), (c, as_))[1]
print jvp(lambda c, as_: scan_initial(f, c, as_), (c, as_), (c, as_))[1]
print
# print linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
# print linearize(lambda c, as_: scan_initial(f, c, as_), c, as_)[1](c, as_)
# print
print linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_)
print linearize(lambda c, as_: scan_initial(f, c, as_), c, as_)[1](c, as_)
print
# print grad(lambda c, as_: scan_reference(f, c, as_)[0].sum())(c, as_)
print grad(lambda c, as_: list(scan_reference(f, c, as_))[0].sum())(c, as_)
print grad(lambda c, as_: list(scan_initial(f, c, as_))[0].sum())(c, as_)
print