add jaxpr eqn structured input, transpose progress

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2019-04-25 10:43:50 -07:00
parent 1c9035efca
commit a17f8e4ca8
9 changed files with 130 additions and 79 deletions

View File

@ -176,3 +176,10 @@ array_types = [onp.ndarray, onp.float64, onp.float32, onp.float16,
for t in array_types:
core.pytype_aval_mappings[t] = ConcreteArray
ad_util.jaxval_zeros_likers[t] = zeros_like_array
def zeros_like_shaped_array(aval):
assert isinstance(aval, ShapedArray)
return onp.zeros(aval.shape, dtype=aval.dtype)
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array

View File

@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .core import JaxTuple, lattice_join, Primitive
from .core import JaxTuple, lattice_join, Primitive, AbstractTuple
from .tree_util import register_pytree_node
from .util import safe_map
@ -52,6 +52,15 @@ jaxval_zeros_likers = {}
jaxval_zeros_likers[JaxTuple] = zeros_like_impl_jaxtuple
def zeros_like_aval(aval):
return aval_zeros_likers[type(aval)](aval)
aval_zeros_likers = {}
def zeros_like_abstract_tuple(tup):
return AbstractTuple(map(zeros_like_aval, tup))
aval_zeros_likers[AbstractTuple] = zeros_like_abstract_tuple
def zeros_like_jaxval(val):
return zeros_like_p.bind(val)

View File

@ -80,7 +80,8 @@ def jaxpr_as_fun(typed_jaxpr, *args):
JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive',
'bound_subjaxprs', 'destructure', 'params'])
'bound_subjaxprs', 'restructure',
'destructure', 'params'])
class Primitive(object):
def __init__(self, name):
@ -137,7 +138,11 @@ def eval_jaxpr(jaxpr, consts, freevar_vals, *args):
map(write, jaxpr.invars, args)
map(write, jaxpr.freevars, freevar_vals)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.invars)
if not eqn.restructure:
in_vals = map(read, eqn.invars)
else:
in_vals = [pack(map(read, invars)) if type(invars) is tuple else read(invars)
for invars in eqn.invars]
subfuns = [partial(eval_jaxpr, subjaxpr, map(read, const_bindings),
map(read, freevar_bindings))
for subjaxpr, const_bindings, freevar_bindings
@ -615,7 +620,11 @@ def check_jaxpr(jaxpr):
map(write, jaxpr.freevars)
map(write, jaxpr.invars)
for eqn in jaxpr.eqns:
map(read, eqn.invars)
if not eqn.restructure:
map(read, eqn.invars)
else:
[map(read, invar) if type(invar) is tuple else read(invar)
for invar in eqn.invars]
for subjaxpr, constvars, freevars in eqn.bound_subjaxprs:
map(read, freevars)
map(read_const, constvars)

View File

@ -11,6 +11,7 @@ from jax.lax import _abstractify, _unpack_eqn
from jax.abstract_arrays import ShapedArray
from jax.interpreters import partial_eval as pe
from jax.interpreters import ad
from jax import ad_util
def pvals_with_zeros(zero_components, aval):
@ -98,7 +99,7 @@ def _call_initial_partial_eval(trace, *tracers, **kwargs):
*in_consts, jaxpr=jaxpr_1, consts=consts_1)
residual_tracers = core.pack(map(trace.new_instantiated_const, residuals))
eqn = core.JaxprEqn((residual_tracers,) + tracers, None, call_initial_p, (),
False, dict(jaxpr=jaxpr_2, consts=consts_2))
False, False, dict(jaxpr=jaxpr_2, consts=consts_2))
return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
@ -150,7 +151,7 @@ def update_arrays(i, aval, xs, x):
else:
return lax.dynamic_update_index_in_dim(xs, x[None, ...], i, axis=0)
_scan_const = pe.gensym('_consts')
_scan_newvar = pe.gensym('_scan')
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
def scan_initial(f, init, xs):
@ -162,26 +163,27 @@ def scan_initial(f, init, xs):
lu.wrap_init(f), (carry_pval, x_pval), instantiate=True)
(carry_aval_out, y_aval), _ = pval_out
assert carry_aval == carry_aval_out
lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr, _scan_const)
lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr, _scan_newvar)
consts_aval, _ = _abstractify(core.pack(consts))
in_avals = (consts_aval, carry_aval, x_aval)
out_aval = core.AbstractTuple((carry_aval, y_aval))
jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval)
length = leading_dim_size(xs)
return scan_initial_p.bind(core.pack(consts), init, xs,
length=length, jaxpr=jaxpr)
forward=True, length=length, jaxpr=jaxpr)
def _scan_initial_impl(consts, init, xs, length, jaxpr):
def _scan_initial_impl(consts, init, xs, forward, length, jaxpr):
_, _, x_aval = jaxpr.in_avals
_, y_aval = jaxpr.out_aval
ys_aval = promote_aval_rank(length, y_aval)
def body_fun(i, vals):
idx = i if forward else length - i - 1
carry, ys = vals
x = index_arrays(i, x_aval, xs)
x = index_arrays(idx, x_aval, xs)
carry_out, y = core.jaxpr_as_fun(jaxpr)(consts, carry, x)
ys_out = update_arrays(i, y_aval, ys, y)
ys_out = update_arrays(idx, y_aval, ys, y)
return (carry_out, ys_out)
ys_init = empty_arrays(ys_aval)
@ -189,7 +191,7 @@ def _scan_initial_impl(consts, init, xs, length, jaxpr):
return core.pack((carry, ys))
def _scan_initial_jvp(primals, tangents, length, jaxpr):
def _scan_initial_jvp(primals, tangents, forward, length, jaxpr):
consts, init, xs = primals
consts_dot, init_dot, xs_dot = tangents
consts_aval, carry_aval, x_aval = jaxpr.in_avals
@ -220,7 +222,8 @@ def _scan_initial_jvp(primals, tangents, length, jaxpr):
xs_dual = core.pack((xs, nonzero_xs_dot))
carry_out_dual, ys_dual = scan_initial_p.bind(
consts_dual, init_dual, xs_dual, length=length, jaxpr=jaxpr_jvp)
consts_dual, init_dual, xs_dual,
forward=forward, length=length, jaxpr=jaxpr_jvp)
ys, ys_dot = ys_dual
ys_dot = ad.put_zeros(ad.TangentTuple, where_ys_zeros, ys_dot)
@ -257,6 +260,8 @@ def binary_lattice_join(a, b):
def _scan_initial_partial_eval(trace, *tracers, **kwargs):
jaxpr = kwargs.pop('jaxpr')
length = kwargs.pop('length')
forward = kwargs.pop('forward')
assert not kwargs
in_pvs, in_consts = unzip2([t.pval for t in tracers])
fc_consts, fc_init, fc_xs = map(is_const, in_pvs)
@ -278,20 +283,13 @@ def _scan_initial_partial_eval(trace, *tracers, **kwargs):
out_pv = _put_known_pvs(fc_out, jaxpr.out_aval)
out_carry, (ys, residuals) = scan_initial_p.bind(
*in_consts, length=length, jaxpr=jaxpr_1)
*in_consts, forward=forward, length=length, jaxpr=jaxpr_1)
out_const = core.pack((out_carry, ys))
residual_tracers = core.pack(map(trace.new_instantiated_const, residuals))
residuals_tracer = trace.new_instantiated_const(core.pack(residuals))
d, c, a = lifted_tracers
new_tracers = (d, c, core.pack((a, residual_tracers))) # TODO nonlin pack
# TODO adapt scan to
# option #1:
# scan :: (d -> c -> a -> b) -> d -> c -> [a] -> [b]
# scan :: (d -> c -> a -> alin -> b) -> d -> c -> [a] -> [alin] -> [b]
# option #2:
# extend jaxpr language to have destructuring tuples of variables in invars
# b = g(a, (x, a))
eqn = core.JaxprEqn(new_tracers, None, scan_initial_p, (), False,
dict(length=length, jaxpr=jaxpr_2))
new_tracers = (d, c, (a, residuals_tracer))
eqn = core.JaxprEqn(new_tracers, None, scan_initial_p, (), True, False,
dict(forward=forward, length=length, jaxpr=jaxpr_2))
return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
def _lift_tracer(trace, tracer, is_const):
@ -316,25 +314,57 @@ def _put_known_pvs(is_known, aval):
return pe.JaxprTracerTuple(map(_put_known_pvs, is_known, aval))
def _scan_initial_transpose(ct, consts, init, xs, length, jaxpr):
def _scan_initial_transpose(ct, consts, init, xs, forward, length, jaxpr):
assert consts is None and init is None
import ipdb; ipdb.set_trace() # TODO but xs is also None!
assert type(xs) is tuple
a, res = xs
assert a is None and res is not None
# jaxpr :: d -> c -> (a, res) -> (c, b)
# jaxpr_lifted :: res -> (d, c, a) -> (c, b)
# jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
jaxpr_lifted = _move_res_and_uncurry(jaxpr)
jaxpr_lifted = _move_res_and_uncurry(jaxpr, _scan_newvar)
import ipdb; ipdb.set_trace()
jaxpr_lifted_trans = transpose_jaxpr2(jaxpr_lifted)
jaxpr_trans = _move_stuff_and_add_add(jaxpr_lifted_trans)
assert False
# c_bar, bs_bar = ct
# d_bar = zeros
ct_c, ct_bs = ct
carry_ct = core.pack((ct_c, ad_util.zeros_like_aval(jaxpr.in_avals[0])))
# jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT d, CT c), CT a)
# scan_p.bind :: (d -> c -> a -> (c, b)) -> d -> c -> [a] -> (c, [b])
scan_initial_p.bind(
core.unit, carry_ct, core.pack((ct_bs, res)),
forward=not forward, length=length, jaxpr=jaxpr_trans)
def _move_res_and_uncurry(jaxpr, newvar):
# jaxpr :: d -> c -> (a, res) -> (c, b)
# jaxpr_lifted :: res -> (d, c, a) -> (c, b)
assert len(jaxpr.in_avals) == 3
assert type(jaxpr.in_avals[2]) is core.AbstractTuple
d_aval, c_aval, (a_aval, res_aval) = jaxpr.in_avals
in_avals = [res_aval, core.AbstractTuple((d_aval, c_aval, a_aval))]
d, c, a_res = jaxpr.jaxpr.invars
a = newvar()
res = newvar()
d_c_a = newvar()
invars = [res, d_c_a]
eqns = (
[pe._unpack_eqn(d_c_a, [d, c, a]),
pe._pack_eqn([a, res], a_res)]
+ list(jaxpr.jaxpr.eqns))
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, jaxpr.jaxpr.freevars,
invars, jaxpr.jaxpr.outvar, eqns)
core.skip_checks or core.check_jaxpr(new_jaxpr)
return core.TypedJaxpr(new_jaxpr, jaxpr.literals, in_avals, jaxpr.out_aval)
# return scan_initial_p.bind(core.unit,
# transpose_jaxpr :: (res -> a -> b) -> (res -> CT b -> CT a)
# TODO either top-level restructure, or else munge somehow
def transpose_jaxpr2(jaxpr):
assert len(jaxpr.in_avals) == 2
def transposed(res, b_bar):

