2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
2018-11-21 13:27:26 -08:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
import itertools as it
|
|
|
|
from collections import namedtuple, Counter, defaultdict
|
|
|
|
|
|
|
|
from .. import core
|
|
|
|
from .. import linear_util as lu
|
2019-02-21 11:47:26 -08:00
|
|
|
from ..abstract_arrays import ShapedArray, ConcreteArray
|
2018-11-17 18:03:33 -08:00
|
|
|
from ..linear_util import thunk, transformation, transformation_with_aux
|
|
|
|
from ..util import unzip2, safe_zip, safe_map, toposort, partial
|
|
|
|
from ..core import (Trace, Tracer, new_master, Jaxpr, JaxprEqn, get_aval, pack,
|
|
|
|
AbstractValue, AbstractTuple, unit, unitvar, Primitive,
|
2019-04-18 09:00:11 -07:00
|
|
|
call_p, TypedJaxpr)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
|
|
|
zip = safe_zip
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-09 07:23:39 -07:00
|
|
|
# A partial value (pval) is modeled as a pair (pv, const), as per
|
|
|
|
# type PVal = (PV, Const)
|
|
|
|
# data PV = NonePV | AbstractPV AbstractValue | JaxprTracerTuple [PV]
|
|
|
|
# type Const = MaybeTraced JaxType
|
|
|
|
# where the NonePV arm indicates a known (constant) value, the AbstractPV arm
|
|
|
|
# indicates an unknown value, and the JaxprTracerTuple indicates a finer-grained
|
|
|
|
# representation that might be a mixture.
|
|
|
|
# There are two additional invariants:
|
|
|
|
# 1. when the pv is a JaxprTracerTuple, then the const is a JaxTuple of the
|
|
|
|
# same length (or a traced version);
|
|
|
|
# 2. when the pv is an AbstractValue, then the const must be unit.
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTrace(Trace):
|
|
|
|
def pure(self, val):
|
|
|
|
return self.new_const(val)
|
|
|
|
|
|
|
|
def lift(self, val):
|
|
|
|
return self.new_const(val)
|
|
|
|
|
|
|
|
def sublift(self, val):
|
|
|
|
return JaxprTracer(self, val.pval, FreeVar(val))
|
|
|
|
|
|
|
|
def new_const(self, val):
|
|
|
|
if isinstance(val, Tracer) and val.trace.level == self.level:
|
|
|
|
raise Exception
|
|
|
|
return JaxprTracer(self, PartialVal((None, val)), unit)
|
|
|
|
|
|
|
|
def new_instantiated_const(self, val):
|
|
|
|
return JaxprTracer(self, PartialVal((get_aval(val), unit)), ConstVar(val))
|
|
|
|
|
|
|
|
def new_arg(self, pval):
|
|
|
|
_, const = pval
|
|
|
|
return JaxprTracer(self, pval, LambdaBinding())
|
|
|
|
|
|
|
|
def instantiate_const(self, tracer):
|
|
|
|
pv, const = tracer.pval
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
return tracer
|
|
|
|
elif isinstance(pv, JaxprTracerTuple):
|
|
|
|
return pack(map(lambda t: self.instantiate_const(self.full_raise(t)), tracer))
|
|
|
|
elif pv is None:
|
|
|
|
return self.new_instantiated_const(const)
|
|
|
|
else:
|
|
|
|
raise TypeError(pv)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-04-01 16:03:56 -04:00
|
|
|
if primitive in custom_partial_eval_rules:
|
|
|
|
partial_eval = custom_partial_eval_rules[primitive]
|
|
|
|
return partial_eval(self, *tracers, **params)
|
|
|
|
else:
|
|
|
|
tracers = map(self.instantiate_const, tracers)
|
|
|
|
avals = [t.aval for t in tracers]
|
|
|
|
out_aval = primitive.abstract_eval(*avals, **params)
|
2019-04-25 10:43:50 -07:00
|
|
|
eqn = JaxprEqn(tracers, None, primitive, (), False, False, params)
|
2019-04-23 11:09:38 -07:00
|
|
|
return JaxprTracer(self, PartialVal((out_aval, unit)), eqn)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def pack(self, tracers):
|
2019-04-25 10:43:50 -07:00
|
|
|
eqn = JaxprEqn(tracers, None, core.pack_p, (), False, False, {})
|
2018-11-17 18:03:33 -08:00
|
|
|
pval = pack_pvals([t.pval for t in tracers])
|
|
|
|
return JaxprTracer(self, pval, eqn)
|
|
|
|
|
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
2019-02-21 11:47:26 -08:00
|
|
|
if call_primitive in map_primitives:
|
|
|
|
return self.process_map(call_primitive, f, tracers, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
|
|
|
fun, aux = partial_eval(f, self, in_pvs)
|
2019-02-20 12:36:18 -08:00
|
|
|
out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
|
2018-11-17 18:03:33 -08:00
|
|
|
out_pv, jaxpr, env = aux()
|
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
|
2019-04-25 10:43:50 -07:00
|
|
|
eqn = JaxprEqn(tracers, None, call_primitive, (bound_subjaxpr,),
|
|
|
|
False, False, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
return JaxprTracer(self, PartialVal((out_pv, out_pv_const)), eqn)
|
|
|
|
|
2019-02-20 12:36:18 -08:00
|
|
|
def process_map(self, call_primitive, f, tracers, params):
|
|
|
|
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
2019-02-21 21:40:10 -08:00
|
|
|
reduced_pvs = map(remove_axis_from_pv, in_pvs)
|
2019-02-21 11:47:26 -08:00
|
|
|
fun, aux = partial_eval(f, self, reduced_pvs)
|
|
|
|
out_const, consts = call_primitive.bind(fun, *in_consts, **params)
|
|
|
|
out_pv_reduced, jaxpr, env = aux()
|
|
|
|
out_pv = add_axis_to_pv(params['axis_size'], out_pv_reduced)
|
2019-02-20 12:36:18 -08:00
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
jaxpr_converted = jaxpr.copy()
|
|
|
|
jaxpr_converted.constvars = []
|
|
|
|
jaxpr_converted.invars = list(it.chain(jaxpr.constvars, jaxpr.invars))
|
2019-02-21 11:47:26 -08:00
|
|
|
invars = tuple(it.chain(const_tracers, tracers))
|
2019-02-20 12:36:18 -08:00
|
|
|
bound_subjaxpr = (jaxpr_converted, (), env)
|
2019-04-25 10:43:50 -07:00
|
|
|
eqn = JaxprEqn(invars, None, call_primitive, (bound_subjaxpr,),
|
|
|
|
False, False, params)
|
2019-02-21 11:47:26 -08:00
|
|
|
return JaxprTracer(self, PartialVal((out_pv, out_const)), eqn)
|
2019-02-20 12:36:18 -08:00
|
|
|
|
2019-05-03 12:37:14 -07:00
|
|
|
def post_process_call(self, call_primitive, out_tracer, params):
|
2019-02-20 12:36:18 -08:00
|
|
|
# TODO(mattjj): post_process_map
|
2018-11-17 18:03:33 -08:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr([], out_tracer)
|
|
|
|
out_pv, out_pv_const = out_tracer.pval
|
|
|
|
out = pack((out_pv_const, pack(consts)))
|
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
|
|
|
out_pv_const, consts = x
|
|
|
|
trace = JaxprTrace(master, core.cur_sublevel())
|
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
|
|
|
env_tracers = map(trace.full_raise, env)
|
|
|
|
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
|
2019-04-25 10:43:50 -07:00
|
|
|
eqn = JaxprEqn([], None, call_primitive, (bound_subjaxpr,),
|
2019-05-09 07:23:39 -07:00
|
|
|
False, False, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
return JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), eqn)
|
|
|
|
|
|
|
|
return out, todo
|
|
|
|
|
2019-02-21 11:47:26 -08:00
|
|
|
map_primitives = set()
|
|
|
|
|
2019-03-27 17:16:54 -04:00
|
|
|
def unzip_scan_jaxpr(jaxpr, consts, init, xs, avals):
|
|
|
|
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr))
|
|
|
|
|
|
|
|
|
|
|
|
assert False
|
|
|
|
|
|
|
|
|
|
|
|
def scan_process_primitive(trace, consts, init, xs, avals, jaxpr):
|
|
|
|
jaxpr1, jaxpr2, avals1, avals2, ans_pv = unzip_scan_jaxpr(
|
|
|
|
jaxpr, consts, init, xs, avals)
|
|
|
|
const_pv , consts_const = consts
|
|
|
|
init_pv , inits_const = init
|
|
|
|
xs_pv , xs_const = xs
|
|
|
|
|
|
|
|
ans = scan_p.bind(consts_const, inits_const, xs_const,
|
|
|
|
avals=avals1, jaxpr=jaxpr1)
|
|
|
|
|
|
|
|
params_out = {'avals' : avals2, 'jaxpr' : jaxpr2}
|
2019-04-25 10:43:50 -07:00
|
|
|
eqn = JaxprEqn([consts, init, xs], None, scan_p, (), False, False, params_out)
|
2019-03-27 17:16:54 -04:00
|
|
|
return JaxprTracer(trace, PartialVal((ans, ans_pv)), )
|
|
|
|
|
|
|
|
# in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
|
|
|
# fun, aux = partial_eval(f, self, in_pvs)
|
|
|
|
# out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
|
|
|
|
# out_pv, jaxpr, env = aux()
|
|
|
|
# const_tracers = map(self.new_instantiated_const, consts)
|
|
|
|
# env_tracers = map(self.full_raise, env)
|
|
|
|
# bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
|
|
|
|
# eqn = JaxprEqn(tracers, None, call_primitive, (bound_subjaxpr,), False, params)
|
|
|
|
# return JaxprTracer(self, PartialVal((out_pv, out_pv_const)), eqn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# tracers = map(self.instantiate_const, tracers)
|
|
|
|
# avals = [t.aval for t in tracers]
|
|
|
|
# out_aval = primitive.abstract_eval(*avals, **params)
|
|
|
|
# eqn = JaxprEqn(tracers, None, primitive, (), False, params)
|
|
|
|
# return JaxprTracer(self, PartialVal((out_aval, unit)), eqn)
|
|
|
|
assert False
|
|
|
|
|
|
|
|
|
|
|
|
|
2019-02-21 11:47:26 -08:00
|
|
|
def remove_axis_from_pv(pv):
|
|
|
|
if pv is None:
|
|
|
|
return pv
|
|
|
|
elif isinstance(pv, AbstractValue):
|
|
|
|
return remove_axis_from_aval(pv)
|
|
|
|
elif type(pv) is JaxprTracerTuple:
|
|
|
|
return JaxprTracerTuple(map(remove_axis_from_pv, pv))
|
|
|
|
else:
|
|
|
|
raise TypeError(type(pv))
|
|
|
|
|
|
|
|
def remove_axis_from_aval(aval):
|
|
|
|
if type(aval) is AbstractTuple:
|
|
|
|
return AbstractTuple(map(remove_axis_from_aval, aval))
|
|
|
|
elif isinstance(aval, ShapedArray):
|
|
|
|
# might be raising abstraction level from Concrete here
|
|
|
|
return ShapedArray(aval.shape[1:], aval.dtype)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
|
|
|
|
def add_axis_to_pv(size, pv):
|
|
|
|
if pv is None:
|
|
|
|
return pv
|
|
|
|
elif isinstance(pv, AbstractValue):
|
|
|
|
return add_axis_to_aval(size, pv)
|
|
|
|
elif type(pv) is JaxprTracerTuple:
|
|
|
|
return JaxprTracerTuple(map(partial(add_axis_to_pv, size), pv))
|
|
|
|
else:
|
|
|
|
raise TypeError(type(pv))
|
|
|
|
|
|
|
|
def add_axis_to_aval(size, aval):
|
|
|
|
if type(aval) is AbstractTuple:
|
|
|
|
return AbstractTuple(map(partial(add_axis_to_aval, size), aval))
|
|
|
|
elif isinstance(aval, ShapedArray):
|
|
|
|
return ShapedArray((size,) + aval.shape, aval.dtype)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def partial_eval(f, trace, pvs):
|
2019-04-01 16:03:56 -04:00
|
|
|
f = trace_to_subjaxpr(f, trace.master, False)
|
2018-11-17 18:03:33 -08:00
|
|
|
return partial_eval_wrapper(f, tuple(pvs))
|
|
|
|
|
|
|
|
|
|
|
|
@transformation_with_aux
|
2019-04-05 12:02:24 -07:00
|
|
|
def partial_eval_wrapper(avals, *consts):
|
2019-04-10 22:09:14 -07:00
|
|
|
py_args = (map(PartialVal, zip(avals, consts)),)
|
|
|
|
jaxpr, (out_pval, consts, env) = yield py_args, {}
|
2019-02-21 11:47:26 -08:00
|
|
|
out_pv, out_const = out_pval
|
|
|
|
out = pack((out_const, pack(consts)))
|
2018-11-17 18:03:33 -08:00
|
|
|
yield out, (out_pv, jaxpr, env)
|
|
|
|
|
|
|
|
|
2019-02-13 14:28:30 -08:00
|
|
|
def abstract_eval_fun(fun, *avals, **params):
|
|
|
|
pvs_in = [PartialVal((a, unit)) for a in avals]
|
2019-02-28 12:07:33 -08:00
|
|
|
_, pvout, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvs_in)
|
2019-02-13 14:28:30 -08:00
|
|
|
aval_out, _ = pvout
|
|
|
|
return aval_out
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['pval', 'recipe']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __init__(self, trace, pval, recipe):
|
|
|
|
assert isinstance(pval, PartialVal)
|
|
|
|
pv, const = pval
|
|
|
|
if isinstance(const, Tracer):
|
|
|
|
assert const.trace.level < trace.level
|
|
|
|
self.trace = trace
|
|
|
|
self.pval = pval
|
|
|
|
self.recipe = recipe
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return 'Traced<{}:{}>'.format(self.aval, self.trace)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
pv, const = self.pval
|
|
|
|
return partial_val_aval(pv, const)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def parents(self):
|
|
|
|
if isinstance(self.recipe, JaxprEqn):
|
|
|
|
return eqn_parents(self.recipe)
|
|
|
|
elif isinstance(self.recipe, Destructuring):
|
|
|
|
return eqn_parents(self.recipe.eqn)
|
|
|
|
else:
|
|
|
|
return []
|
|
|
|
|
|
|
|
def ispure(self):
|
|
|
|
pv, _ = self.pval
|
|
|
|
return pv is None
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
if self.ispure():
|
|
|
|
_, const = self.pval
|
|
|
|
return core.full_lower(const)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
|
|
|
def unpack(self):
|
|
|
|
pv, const = self.pval
|
|
|
|
if isinstance(pv, (AbstractValue, JaxprTracerTuple)):
|
|
|
|
n = len(pv)
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
const = [unit for _ in range(n)]
|
|
|
|
key = object()
|
2019-04-25 10:43:50 -07:00
|
|
|
eqn = JaxprEqn([self], [None]*n, core.identity_p, (), False, True, {})
|
2018-11-17 18:03:33 -08:00
|
|
|
def child_tracer(i, pval, c):
|
|
|
|
d = Destructuring(i, eqn, key)
|
|
|
|
return JaxprTracer(self.trace, PartialVal((pval, c)), d).full_lower()
|
|
|
|
return map(child_tracer, range(n), pv, const)
|
|
|
|
elif pv is None:
|
|
|
|
return const
|
|
|
|
else:
|
|
|
|
raise TypeError(pv)
|
|
|
|
|
|
|
|
class JaxprTracerTuple(tuple): pass
|
|
|
|
|
|
|
|
Destructuring = namedtuple('Destructuring', ['i', 'eqn', 'key'])
|
|
|
|
|
|
|
|
class PartialVal(tuple):
|
2018-11-19 07:43:23 -08:00
|
|
|
def __new__(cls, xs):
|
2019-05-09 07:23:39 -07:00
|
|
|
pv, const = xs
|
|
|
|
if not core.skip_checks:
|
|
|
|
# type checks
|
|
|
|
assert isinstance(pv, valid_pv_types), xs
|
|
|
|
assert isinstance(const, core.Tracer) or core.valid_jaxtype(const), xs
|
|
|
|
# invariant checks
|
|
|
|
if type(pv) is JaxprTracerTuple:
|
|
|
|
assert len(pv) == len(const), xs
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
assert const == core.unit, xs
|
2018-11-19 07:43:23 -08:00
|
|
|
return tuple.__new__(cls, xs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
valid_pv_types = (AbstractValue, JaxprTracerTuple, type(None))
|
|
|
|
|
|
|
|
|
|
|
|
abstract_unit = core.AbstractTuple()
|
|
|
|
|
|
|
|
def merge_pvals(val, pval):
|
|
|
|
pv, const = pval
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
return val
|
|
|
|
elif isinstance(pv, JaxprTracerTuple):
|
|
|
|
return pack(map(merge_pvals, val, zip(pv, const)))
|
|
|
|
elif pv is None:
|
|
|
|
return const
|
|
|
|
else:
|
|
|
|
raise TypeError(pv)
|
|
|
|
|
add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:
```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x) # relu
```
The partial evaluation logic works with tuples, so this works too:
```python
lax.cond(x < 0,
x, lambda x: (x, x, 1, 1, 1),
x, lambda x: (x, 1, x, 1, 2))
```
in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.
For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
|
|
|
def join_pvals(pval1, pval2):
|
|
|
|
pv1, const1 = pval1
|
|
|
|
pv2, const2 = pval2
|
|
|
|
if pv1 is None and pv2 is None:
|
|
|
|
aval1, aval2 = core.get_aval(const1), core.get_aval(const2)
|
|
|
|
if aval1 == aval2:
|
|
|
|
return pval1 # both pvals known, equal constants
|
|
|
|
else:
|
|
|
|
aval = core.lattice_join(aval1, aval2)
|
|
|
|
return PartialVal((aval, unit)) # both pvals known, different constants
|
|
|
|
elif pv1 is None and isinstance(pv2, AbstractValue):
|
|
|
|
aval = pv2
|
|
|
|
return PartialVal((aval, unit)) # first pval known, second not known
|
|
|
|
elif isinstance(pv1, AbstractValue) and pv2 is None:
|
|
|
|
aval = pv1
|
|
|
|
return PartialVal((aval, unit)) # first pval not known, second known
|
|
|
|
elif isinstance(pv1, AbstractValue) and isinstance(pv2, AbstractValue):
|
|
|
|
aval = core.lattice_join(pv1, pv2)
|
|
|
|
return PartialVal((aval, unit)) # neither is known
|
|
|
|
else:
|
|
|
|
# the pvals are tuples with some mixtures of known/unknown
|
|
|
|
assert isinstance(pv1, JaxprTracerTuple) or isinstance(pv2, JaxprTracerTuple)
|
|
|
|
pv1 = [None] * len(pv2) if pv1 is None else pv1
|
|
|
|
pv2 = [None] * len(pv1) if pv2 is None else pv2
|
|
|
|
pvals1, pvals2 = zip(pv1, const1), zip(pv2, const2)
|
|
|
|
join_pvs, join_consts = unzip2(map(join_pvals, pvals1, pvals2))
|
|
|
|
if all(isinstance(pv, AbstractValue) for pv in join_pvs):
|
2019-05-03 12:01:12 -07:00
|
|
|
return PartialVal((AbstractTuple(join_pvs), pack(join_consts)))
|
add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:
```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x) # relu
```
The partial evaluation logic works with tuples, so this works too:
```python
lax.cond(x < 0,
x, lambda x: (x, x, 1, 1, 1),
x, lambda x: (x, 1, x, 1, 2))
```
in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.
For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
|
|
|
else:
|
2019-05-03 12:01:12 -07:00
|
|
|
return PartialVal((JaxprTracerTuple(join_pvs), pack(join_consts)))
|
add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:
```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x) # relu
```
The partial evaluation logic works with tuples, so this works too:
```python
lax.cond(x < 0,
x, lambda x: (x, x, 1, 1, 1),
x, lambda x: (x, 1, x, 1, 2))
```
in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.
For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def as_abstract_val(pv):
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
return pv
|
|
|
|
elif isinstance(pv, JaxprTracerTuple):
|
|
|
|
return AbstractTuple(map(as_abstract_val, pv))
|
|
|
|
elif pv is None:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError("{} is not abstract".format(pv))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def partial_val_aval(pv, const):
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
return pv
|
|
|
|
elif isinstance(pv, JaxprTracerTuple):
|
|
|
|
return AbstractTuple(map(partial_val_aval, pv, const))
|
|
|
|
elif pv is None:
|
|
|
|
return get_aval(const)
|
|
|
|
else:
|
|
|
|
raise TypeError(pv)
|
|
|
|
|
|
|
|
def pack_pvals(pvals):
|
|
|
|
pvs, consts = unzip2(pvals)
|
|
|
|
if all(pv is None for pv in pvs):
|
2019-05-09 07:23:39 -07:00
|
|
|
return PartialVal((None, pack(consts)))
|
2018-11-17 18:03:33 -08:00
|
|
|
elif all(isinstance(pv, AbstractValue) for pv in pvs):
|
|
|
|
pv_out = AbstractTuple(pvs)
|
2019-05-09 07:23:39 -07:00
|
|
|
return PartialVal((pv_out, unit))
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
pv_out = JaxprTracerTuple(pvs)
|
2019-05-09 07:23:39 -07:00
|
|
|
return PartialVal((pv_out, pack(consts)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def abstractify(x):
|
|
|
|
return PartialVal((core.concrete_aval(x), unit))
|
|
|
|
|
|
|
|
def trace_unwrapped_to_jaxpr(fun, pvals, **kwargs):
|
2019-04-10 22:09:14 -07:00
|
|
|
return trace_to_jaxpr(lu.wrap_init(fun, kwargs), pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def trace_to_jaxpr(fun, pvals, **kwargs):
|
|
|
|
"""Traces a function, given abstract inputs, to a jaxpr."""
|
2019-04-01 16:03:56 -04:00
|
|
|
instantiate = kwargs.pop('instantiate', False)
|
2018-11-17 18:03:33 -08:00
|
|
|
with new_master(JaxprTrace) as master:
|
2019-04-01 16:03:56 -04:00
|
|
|
fun = trace_to_subjaxpr(fun, master, instantiate)
|
2019-04-10 22:09:14 -07:00
|
|
|
jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
assert not env
|
|
|
|
del master
|
|
|
|
|
|
|
|
return jaxpr, out_pval, consts
|
|
|
|
|
|
|
|
@transformation
|
2019-04-01 16:03:56 -04:00
|
|
|
def trace_to_subjaxpr(master, instantiate, pvals):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
|
|
|
trace = JaxprTrace(master, core.cur_sublevel())
|
|
|
|
in_tracers = map(trace.new_arg, pvals)
|
2019-04-10 22:09:14 -07:00
|
|
|
out_tracer = yield in_tracers, {}
|
2018-11-17 18:03:33 -08:00
|
|
|
out_tracer = trace.full_raise(out_tracer)
|
2019-04-01 16:03:56 -04:00
|
|
|
|
|
|
|
if instantiate:
|
|
|
|
out_tracer = trace.instantiate_const(out_tracer)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracer)
|
|
|
|
out_pval = out_tracer.pval
|
|
|
|
del trace, in_tracers, out_tracer
|
|
|
|
yield jaxpr, (out_pval, consts, env)
|
|
|
|
|
|
|
|
|
|
|
|
FreeVar = namedtuple('FreeVar', ['val'])
|
|
|
|
ConstVar = namedtuple('ConstVar', ['val'])
|
|
|
|
LambdaBinding = namedtuple('LambdaBinding', [])
|
|
|
|
|
|
|
|
def eqn_tracer_to_var(var, outvars, eqn):
|
2019-04-25 10:43:50 -07:00
|
|
|
invars, _, primitive, bound_subjaxprs, restructure, destructure, params = eqn
|
|
|
|
if not restructure:
|
|
|
|
invars = map(var, invars)
|
|
|
|
else:
|
|
|
|
invars = [tuple(map(var, v)) if type(v) is tuple else var(v)
|
|
|
|
for v in invars]
|
2018-11-17 18:03:33 -08:00
|
|
|
new_bound_subjaxprs = [(j, map(var, c), map(var, f))
|
|
|
|
for j, c, f in bound_subjaxprs]
|
|
|
|
return JaxprEqn(invars, outvars, primitive,
|
2019-04-25 10:43:50 -07:00
|
|
|
new_bound_subjaxprs, restructure, destructure, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def tracers_to_jaxpr(in_tracers, out_tracer):
|
|
|
|
newvar = gensym('')
|
|
|
|
t_to_var = defaultdict(newvar)
|
2018-11-21 13:20:44 -08:00
|
|
|
var = lambda t: t_to_var[id(t)]
|
2018-11-17 18:03:33 -08:00
|
|
|
sorted_tracers = toposort(out_tracer)
|
|
|
|
invars = map(var, in_tracers)
|
|
|
|
eqns = []
|
|
|
|
env = {}
|
|
|
|
consts = {}
|
|
|
|
destructuring_vars = {}
|
|
|
|
for t in sorted_tracers:
|
|
|
|
recipe = t.recipe
|
|
|
|
if isinstance(recipe, JaxprEqn):
|
|
|
|
eqns.append(eqn_tracer_to_var(var, [var(t)], recipe))
|
|
|
|
elif isinstance(recipe, LambdaBinding):
|
|
|
|
assert in_tracers, "Lambda binding with no args"
|
|
|
|
elif isinstance(recipe, FreeVar):
|
|
|
|
env[var(t)] = recipe.val
|
|
|
|
elif isinstance(recipe, ConstVar):
|
|
|
|
consts[var(t)] = recipe.val
|
|
|
|
elif isinstance(recipe, Destructuring):
|
|
|
|
i, eqn, key = recipe
|
|
|
|
if key not in destructuring_vars:
|
|
|
|
outvars = [newvar() for _ in eqn.outvars]
|
|
|
|
eqns.append(eqn_tracer_to_var(var, outvars, eqn))
|
|
|
|
destructuring_vars[key] = outvars
|
|
|
|
else:
|
|
|
|
outvars = destructuring_vars[key]
|
2018-11-21 13:20:44 -08:00
|
|
|
t_to_var[id(t)] = outvars[i]
|
2018-11-17 18:03:33 -08:00
|
|
|
elif recipe is unit:
|
2018-11-21 13:20:44 -08:00
|
|
|
t_to_var[id(t)] = unitvar
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError(recipe)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
env_vars, env_vals = unzip2(env.items())
|
|
|
|
const_vars, const_vals = unzip2(consts.items())
|
|
|
|
jaxpr = Jaxpr(const_vars, env_vars, invars, var(out_tracer), eqns)
|
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
|
|
|
return jaxpr, const_vals, env_vals
|
|
|
|
|
|
|
|
|
|
|
|
def gensym(suffix):
|
|
|
|
counter = it.count()
|
2018-11-21 13:20:44 -08:00
|
|
|
return lambda: Var(next(counter), suffix)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class Var(object):
|
|
|
|
def __init__(self, count, suffix):
|
|
|
|
self.count = count
|
|
|
|
self.suffix = suffix
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
rem = self.count
|
|
|
|
s = ''
|
|
|
|
while True:
|
2018-11-21 13:20:44 -08:00
|
|
|
rem, i = rem // 26, rem % 26
|
2018-11-17 18:03:33 -08:00
|
|
|
s = chr(97 + i % 26) + s
|
|
|
|
if not rem:
|
|
|
|
break
|
|
|
|
return s + self.suffix
|
|
|
|
|
|
|
|
def eqn_parents(eqn):
|
|
|
|
subjaxpr_tracers = [it.chain(c, f) for _, c, f in eqn.bound_subjaxprs]
|
2019-04-25 10:43:50 -07:00
|
|
|
if not eqn.restructure:
|
|
|
|
return list(it.chain(eqn.invars, *subjaxpr_tracers))
|
|
|
|
else:
|
|
|
|
invars = []
|
|
|
|
for v in eqn.invars:
|
|
|
|
if type(v) is tuple:
|
|
|
|
invars.extend(v)
|
|
|
|
else:
|
|
|
|
invars.append(v)
|
|
|
|
return list(it.chain(invars, *subjaxpr_tracers))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-04-10 22:09:14 -07:00
|
|
|
def compiled_call_impl(fun, *args):
|
2018-11-17 18:03:33 -08:00
|
|
|
with new_master(JaxprTrace, True) as master:
|
|
|
|
pvals = map(abstractify, args)
|
2019-04-01 16:03:56 -04:00
|
|
|
jaxpr, (pval, consts, env) = trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
jaxpr_ans = eval_jaxpr_raw(jaxpr, consts, env, *args)
|
|
|
|
ans = merge_pvals(jaxpr_ans, pval)
|
|
|
|
del master, pvals, pval, consts, env, jaxpr_ans, jaxpr
|
|
|
|
return ans
|
|
|
|
|
|
|
|
compiled_call_p = Primitive('compiled_call')
|
|
|
|
compiled_call = partial(core.call_bind, compiled_call_p)
|
|
|
|
compiled_call_p.def_custom_bind(compiled_call)
|
|
|
|
compiled_call_p.def_impl(compiled_call_impl)
|
2019-04-01 16:03:56 -04:00
|
|
|
|
|
|
|
|
2019-04-05 12:02:24 -07:00
|
|
|
def unzip_tracer_tuple(pvals):
|
|
|
|
pvs, consts = unzip2(pvals)
|
|
|
|
return PartialVal((JaxprTracerTuple(pvs), pack(consts)))
|
|
|
|
|
2019-05-10 08:20:40 -07:00
|
|
|
def as_pval(aval, is_unknown, val):
|
|
|
|
t = type(is_unknown)
|
2019-04-05 12:02:24 -07:00
|
|
|
if t is tuple:
|
2019-05-10 08:20:40 -07:00
|
|
|
return unzip_tracer_tuple(map(as_pval, aval, is_unknown, val))
|
2019-04-05 12:02:24 -07:00
|
|
|
elif t is bool:
|
2019-05-10 08:20:40 -07:00
|
|
|
if is_unknown:
|
2019-04-05 12:02:24 -07:00
|
|
|
return PartialVal((aval, core.unit))
|
2019-05-10 08:20:40 -07:00
|
|
|
else:
|
|
|
|
return PartialVal((None, val))
|
2019-04-05 12:02:24 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(t)
|
|
|
|
|
2019-05-10 08:20:40 -07:00
|
|
|
def as_pval2(aval, is_unknown):
|
|
|
|
t = type(is_unknown)
|
2019-04-05 12:02:24 -07:00
|
|
|
if t is tuple:
|
2019-05-10 08:20:40 -07:00
|
|
|
return unzip_tracer_tuple(map(as_pval2, aval, is_unknown))
|
2019-04-05 12:02:24 -07:00
|
|
|
elif t is bool:
|
2019-05-10 08:20:40 -07:00
|
|
|
if is_unknown:
|
2019-04-05 12:02:24 -07:00
|
|
|
return PartialVal((core.AbstractTuple(()), core.unit))
|
2019-05-10 08:20:40 -07:00
|
|
|
else:
|
|
|
|
return PartialVal((aval, core.unit))
|
2019-04-05 12:02:24 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(t)
|
|
|
|
|
2019-05-10 08:20:40 -07:00
|
|
|
def unknown(x):
|
2019-04-05 19:44:38 -07:00
|
|
|
if x is None:
|
|
|
|
return False
|
2019-05-10 08:20:40 -07:00
|
|
|
elif type(x) is JaxprTracerTuple:
|
|
|
|
return tuple(map(unknown, x))
|
|
|
|
elif isinstance(x, core.AbstractValue):
|
|
|
|
return True
|
2019-04-05 19:44:38 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(type(x))
|
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
def _closure_convert_jaxpr(jaxpr):
|
2019-05-08 16:27:23 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2019-04-09 08:45:34 -07:00
|
|
|
lifted_jaxpr = jaxpr.copy()
|
|
|
|
lifted_jaxpr.constvars = ()
|
2019-05-01 15:47:01 -07:00
|
|
|
lifted_jaxpr.invars = [tuple(jaxpr.constvars)] + list(jaxpr.invars)
|
|
|
|
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
2019-04-09 08:45:34 -07:00
|
|
|
return lifted_jaxpr
|
|
|
|
|
|
|
|
def _unpack_eqn(invar, outvars):
|
2019-04-25 10:43:50 -07:00
|
|
|
return core.JaxprEqn([invar], outvars, core.identity_p, (), False, True, {})
|
2019-04-09 08:45:34 -07:00
|
|
|
|
|
|
|
def _pack_eqn(invars, outvar):
|
2019-04-25 10:43:50 -07:00
|
|
|
return core.JaxprEqn(invars, [outvar], core.pack_p, (), False, False, {})
|
2019-04-01 16:03:56 -04:00
|
|
|
|
|
|
|
|
2019-05-10 08:20:40 -07:00
|
|
|
def partial_eval_jaxpr(jaxpr, second_components):
|
2019-04-23 09:15:16 -07:00
|
|
|
# jaxpr :: d -> c -> a -> (c, b)
|
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-04-11 14:50:58 -07:00
|
|
|
|
|
|
|
cell = []
|
2019-04-23 09:15:16 -07:00
|
|
|
# we do some final-style output munging to place residuals
|
|
|
|
# fun :: d1 -> c1 -> a1 -> (c1, (b1, res))
|
2019-04-11 14:50:58 -07:00
|
|
|
def fun(*vals):
|
2019-05-10 08:20:40 -07:00
|
|
|
pvals = map(as_pval, jaxpr.in_avals, second_components, vals)
|
2019-04-11 14:50:58 -07:00
|
|
|
jaxpr_2, out_pval, consts_2 = trace_to_jaxpr(f, pvals)
|
2019-05-09 07:23:39 -07:00
|
|
|
(out_pv_c, out_pv_b), out_const = out_pval
|
|
|
|
if out_const is core.unit:
|
|
|
|
out_const_c, out_const_b = core.unit, core.unit
|
|
|
|
else:
|
|
|
|
out_const_c, out_const_b = out_const
|
2019-04-23 09:15:16 -07:00
|
|
|
cell.append((out_pv_c, out_pv_b, jaxpr_2))
|
|
|
|
return pack((out_const_c, pack((out_const_b, pack(consts_2)))))
|
2019-04-11 14:50:58 -07:00
|
|
|
|
2019-05-10 08:20:40 -07:00
|
|
|
pvals = map(as_pval2, jaxpr.in_avals, second_components)
|
2019-04-11 14:50:58 -07:00
|
|
|
jaxpr_1, out_pval, consts_1 = trace_to_jaxpr(
|
|
|
|
lu.wrap_init(fun), pvals, instantiate=True)
|
2019-04-23 09:15:16 -07:00
|
|
|
out_pv_c, out_pv_b, jaxpr_2 = cell[0]
|
2019-04-11 14:50:58 -07:00
|
|
|
|
2019-04-23 09:15:16 -07:00
|
|
|
# jaxpr_1 :: d1 -> c1 -> a1 -> (c1, (b1, res))
|
|
|
|
# jaxpr_2 :: res | d2 -> c2 -> a2 -> (c2, b2)
|
|
|
|
# lifted_jaxpr_2 :: res -> d2 -> c2 -> a2 -> (c2, b2)
|
|
|
|
# doubly_lifted_jaxpr_2 :: d2 -> c2 -> (a2, res) -> (c2, b2)
|
2019-05-01 15:47:01 -07:00
|
|
|
lifted_jaxpr_2 = _closure_convert_jaxpr(jaxpr_2)
|
|
|
|
doubly_lifted_jaxpr_2 = _move_and_pair_arg(lifted_jaxpr_2)
|
2019-05-10 08:20:40 -07:00
|
|
|
sc_out = sc_c_out, sc_b_out = unknown(out_pv_c), unknown(out_pv_b)
|
2019-04-18 09:00:11 -07:00
|
|
|
|
2019-05-10 08:20:40 -07:00
|
|
|
in_avals_1, in_avals_2 = unzip2(map(_split_avals, second_components,
|
2019-04-18 09:00:11 -07:00
|
|
|
jaxpr.in_avals))
|
2019-05-10 08:20:40 -07:00
|
|
|
out_aval_1, out_aval_2 = _split_avals(sc_out, jaxpr.out_aval)
|
2019-04-18 09:00:11 -07:00
|
|
|
|
2019-04-23 09:15:16 -07:00
|
|
|
# in_avals_1 is already (d1, c1, a1), and out_aval_2 is already (c2, b2), but
|
2019-04-18 09:00:11 -07:00
|
|
|
# we must munge:
|
2019-04-23 09:15:16 -07:00
|
|
|
# 1. form out_aval_1 to include the residuals as (c1, (b1, res))
|
|
|
|
# 2. form in_avals_2 to include the residuals as (d2, c2, (a2, res))
|
2019-04-18 09:00:11 -07:00
|
|
|
|
|
|
|
out_pv, _ = out_pval
|
2019-04-23 09:15:16 -07:00
|
|
|
_, (_, res) = out_pv
|
2019-04-18 09:00:11 -07:00
|
|
|
assert isinstance(res, AbstractValue)
|
|
|
|
|
2019-04-23 09:15:16 -07:00
|
|
|
c1, b1 = out_aval_1
|
|
|
|
lifted_out_aval_1 = AbstractTuple((c1, AbstractTuple((b1, res))))
|
2019-04-18 09:00:11 -07:00
|
|
|
|
|
|
|
d2, c2, a2 = in_avals_2
|
2019-04-23 09:15:16 -07:00
|
|
|
lifted_in_avals_2 = (d2, c2, AbstractTuple((a2, res)))
|
2019-04-11 14:50:58 -07:00
|
|
|
|
2019-04-18 09:00:11 -07:00
|
|
|
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, lifted_out_aval_1)
|
|
|
|
typed_jaxpr_2 = TypedJaxpr(doubly_lifted_jaxpr_2, (), lifted_in_avals_2,
|
|
|
|
out_aval_2)
|
2019-05-10 08:20:40 -07:00
|
|
|
return typed_jaxpr_1, typed_jaxpr_2, sc_out
|
2019-04-11 14:50:58 -07:00
|
|
|
|
2019-05-01 15:47:01 -07:00
|
|
|
def _move_and_pair_arg(jaxpr):
|
2019-04-11 14:50:58 -07:00
|
|
|
moved_jaxpr = jaxpr.copy()
|
|
|
|
res, d, c, a = jaxpr.invars
|
2019-05-01 15:47:01 -07:00
|
|
|
moved_jaxpr.invars = [d, c, (a, res)]
|
|
|
|
core.skip_checks or core.check_jaxpr(moved_jaxpr)
|
2019-04-11 14:50:58 -07:00
|
|
|
return moved_jaxpr
|
|
|
|
|
2019-05-10 08:20:40 -07:00
|
|
|
def _split_avals(second_component, aval):
|
|
|
|
t = type(second_component)
|
2019-04-18 07:19:04 -07:00
|
|
|
if t is tuple:
|
2019-04-18 09:00:11 -07:00
|
|
|
assert type(aval) is AbstractTuple
|
2019-05-10 08:20:40 -07:00
|
|
|
avals1, avals2 = unzip2(map(_split_avals, second_component, aval))
|
2019-04-23 09:15:16 -07:00
|
|
|
return AbstractTuple(avals1), AbstractTuple(avals2)
|
2019-04-18 07:19:04 -07:00
|
|
|
elif t is bool:
|
2019-05-10 08:20:40 -07:00
|
|
|
if second_component:
|
2019-04-18 09:00:11 -07:00
|
|
|
return AbstractTuple(()), aval
|
2019-05-10 08:20:40 -07:00
|
|
|
else:
|
|
|
|
return aval, AbstractTuple(())
|
2019-04-18 07:19:04 -07:00
|
|
|
else:
|
|
|
|
raise TypeError(t)
|
|
|
|
|
2019-04-11 14:50:58 -07:00
|
|
|
|
2019-04-01 16:03:56 -04:00
|
|
|
custom_partial_eval_rules = {}
|