2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
2018-11-21 13:27:26 -08:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
import itertools as it
|
|
|
|
from collections import namedtuple, Counter, defaultdict
|
2019-11-27 19:15:53 -08:00
|
|
|
import contextlib
|
|
|
|
import threading
|
2019-11-19 12:26:30 -08:00
|
|
|
from weakref import ref
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-28 22:38:06 -07:00
|
|
|
import numpy as onp
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from .. import core
|
|
|
|
from .. import linear_util as lu
|
2019-11-22 10:53:11 -08:00
|
|
|
from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped
|
2018-11-17 18:03:33 -08:00
|
|
|
from ..linear_util import thunk, transformation, transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
from ..util import unzip2, safe_zip, safe_map, toposort, partial, split_list
|
2019-11-19 12:26:30 -08:00
|
|
|
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
|
|
|
|
AbstractValue, unit, unitvar, abstract_unit, Primitive,
|
|
|
|
call_p, TypedJaxpr, new_jaxpr_eqn)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
|
|
|
zip = safe_zip
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-09 07:23:39 -07:00
|
|
|
# A partial value (pval) is modeled as a pair (pv, const), as per
|
|
|
|
# type PVal = (PV, Const)
|
2019-09-20 15:35:43 -07:00
|
|
|
# data PV = Known | Unknown AbstractValue
|
2019-05-09 07:23:39 -07:00
|
|
|
# type Const = MaybeTraced JaxType
|
2019-09-20 15:48:39 -07:00
|
|
|
# where the Known arm, represented by a None, indicates a known (constant) value
|
|
|
|
# and the Unknown arm, represented by an AbstractValue instance, indicates an
|
|
|
|
# unknown value.
|
|
|
|
# When the pv is an AbstractValue, then the const must be unit.
|
2019-05-09 07:23:39 -07:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTrace(Trace):
|
|
|
|
def pure(self, val):
|
2019-09-20 15:35:43 -07:00
|
|
|
return self.new_const(val)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lift(self, val):
|
|
|
|
return self.new_const(val)
|
|
|
|
|
|
|
|
def sublift(self, val):
|
|
|
|
return JaxprTracer(self, val.pval, FreeVar(val))
|
|
|
|
|
|
|
|
def new_const(self, val):
|
|
|
|
if isinstance(val, Tracer) and val.trace.level == self.level:
|
|
|
|
raise Exception
|
|
|
|
return JaxprTracer(self, PartialVal((None, val)), unit)
|
|
|
|
|
2019-05-13 08:48:13 -07:00
|
|
|
def new_instantiated_literal(self, val):
|
|
|
|
return JaxprTracer(self, PartialVal((get_aval(val), unit)), Literal(val))
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def new_instantiated_const(self, val):
|
|
|
|
return JaxprTracer(self, PartialVal((get_aval(val), unit)), ConstVar(val))
|
|
|
|
|
|
|
|
def new_arg(self, pval):
|
|
|
|
_, const = pval
|
|
|
|
return JaxprTracer(self, pval, LambdaBinding())
|
|
|
|
|
|
|
|
def instantiate_const(self, tracer):
|
|
|
|
pv, const = tracer.pval
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
return tracer
|
|
|
|
elif pv is None:
|
2019-09-20 15:35:43 -07:00
|
|
|
if type(const) in core.literalable_types and onp.shape(const) == ():
|
|
|
|
return self.new_instantiated_literal(const)
|
2019-05-13 08:48:13 -07:00
|
|
|
else:
|
|
|
|
return self.new_instantiated_const(const)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
raise TypeError(pv)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-04-01 16:03:56 -04:00
|
|
|
if primitive in custom_partial_eval_rules:
|
2019-07-27 15:46:14 -07:00
|
|
|
return custom_partial_eval_rules[primitive](self, *tracers, **params)
|
2019-04-01 16:03:56 -04:00
|
|
|
else:
|
2019-06-18 21:23:52 -07:00
|
|
|
pvs, consts = unzip2(t.pval for t in tracers)
|
|
|
|
if all(pv is None for pv in pvs):
|
|
|
|
return primitive.bind(*consts, **params)
|
2019-04-01 16:03:56 -04:00
|
|
|
tracers = map(self.instantiate_const, tracers)
|
|
|
|
avals = [t.aval for t in tracers]
|
|
|
|
out_aval = primitive.abstract_eval(*avals, **params)
|
2019-07-27 15:46:14 -07:00
|
|
|
if primitive.multiple_results:
|
|
|
|
out_tracers = [JaxprTracer(self, PartialVal((aval, unit)), None)
|
|
|
|
for aval in out_aval]
|
2019-11-19 12:26:30 -08:00
|
|
|
eqn = new_eqn_recipe(tracers, out_tracers, primitive, (), params)
|
2019-07-27 15:46:14 -07:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
else:
|
|
|
|
out_tracer = JaxprTracer(self, PartialVal((out_aval, unit)), None)
|
2019-11-19 12:26:30 -08:00
|
|
|
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, (), params)
|
2019-07-27 15:46:14 -07:00
|
|
|
return out_tracer
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
2019-11-22 10:53:11 -08:00
|
|
|
if call_primitive in call_partial_eval_rules:
|
|
|
|
return call_partial_eval_rules[call_primitive](self, f, tracers, params)
|
2019-02-21 11:47:26 -08:00
|
|
|
if call_primitive in map_primitives:
|
|
|
|
return self.process_map(call_primitive, f, tracers, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
|
|
|
fun, aux = partial_eval(f, self, in_pvs)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat = call_primitive.bind(fun, *in_consts, **params)
|
2019-07-26 16:48:17 -04:00
|
|
|
out_pvs, jaxpr, env = aux()
|
2019-07-27 15:46:14 -07:00
|
|
|
out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
2018-11-17 18:03:33 -08:00
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
2019-07-08 18:33:09 -07:00
|
|
|
bound_subjaxpr = (jaxpr, const_tracers, map(self.full_raise, env))
|
2019-07-26 18:01:38 -04:00
|
|
|
out_tracers = [JaxprTracer(self, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2019-11-19 12:26:30 -08:00
|
|
|
eqn = new_eqn_recipe(tracers, out_tracers, call_primitive, (bound_subjaxpr,), params)
|
2019-07-26 18:01:38 -04:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-06-04 18:33:52 -07:00
|
|
|
def process_map(self, map_primitive, f, tracers, params):
|
2019-02-20 12:36:18 -08:00
|
|
|
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
2019-07-27 15:46:14 -07:00
|
|
|
reduced_pvs = [None if pv is None else _mapped_aval(pv) for pv in in_pvs]
|
2019-02-21 11:47:26 -08:00
|
|
|
fun, aux = partial_eval(f, self, reduced_pvs)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat = map_primitive.bind(fun, *in_consts, **params)
|
|
|
|
out_pvs_reduced, jaxpr, env = aux()
|
|
|
|
out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvs = [None if pv is None else _unmapped_aval(params['axis_size'], pv)
|
|
|
|
for pv in out_pvs_reduced]
|
2019-02-20 12:36:18 -08:00
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
2019-07-27 15:46:14 -07:00
|
|
|
lifted_jaxpr = closure_convert_jaxpr(jaxpr)
|
|
|
|
bound_subjaxpr = (lifted_jaxpr, (), map(self.full_raise, env))
|
|
|
|
out_tracers = [JaxprTracer(self, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2019-11-19 12:26:30 -08:00
|
|
|
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, tracers)),
|
|
|
|
out_tracers, map_primitive, (bound_subjaxpr,), params)
|
2019-07-27 15:46:14 -07:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
2019-09-20 07:01:01 -07:00
|
|
|
if call_primitive in map_primitives:
|
|
|
|
return self.post_process_map(call_primitive, out_tracers, params)
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
|
|
|
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
|
|
|
out = out_pv_consts + consts
|
|
|
|
del consts, out_pv_consts
|
2018-11-17 18:03:33 -08:00
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
2019-07-27 15:46:14 -07:00
|
|
|
n = len(jaxpr.outvars)
|
|
|
|
out_pv_consts, consts = x[:n], x[n:]
|
2018-11-17 18:03:33 -08:00
|
|
|
trace = JaxprTrace(master, core.cur_sublevel())
|
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
|
|
|
env_tracers = map(trace.full_raise, env)
|
|
|
|
bound_subjaxpr = (jaxpr, const_tracers, env_tracers)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2019-11-19 12:26:30 -08:00
|
|
|
eqn = new_eqn_recipe([], out_tracers, call_primitive, (bound_subjaxpr,), params)
|
2019-07-27 15:46:14 -07:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
2018-11-17 18:03:33 -08:00
|
|
|
return out, todo
|
|
|
|
|
2019-09-20 07:01:01 -07:00
|
|
|
def post_process_map(self, map_primitive, out_tracers, params):
|
|
|
|
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
|
|
|
out_pvs_reduced, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
|
|
|
out_pvs = [None if pv is None else _unmapped_aval(params['axis_size'], pv)
|
|
|
|
for pv in out_pvs_reduced]
|
|
|
|
out = out_pv_consts + consts
|
|
|
|
del consts, out_pv_consts
|
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
|
|
|
n = len(jaxpr.outvars)
|
|
|
|
out_pv_consts, consts = x[:n], x[n:]
|
|
|
|
trace = JaxprTrace(master, core.cur_sublevel())
|
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
|
|
|
env_tracers = map(trace.full_raise, env)
|
|
|
|
lifted_jaxpr = closure_convert_jaxpr(jaxpr)
|
|
|
|
bound_subjaxpr = (lifted_jaxpr, (), env_tracers)
|
|
|
|
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2019-11-19 12:26:30 -08:00
|
|
|
eqn = new_eqn_recipe(const_tracers, out_tracers, map_primitive,
|
|
|
|
(bound_subjaxpr,), params)
|
2019-09-20 07:01:01 -07:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
return out, todo
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def _mapped_aval(aval):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return aval
|
2019-02-21 11:47:26 -08:00
|
|
|
elif isinstance(aval, ShapedArray):
|
|
|
|
# might be raising abstraction level from Concrete here
|
|
|
|
return ShapedArray(aval.shape[1:], aval.dtype)
|
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise TypeError(aval)
|
2019-02-21 11:47:26 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def _unmapped_aval(size, aval):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return aval
|
2019-02-21 11:47:26 -08:00
|
|
|
elif isinstance(aval, ShapedArray):
|
|
|
|
return ShapedArray((size,) + aval.shape, aval.dtype)
|
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
map_primitives = set()
|
2019-11-22 10:53:11 -08:00
|
|
|
custom_partial_eval_rules = {}
|
|
|
|
call_partial_eval_rules = {}
|
2019-02-21 11:47:26 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def partial_eval(f, trace, pvs):
|
2019-04-01 16:03:56 -04:00
|
|
|
f = trace_to_subjaxpr(f, trace.master, False)
|
2018-11-17 18:03:33 -08:00
|
|
|
return partial_eval_wrapper(f, tuple(pvs))
|
|
|
|
|
|
|
|
|
|
|
|
@transformation_with_aux
|
2019-04-05 12:02:24 -07:00
|
|
|
def partial_eval_wrapper(avals, *consts):
|
2019-04-10 22:09:14 -07:00
|
|
|
py_args = (map(PartialVal, zip(avals, consts)),)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, (out_pvals, consts, env) = yield py_args, {}
|
|
|
|
out_pvs, out_consts = unzip2(out_pvals)
|
|
|
|
out = tuple(out_consts) + tuple(consts) # TODO: can consts be traced?
|
|
|
|
yield out, (out_pvs, jaxpr, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-13 14:28:30 -08:00
|
|
|
def abstract_eval_fun(fun, *avals, **params):
|
2019-06-03 07:17:37 -07:00
|
|
|
pvals_in = [PartialVal((a, unit)) for a in avals]
|
2019-07-27 15:46:14 -07:00
|
|
|
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
|
2019-06-03 07:17:37 -07:00
|
|
|
instantiate=True)
|
2019-07-27 15:46:14 -07:00
|
|
|
avals_out, _ = unzip2(pvals_out)
|
2019-07-26 16:48:17 -04:00
|
|
|
for aval_out in avals_out:
|
|
|
|
assert isinstance(aval_out, AbstractValue) # instantiate=True
|
|
|
|
return avals_out
|
2019-02-13 14:28:30 -08:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['pval', 'recipe']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __init__(self, trace, pval, recipe):
|
|
|
|
assert isinstance(pval, PartialVal)
|
|
|
|
pv, const = pval
|
|
|
|
if isinstance(const, Tracer):
|
|
|
|
assert const.trace.level < trace.level
|
|
|
|
self.trace = trace
|
|
|
|
self.pval = pval
|
|
|
|
self.recipe = recipe
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return 'Traced<{}:{}>'.format(self.aval, self.trace)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
pv, const = self.pval
|
|
|
|
return partial_val_aval(pv, const)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def parents(self):
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(self.recipe, JaxprEqnRecipe):
|
2018-11-17 18:03:33 -08:00
|
|
|
return eqn_parents(self.recipe)
|
|
|
|
else:
|
|
|
|
return []
|
|
|
|
|
|
|
|
def ispure(self):
|
|
|
|
pv, _ = self.pval
|
2019-07-27 15:46:14 -07:00
|
|
|
return pv is None # or pv is core.abstract_unit
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
if self.ispure():
|
|
|
|
_, const = self.pval
|
|
|
|
return core.full_lower(const)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
|
|
|
class PartialVal(tuple):
|
2018-11-19 07:43:23 -08:00
|
|
|
def __new__(cls, xs):
|
2019-05-09 07:23:39 -07:00
|
|
|
pv, const = xs
|
|
|
|
if not core.skip_checks:
|
|
|
|
# type checks
|
|
|
|
assert isinstance(pv, valid_pv_types), xs
|
|
|
|
assert isinstance(const, core.Tracer) or core.valid_jaxtype(const), xs
|
|
|
|
# invariant checks
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
assert const == core.unit, xs
|
2018-11-19 07:43:23 -08:00
|
|
|
return tuple.__new__(cls, xs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
valid_pv_types = (AbstractValue, type(None))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def merge_pvals(val, pval):
|
|
|
|
pv, const = pval
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
return val
|
|
|
|
elif pv is None:
|
|
|
|
return const
|
|
|
|
else:
|
|
|
|
raise TypeError(pv)
|
|
|
|
|
|
|
|
def partial_val_aval(pv, const):
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
return pv
|
|
|
|
elif pv is None:
|
|
|
|
return get_aval(const)
|
|
|
|
else:
|
|
|
|
raise TypeError(pv)
|
|
|
|
|
|
|
|
def trace_to_jaxpr(fun, pvals, **kwargs):
|
|
|
|
"""Traces a function, given abstract inputs, to a jaxpr."""
|
2019-04-01 16:03:56 -04:00
|
|
|
instantiate = kwargs.pop('instantiate', False)
|
2018-11-17 18:03:33 -08:00
|
|
|
with new_master(JaxprTrace) as master:
|
2019-04-01 16:03:56 -04:00
|
|
|
fun = trace_to_subjaxpr(fun, master, instantiate)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
assert not env
|
|
|
|
del master
|
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
return jaxpr, out_pvals, consts
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@transformation
|
2019-04-01 16:03:56 -04:00
|
|
|
def trace_to_subjaxpr(master, instantiate, pvals):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
|
|
|
trace = JaxprTrace(master, core.cur_sublevel())
|
|
|
|
in_tracers = map(trace.new_arg, pvals)
|
2019-06-23 15:31:13 -07:00
|
|
|
ans = yield in_tracers, {}
|
2019-07-27 15:46:14 -07:00
|
|
|
instantiate = [instantiate] * len(ans) if type(instantiate) is bool else instantiate
|
|
|
|
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
|
|
|
|
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
|
|
|
|
out_pvals = [t.pval for t in out_tracers]
|
|
|
|
del trace, in_tracers, out_tracers
|
|
|
|
yield jaxpr, (out_pvals, consts, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-10 08:58:05 -07:00
|
|
|
def instantiate_const_at(trace, instantiate, tracer):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert type(instantiate) is bool
|
|
|
|
if instantiate:
|
|
|
|
return trace.instantiate_const(trace.full_raise(tracer))
|
2019-05-10 08:58:05 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return tracer
|
2019-05-10 08:58:05 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
FreeVar = namedtuple('FreeVar', ['val'])
|
|
|
|
ConstVar = namedtuple('ConstVar', ['val'])
|
|
|
|
LambdaBinding = namedtuple('LambdaBinding', [])
|
2019-11-19 12:26:30 -08:00
|
|
|
JaxprEqnRecipe = namedtuple('JaxprEqnRecipe',
|
|
|
|
['eqn_id', 'invars', 'outvars', 'primitive',
|
|
|
|
'bound_subjaxprs', 'params'])
|
|
|
|
|
|
|
|
def new_eqn_recipe(invars, outvars, primitive, bound_subjaxprs, params):
|
|
|
|
return JaxprEqnRecipe(object(), invars, map(ref, outvars), primitive,
|
|
|
|
bound_subjaxprs, params)
|
|
|
|
|
2019-11-20 09:12:15 -08:00
|
|
|
def recipe_to_eqn(unused_var, getvar, recipe):
|
2019-11-19 12:26:30 -08:00
|
|
|
_, in_tracers, out_tracer_refs, primitive, bound_subjaxprs, params = recipe
|
2019-11-20 09:12:15 -08:00
|
|
|
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
|
|
|
invars = [getvar(t) for t in in_tracers]
|
|
|
|
outvars = [unused_var() if t is None else getvar(t) for t in out_tracers]
|
|
|
|
new_bound_subjaxprs = [(j, map(getvar, c), map(getvar, f))
|
2018-11-17 18:03:33 -08:00
|
|
|
for j, c, f in bound_subjaxprs]
|
2019-07-26 18:01:38 -04:00
|
|
|
return new_jaxpr_eqn(invars, outvars, primitive, new_bound_subjaxprs, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
def tracers_to_jaxpr(in_tracers, out_tracers):
|
2019-10-08 10:57:36 -07:00
|
|
|
newvar = core.gensym('')
|
2018-11-17 18:03:33 -08:00
|
|
|
t_to_var = defaultdict(newvar)
|
2019-11-20 09:12:15 -08:00
|
|
|
getvar = lambda t: t_to_var[id(t)]
|
2019-07-26 16:48:17 -04:00
|
|
|
sorted_tracers = toposort(out_tracers)
|
2019-11-20 09:12:15 -08:00
|
|
|
invars = map(getvar, in_tracers)
|
2018-11-17 18:03:33 -08:00
|
|
|
eqns = []
|
|
|
|
env = {}
|
|
|
|
consts = {}
|
2019-06-18 08:09:37 -07:00
|
|
|
const_to_var = defaultdict(newvar)
|
2018-11-17 18:03:33 -08:00
|
|
|
destructuring_vars = {}
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids = set()
|
2018-11-17 18:03:33 -08:00
|
|
|
for t in sorted_tracers:
|
|
|
|
recipe = t.recipe
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(recipe, JaxprEqnRecipe):
|
2019-11-20 09:12:15 -08:00
|
|
|
if recipe.eqn_id not in processed_eqn_ids:
|
|
|
|
eqns.append(recipe_to_eqn(newvar, getvar, recipe))
|
|
|
|
processed_eqn_ids.add(recipe.eqn_id)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, LambdaBinding):
|
2019-10-13 14:51:21 +02:00
|
|
|
assert any(t is in_tracer for in_tracer in in_tracers), "Encountered unexpected tracer"
|
2018-11-17 18:03:33 -08:00
|
|
|
assert in_tracers, "Lambda binding with no args"
|
|
|
|
elif isinstance(recipe, FreeVar):
|
2019-11-20 09:12:15 -08:00
|
|
|
env[getvar(t)] = recipe.val
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, ConstVar):
|
2019-06-18 08:09:37 -07:00
|
|
|
v = t_to_var[id(t)] = const_to_var[id(recipe.val)]
|
|
|
|
consts[v] = recipe.val
|
2019-05-13 08:48:13 -07:00
|
|
|
elif isinstance(recipe, Literal):
|
|
|
|
t_to_var[id(t)] = recipe
|
2018-11-17 18:03:33 -08:00
|
|
|
elif recipe is unit:
|
2018-11-21 13:20:44 -08:00
|
|
|
t_to_var[id(t)] = unitvar
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError(recipe)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
env_vars, env_vals = unzip2(env.items())
|
|
|
|
const_vars, const_vals = unzip2(consts.items())
|
2019-11-20 09:12:15 -08:00
|
|
|
jaxpr = Jaxpr(const_vars, env_vars, invars, list(map(getvar, out_tracers)), eqns)
|
2019-09-20 15:35:43 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2018-11-17 18:03:33 -08:00
|
|
|
return jaxpr, const_vals, env_vals
|
|
|
|
|
|
|
|
|
2019-11-20 09:12:15 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def eqn_parents(eqn):
|
|
|
|
subjaxpr_tracers = [it.chain(c, f) for _, c, f in eqn.bound_subjaxprs]
|
2019-07-26 16:48:17 -04:00
|
|
|
return list(it.chain(eqn.invars, *subjaxpr_tracers))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def closure_convert_jaxpr(jaxpr):
|
2019-05-08 16:27:23 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
lifted_jaxpr = Jaxpr(constvars=(), freevars=jaxpr.freevars,
|
|
|
|
invars=jaxpr.constvars + jaxpr.invars,
|
|
|
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
2019-05-01 15:47:01 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
2019-04-09 08:45:34 -07:00
|
|
|
return lifted_jaxpr
|
|
|
|
|
2019-11-27 14:28:13 -08:00
|
|
|
def convert_freevars_jaxpr(jaxpr):
|
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
|
|
|
lifted_jaxpr = Jaxpr(constvars=jaxpr.constvars, freevars=(),
|
|
|
|
invars=jaxpr.freevars + jaxpr.invars,
|
|
|
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
|
|
|
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
|
|
|
return lifted_jaxpr
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
|
2019-04-23 09:15:16 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-04-11 14:50:58 -07:00
|
|
|
|
|
|
|
cell = []
|
|
|
|
def fun(*vals):
|
2019-07-27 15:46:14 -07:00
|
|
|
pvals = [PartialVal((aval, unit)) if uk else PartialVal((None, val))
|
|
|
|
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
|
|
|
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
|
|
|
|
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
|
|
|
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
|
|
|
return out_consts_2 + consts_2
|
|
|
|
|
|
|
|
pvals = [PartialVal((abstract_unit, unit)) if uk else PartialVal((aval, unit))
|
|
|
|
for aval, uk in zip(jaxpr.in_avals, unknowns)]
|
|
|
|
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
|
|
|
|
(out_pvs_2, jaxpr_2, num_res), = cell
|
|
|
|
assert len(jaxpr_2.constvars) == num_res
|
|
|
|
|
|
|
|
# jaxpr :: a -> b
|
|
|
|
# jaxpr_1 :: a1 -> [b1, res]
|
|
|
|
# jaxpr_2 :: res | a2 -> b2
|
|
|
|
# jaxpr_2 :: [a2, res] -> b2
|
|
|
|
jaxpr_2 = closure_convert_jaxpr(jaxpr_2)
|
|
|
|
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
|
|
|
|
uk_out = [pv is not None for pv in out_pvs_2]
|
|
|
|
|
|
|
|
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
|
|
|
|
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
|
|
|
|
# out_avals_1 and in_avals_2 need the residuals added
|
|
|
|
out_pvs, _ = unzip2(out_pvals)
|
|
|
|
res_avals = out_pvs[len(jaxpr.out_avals):]
|
|
|
|
assert len(res_avals) == num_res
|
|
|
|
out_avals_1 = out_avals_1 + res_avals
|
|
|
|
in_avals_2 = in_avals_2 + res_avals
|
|
|
|
|
|
|
|
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1)
|
|
|
|
typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2)
|
|
|
|
return typed_jaxpr_1, typed_jaxpr_2, uk_out
|
|
|
|
|
|
|
|
def _split_aval(unknown, aval):
|
|
|
|
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
|
2019-04-11 14:50:58 -07:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
remat_call_p = core.Primitive('remat_call')
|
|
|
|
remat_call = partial(core.call_bind, remat_call_p)
|
|
|
|
remat_call_p.def_custom_bind(remat_call)
|
|
|
|
remat_call_p.def_impl(core.call_impl)
|
|
|
|
remat_call_p.multiple_results = True
|
|
|
|
|
|
|
|
def _remat_partial_eval(trace, f, tracers, params):
|
|
|
|
concrete = params['concrete']
|
|
|
|
|
|
|
|
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
|
|
|
# the function being called, not just for the unknown parts. To do that, we
|
|
|
|
# instantiate all the input tracers as constants in the jaxpr being formed.
|
|
|
|
# Those tracers might have concrete avals, and doing abstract interpretation
|
|
|
|
# on concrete avals engenders a tradeoff: it allows data-dependent Python
|
|
|
|
# control flow to work, but it can in some cases lead to redundant FLOPs (done
|
|
|
|
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
|
|
|
|
# `concrete` parameter to switch this behavior, and if `concrete` is False
|
|
|
|
# then we raise the avals to the Shaped level.
|
|
|
|
instantiated_tracers = map(trace.instantiate_const, tracers)
|
|
|
|
if not concrete:
|
|
|
|
instantiated_tracers = [
|
|
|
|
JaxprTracer(trace, PartialVal((raise_to_shaped(t.pval[0]), unit)), t.recipe)
|
|
|
|
if type(t.pval[0]) is ConcreteArray else t for t in instantiated_tracers]
|
|
|
|
|
|
|
|
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
|
|
|
in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers)
|
|
|
|
fun, aux = partial_eval(f, trace, in_pvs)
|
2019-11-27 19:15:53 -08:00
|
|
|
if concrete:
|
|
|
|
# TODO(mattjj): remove `remat_context` when confident no accidental FLOPs
|
|
|
|
with remat_context():
|
|
|
|
out_flat = remat_call_p.bind(fun, *in_consts, **params)
|
|
|
|
else:
|
|
|
|
out_flat = remat_call_p.bind(fun, *in_consts, **params)
|
2019-11-22 10:53:11 -08:00
|
|
|
out_pvs, jaxpr, env = aux()
|
2019-11-27 14:28:13 -08:00
|
|
|
env = map(trace.full_raise, env)
|
2019-11-22 10:53:11 -08:00
|
|
|
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]
|
|
|
|
|
|
|
|
# Since we traced with everything marked as unknown, but we need to know which
|
|
|
|
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
2019-11-27 14:28:13 -08:00
|
|
|
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
|
|
|
|
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
|
|
|
|
+ [raise_to_shaped(pv) for pv in in_pvs])
|
2019-11-22 10:53:11 -08:00
|
|
|
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
|
|
|
|
for pv, const in zip(out_pvs, out_pval_consts1)]
|
2019-11-27 14:28:13 -08:00
|
|
|
typed_jaxpr = core.TypedJaxpr(jaxpr_converted, consts, in_avals, out_avals)
|
|
|
|
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
|
2019-11-22 10:53:11 -08:00
|
|
|
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
|
|
|
|
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
|
|
|
|
|
2019-11-27 15:25:49 -08:00
|
|
|
# First, we prune the jaxpr to be staged out not to have too many outputs.
|
|
|
|
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
# Next, we need values for the outputs that should be known. Since consts
|
|
|
|
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
|
|
|
|
# minus the residual outputs that we don't need. When `concrete=True`, as an
|
|
|
|
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
|
|
|
|
# produced concrete avals at the output, simply by using those as computed
|
2019-11-27 14:28:13 -08:00
|
|
|
# values. For the use case of reverse-mode ad in op-by-op ("eager mode")
|
|
|
|
# evaluation, all the primal outputs should be concrete (thus not recomputed).
|
2019-11-22 10:53:11 -08:00
|
|
|
to_compute = [not uk and type(pv) is not ConcreteArray
|
|
|
|
for uk, pv in zip(out_unknowns, out_pvs)]
|
2019-11-27 14:28:13 -08:00
|
|
|
jaxpr_1_primals = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
|
|
|
|
_, in_consts = unzip2(t.pval for t in it.chain(env, tracers))
|
|
|
|
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1_primals)(*in_consts)[:-num_res or None]
|
2019-11-22 10:53:11 -08:00
|
|
|
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
|
|
|
|
|
2019-11-27 15:25:49 -08:00
|
|
|
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
|
2019-11-27 14:28:13 -08:00
|
|
|
instantiated_tracers = env + instantiated_tracers
|
2019-11-22 10:53:11 -08:00
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
2019-11-27 15:25:49 -08:00
|
|
|
bound_subjaxpr = (typed_jaxpr.jaxpr, const_tracers, ())
|
2019-11-22 10:53:11 -08:00
|
|
|
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
|
2019-11-27 15:25:49 -08:00
|
|
|
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
|
|
|
|
(bound_subjaxpr,), params)
|
|
|
|
for t in out_tracers: t.recipe = eqn
|
2019-11-22 10:53:11 -08:00
|
|
|
return out_tracers
|
|
|
|
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
|
|
|
|
2019-11-27 15:25:49 -08:00
|
|
|
def _dce_jaxpr(typed_jaxpr, outputs):
|
2019-11-22 10:53:11 -08:00
|
|
|
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
2019-11-27 14:28:13 -08:00
|
|
|
# nontrivially DCE through scan, call, or other higher-order primitives.
|
|
|
|
# TODO(mattjj): better DCE
|
2019-11-22 10:53:11 -08:00
|
|
|
jaxpr = typed_jaxpr.jaxpr
|
|
|
|
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
|
2019-11-27 15:25:49 -08:00
|
|
|
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
|
|
|
|
for var, aval, output in zip(outvars, out_avals, outputs)]
|
2019-11-22 10:53:11 -08:00
|
|
|
new_outvars, new_out_avals = unzip2(out_pairs)
|
|
|
|
|
|
|
|
needed_vars = set(new_outvars)
|
|
|
|
new_eqns = []
|
|
|
|
for eqn in jaxpr.eqns[::-1]:
|
|
|
|
if set(eqn.outvars) & needed_vars:
|
|
|
|
new_eqns.append(eqn)
|
|
|
|
needed_vars.update(eqn.invars)
|
|
|
|
new_eqns = new_eqns[::-1]
|
|
|
|
|
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.constvars, jaxpr.freevars, jaxpr.invars,
|
|
|
|
new_outvars, new_eqns)
|
|
|
|
return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals,
|
|
|
|
new_out_avals)
|
|
|
|
|
|
|
|
def _reconstruct_pval(pval1, const2, unknown):
|
|
|
|
pv1, const1 = pval1
|
|
|
|
if unknown or pv1 is None:
|
|
|
|
return pval1
|
|
|
|
else:
|
|
|
|
if type(pv1) is ConcreteArray:
|
|
|
|
return PartialVal((None, pv1.val))
|
|
|
|
else:
|
|
|
|
return PartialVal((None, const2))
|
2019-11-27 14:28:13 -08:00
|
|
|
|
2019-11-27 19:15:53 -08:00
|
|
|
# TODO(mattjj): for https://github.com/google/jax/pull/1749 we allowed
|
|
|
|
# standard_abstract_eval to perform concrete evaluation (i.e. FLOPs), but we
|
|
|
|
# don't think it should happen except for in a remat context
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def remat_context():
|
|
|
|
try:
|
|
|
|
prev_state = _thread_local_state.remat
|
|
|
|
_thread_local_state.remat = True
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
_thread_local_state.remat = prev_state
|
|
|
|
|
|
|
|
class _ThreadLocalState(threading.local):
|
|
|
|
def __init__(self):
|
|
|
|
self.remat = False
|
|
|
|
_thread_local_state = _ThreadLocalState()
|
|
|
|
|
2019-11-27 14:28:13 -08:00
|
|
|
|
|
|
|
def move_binders_to_front(typed_jaxpr, to_move):
|
|
|
|
assert not typed_jaxpr.jaxpr.constvars and not typed_jaxpr.jaxpr.freevars
|
|
|
|
assert len(typed_jaxpr.in_avals) == len(to_move)
|
|
|
|
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
|
|
|
|
new_jaxpr = core.Jaxpr((), (), new_invars, typed_jaxpr.jaxpr.outvars,
|
|
|
|
typed_jaxpr.jaxpr.eqns)
|
|
|
|
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
|
|
|
|
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
|
|
|
|
new_in_avals, typed_jaxpr.out_avals)
|
|
|
|
return new_typed_jaxpr
|
|
|
|
|
|
|
|
def _move_to_front(lst, to_move):
|
|
|
|
return ([elt for elt, move in zip(lst, to_move) if move] +
|
|
|
|
[elt for elt, move in zip(lst, to_move) if not move])
|
|
|
|
|