View File

@ -145,6 +145,9 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in):
def read_cotangent(v):
return ct_env.get(v, zero)
def read_primal(v):
return primal_env.get(v)
primal_env = {v: val for v, val in zip(jaxpr.freevars, freevar_vals)
if val is not None}
primal_env.update(zip(jaxpr.constvars, consts))
@ -155,12 +158,16 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangent_in):
for eqn in jaxpr.eqns[::-1]:
cts_in = map(read_cotangent, eqn.outvars)
ct_in = TangentTuple(cts_in) if eqn.destructure else cts_in[0]
invals = map(primal_env.get, eqn.invars)
if not eqn.restructure:
invals = map(read_primal, eqn.invars)
else:
invals = [tuple(map(read_primal, v)) if type(v) is tuple
else read_primal(v) for v in eqn.invars]
if eqn.bound_subjaxprs:
subjaxprs, sub_consts, sub_freevar_vals = unzip3([
(subjaxpr,
map(primal_env.get, const_vars),
map(primal_env.get, bound_vars))
map(read_primal, const_vars),
map(read_primal, bound_vars))
for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs])
cts_out, ct_free_vars_out = get_primitive_transpose(eqn.primitive)(
eqn.params, subjaxprs, sub_consts, sub_freevar_vals, invals, ct_in)

View File

@ -74,11 +74,11 @@ class JaxprTrace(Trace):
avals = [t.aval for t in tracers]
out_aval = primitive.abstract_eval(*avals, **params)
partial_val = PartialVal((out_aval, unit))
eqn = JaxprEqn(tracers, None, primitive, (), False, params)
eqn = JaxprEqn(tracers, None, primitive, (), False, False, params)
return JaxprTracer(self, partial_val, eqn)
def pack(self, tracers):
eqn = JaxprEqn(tracers, None, core.pack_p, (), False, {})
eqn = JaxprEqn(tracers, None, core.pack_p, (), False, False, {})
pval = pack_pvals([t.pval for t in tracers])
return JaxprTracer(self, pval, eqn)
@ -92,7 +92,8 @@ class JaxprTrace(Trace):
const_tracers = map(self.new_instantiated_const, consts)
env_tracers = map(self.full_raise, env)
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
eqn = JaxprEqn(tracers, None, call_primitive, (bound_subjaxpr,), False, params)
eqn = JaxprEqn(tracers, None, call_primitive, (bound_subjaxpr,),
False, False, params)
return JaxprTracer(self, PartialVal((out_pv, out_pv_const)), eqn)
def process_map(self, call_primitive, f, tracers, params):
@ -109,7 +110,8 @@ class JaxprTrace(Trace):
jaxpr_converted.invars = list(it.chain(jaxpr.constvars, jaxpr.invars))
invars = tuple(it.chain(const_tracers, tracers))
bound_subjaxpr = (jaxpr_converted, (), env)
eqn = JaxprEqn(invars, None, call_primitive, (bound_subjaxpr,), False, params)
eqn = JaxprEqn(invars, None, call_primitive, (bound_subjaxpr,),
False, False, params)
return JaxprTracer(self, PartialVal((out_pv, out_const)), eqn)
def post_process_call(self, call_primitive, out_tracer):
@ -124,7 +126,8 @@ class JaxprTrace(Trace):
const_tracers = map(trace.new_instantiated_const, consts)
env_tracers = map(trace.full_raise, env)
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,), False, {})
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,),
False, False, {})
return JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), eqn)
return out, todo
@ -149,7 +152,7 @@ def scan_process_primitive(trace, consts, init, xs, avals, jaxpr):
avals=avals1, jaxpr=jaxpr1)
params_out = {'avals' : avals2, 'jaxpr' : jaxpr2}
eqn = JaxprEqn([consts, init, xs], None, scan_p, (), False, params_out)
eqn = JaxprEqn([consts, init, xs], None, scan_p, (), False, False, params_out)
return JaxprTracer(trace, PartialVal((ans, ans_pv)), )
# in_pvs, in_consts = unzip2([t.pval for t in tracers])
@ -278,7 +281,7 @@ class JaxprTracer(Tracer):
if isinstance(pv, AbstractValue):
const = [unit for _ in range(n)]
key = object()
eqn = JaxprEqn([self], [None]*n, core.identity_p, (), True, {})
eqn = JaxprEqn([self], [None]*n, core.identity_p, (), False, True, {})
def child_tracer(i, pval, c):
d = Destructuring(i, eqn, key)
return JaxprTracer(self.trace, PartialVal((pval, c)), d).full_lower()
@ -416,12 +419,16 @@ ConstVar = namedtuple('ConstVar', ['val'])
LambdaBinding = namedtuple('LambdaBinding', [])
def eqn_tracer_to_var(var, outvars, eqn):
invars, _, primitive, bound_subjaxprs, destructure, params = eqn
invars = map(var, invars)
invars, _, primitive, bound_subjaxprs, restructure, destructure, params = eqn
if not restructure:
invars = map(var, invars)
else:
invars = [tuple(map(var, v)) if type(v) is tuple else var(v)
for v in invars]
new_bound_subjaxprs = [(j, map(var, c), map(var, f))
for j, c, f in bound_subjaxprs]
return JaxprEqn(invars, outvars, primitive,
new_bound_subjaxprs, destructure, params)
new_bound_subjaxprs, restructure, destructure, params)
def tracers_to_jaxpr(in_tracers, out_tracer):
@ -486,36 +493,18 @@ class Var(object):
def eqn_parents(eqn):
subjaxpr_tracers = [it.chain(c, f) for _, c, f in eqn.bound_subjaxprs]
return list(it.chain(eqn.invars, *subjaxpr_tracers))
if not eqn.restructure:
return list(it.chain(eqn.invars, *subjaxpr_tracers))
else:
invars = []
for v in eqn.invars:
if type(v) is tuple:
invars.extend(v)
else:
invars.append(v)
return list(it.chain(invars, *subjaxpr_tracers))
def eval_jaxpr_raw(jaxpr, consts, freevar_vals, *args):
assert all(map(core.valid_jaxtype, consts))
assert all(map(core.valid_jaxtype, freevar_vals))
assert all(map(core.valid_jaxtype, args))
def read(v):
return env[v]
def write(v, val):
env[v] = val
env = {}
write(unitvar, unit)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
map(write, jaxpr.freevars, freevar_vals)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.invars)
subfuns = [partial(core.eval_jaxpr, subjaxpr, map(read, const_bindings),
map(read, freevar_bindings))
for subjaxpr, const_bindings, freevar_bindings
in eqn.bound_subjaxprs]
ans = eqn.primitive.impl(*(subfuns + in_vals), **eqn.params) # not bind!
outvals = list(ans) if eqn.destructure else [ans]
map(write, eqn.outvars, outvals)
return read(jaxpr.outvar)
def compiled_call_impl(fun, *args, **kwargs):
with new_master(JaxprTrace, True) as master:
pvals = map(abstractify, args)
@ -616,10 +605,10 @@ def _closure_convert_jaxpr(jaxpr, newvar):
return lifted_jaxpr
def _unpack_eqn(invar, outvars):
return core.JaxprEqn([invar], outvars, core.identity_p, (), True, {})
return core.JaxprEqn([invar], outvars, core.identity_p, (), False, True, {})
def _pack_eqn(invars, outvar):
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, {})
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, False, {})
def partial_eval_jaxpr2(jaxpr, first_components):

View File

@ -222,7 +222,7 @@ def replicated_comp(jaxpr, ax_env, const_vals, freevar_shapes, *arg_shapes):
map(write, all_freevars, map(c.ParameterWithShape, freevar_shapes))
map(write, jaxpr.invars, map(c.ParameterWithShape, arg_shapes))
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
in_nodes = map(read, eqn.invars) # TODO
if eqn.primitive in parallel_translation_rules:
name = eqn.params['axis_name']
params = {k: eqn.params[k] for k in eqn.params if k != 'axis_name'}

View File

@ -163,7 +163,7 @@ def jaxpr_computation(jaxpr, const_vals, freevar_shapes, *arg_shapes):
map(write, all_freevars, map(c.ParameterWithShape, freevar_shapes))
map(write, jaxpr.invars, map(c.ParameterWithShape, arg_shapes))
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
in_nodes = map(read, eqn.invars) # TODO
in_shapes = map(c.GetShape, in_nodes)
subcs = [
jaxpr_computation(

View File

@ -3819,10 +3819,10 @@ xla.translations[while_p] = _while_loop_translation_rule
batching.primitive_batchers[while_p] = _while_loop_batching_rule
def _unpack_eqn(invar, outvars):
return core.JaxprEqn([invar], outvars, core.identity_p, (), True, {})
return core.JaxprEqn([invar], outvars, core.identity_p, (), False, True, {})
def _pack_eqn(invars, outvar):
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, {})
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, False, {})
def _cond_abstract_eval(pred, true_op, true_consts, false_op, false_consts,