2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
2018-11-21 13:27:26 -08:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from operator import attrgetter
|
|
|
|
from contextlib import contextmanager
|
|
|
|
from collections import namedtuple, Counter, defaultdict
|
2019-10-08 10:57:36 -07:00
|
|
|
import itertools as it
|
2018-11-17 18:03:33 -08:00
|
|
|
from weakref import ref
|
2019-07-23 09:53:27 -04:00
|
|
|
import threading
|
2018-11-17 18:03:33 -08:00
|
|
|
import types
|
|
|
|
|
2019-06-18 21:51:51 -07:00
|
|
|
import six
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from . import linear_util as lu
|
2019-12-06 22:28:41 -08:00
|
|
|
from .util import safe_zip, safe_map, partial, curry
|
2018-11-17 18:03:33 -08:00
|
|
|
from .pprint_util import pp, vcat, hcat, pp_kv_pairs
|
|
|
|
|
|
|
|
# TODO(dougalm): the trace cache breaks the leak detector. Consisder solving.
|
|
|
|
check_leaks = False
|
2018-12-09 13:33:15 -05:00
|
|
|
# TODO(dougalm): put this behind a flag that's enabled during testing
|
2019-05-10 22:07:54 -07:00
|
|
|
skip_checks = True # not __debug__ # google doesn't use -O
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
zip = safe_zip
|
|
|
|
map = safe_map
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------- jaxprs --------------------
|
|
|
|
|
|
|
|
class Jaxpr(object):
|
2019-07-26 16:48:17 -04:00
|
|
|
def __init__(self, constvars, freevars, invars, outvars, eqns):
|
2019-07-27 15:46:14 -07:00
|
|
|
self.constvars = list(constvars)
|
|
|
|
self.freevars = list(freevars)
|
|
|
|
self.invars = list(invars)
|
|
|
|
self.outvars = list(outvars)
|
|
|
|
self.eqns = list(eqns)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return str(pp_jaxpr(self))
|
2019-07-27 15:46:14 -07:00
|
|
|
__repr__ = __str__
|
2019-02-06 11:49:21 -05:00
|
|
|
|
2019-05-10 14:00:21 -07:00
|
|
|
class TypedJaxpr(object):
|
2019-07-27 15:46:14 -07:00
|
|
|
def __init__(self, jaxpr, literals, in_avals, out_avals):
|
2019-04-18 07:19:04 -07:00
|
|
|
assert type(jaxpr) is Jaxpr
|
|
|
|
assert len(literals) == len(jaxpr.constvars)
|
|
|
|
assert len(in_avals) == len(jaxpr.invars)
|
2019-07-27 15:46:14 -07:00
|
|
|
assert all(isinstance(aval, AbstractValue) for aval in in_avals)
|
|
|
|
assert all(isinstance(aval, AbstractValue) for aval in out_avals)
|
2019-04-24 16:40:29 -07:00
|
|
|
assert not jaxpr.freevars
|
2019-05-10 14:00:21 -07:00
|
|
|
|
|
|
|
self.jaxpr = jaxpr
|
2019-07-27 15:46:14 -07:00
|
|
|
self.literals = list(literals)
|
|
|
|
self.in_avals = list(in_avals)
|
|
|
|
self.out_avals = list(out_avals)
|
2019-05-10 14:00:21 -07:00
|
|
|
|
|
|
|
def __iter__(self):
|
2019-07-27 15:46:14 -07:00
|
|
|
return iter((self.jaxpr, self.literals, self.in_avals, self.out_avals))
|
2019-04-18 07:19:04 -07:00
|
|
|
|
2019-05-11 13:28:47 -07:00
|
|
|
def __str__(self):
|
|
|
|
# TODO(mattjj): improve this with type annotations?
|
|
|
|
return str(pp_jaxpr(self.jaxpr))
|
2019-07-27 15:46:14 -07:00
|
|
|
__repr__ = __str__
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-23 09:15:16 -07:00
|
|
|
@curry
|
|
|
|
def jaxpr_as_fun(typed_jaxpr, *args):
|
2019-07-27 15:46:14 -07:00
|
|
|
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, (), *args)
|
2019-04-23 09:15:16 -07:00
|
|
|
|
|
|
|
|
2019-11-19 12:26:30 -08:00
|
|
|
JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive',
|
2019-07-26 16:48:17 -04:00
|
|
|
'bound_subjaxprs', 'params'])
|
2019-11-19 12:26:30 -08:00
|
|
|
JaxprEqn.__repr__ = JaxprEqn.__str__ = lambda eqn: str(pp_eqn(eqn)).rstrip()
|
|
|
|
new_jaxpr_eqn = JaxprEqn
|
2019-07-26 18:01:38 -04:00
|
|
|
|
2019-10-03 17:56:25 -07:00
|
|
|
|
2019-10-08 10:57:36 -07:00
|
|
|
class Var(object):
|
|
|
|
def __init__(self, count, suffix):
|
|
|
|
self.count = count
|
|
|
|
self.suffix = suffix
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
rem = self.count
|
|
|
|
s = ''
|
|
|
|
while True:
|
|
|
|
rem, i = rem // 26, rem % 26
|
|
|
|
s = chr(97 + i % 26) + s
|
|
|
|
if not rem:
|
|
|
|
break
|
|
|
|
return s + self.suffix
|
|
|
|
|
|
|
|
def gensym(suffix):
|
|
|
|
counter = it.count()
|
|
|
|
return lambda: Var(next(counter), suffix)
|
|
|
|
|
2019-05-28 22:50:52 -07:00
|
|
|
class Literal(object):
|
2019-06-18 21:51:51 -07:00
|
|
|
__slots__ = ["val", "hash"]
|
2019-05-28 22:50:52 -07:00
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
try:
|
2019-06-18 21:51:51 -07:00
|
|
|
self.hash = hash(val)
|
2019-05-28 22:50:52 -07:00
|
|
|
except TypeError:
|
2019-06-19 10:32:55 -07:00
|
|
|
if type(val) in literalable_types:
|
2019-06-18 21:51:51 -07:00
|
|
|
try:
|
2019-06-19 10:32:55 -07:00
|
|
|
self.hash = hash((val.item(), val.dtype))
|
2019-06-18 21:51:51 -07:00
|
|
|
except (TypeError, AttributeError):
|
|
|
|
self.hash = None
|
2019-05-28 22:50:52 -07:00
|
|
|
|
|
|
|
def __hash__(self):
|
2019-06-18 21:51:51 -07:00
|
|
|
return id(self.val) if self.hash is None else self.hash
|
2019-05-28 22:50:52 -07:00
|
|
|
|
|
|
|
def __eq__(self, other):
|
2019-06-18 21:51:51 -07:00
|
|
|
return self.val is other.val if self.hash is None else self.val == other.val
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-29 08:12:05 -07:00
|
|
|
def __repr__(self):
|
2019-06-18 21:51:51 -07:00
|
|
|
if self.hash is None:
|
2019-06-18 08:09:37 -07:00
|
|
|
return 'Literal(val={}, hashable={})'.format(self.val, self.hashable)
|
2019-06-18 21:51:51 -07:00
|
|
|
else:
|
|
|
|
return '{}'.format(self.val)
|
2019-05-29 08:12:05 -07:00
|
|
|
|
2019-06-19 10:32:55 -07:00
|
|
|
literalable_types = set()
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class Primitive(object):
|
2019-07-27 10:43:40 -04:00
|
|
|
multiple_results = False # override for multi-output primitives
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __init__(self, name):
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}'.format(self.name)
|
|
|
|
|
|
|
|
def bind(self, *args, **kwargs):
|
|
|
|
assert skip_checks or all(isinstance(arg, Tracer)
|
|
|
|
or valid_jaxtype(arg) for arg in args), args
|
|
|
|
top_trace = find_top_trace(args)
|
|
|
|
if top_trace is None:
|
|
|
|
return self.impl(*args, **kwargs)
|
|
|
|
|
|
|
|
tracers = map(top_trace.full_raise, args)
|
|
|
|
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
|
2019-07-27 15:46:14 -07:00
|
|
|
if self.multiple_results:
|
|
|
|
return map(full_lower, out_tracer)
|
|
|
|
else:
|
|
|
|
return full_lower(out_tracer)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def def_impl(self, impl):
|
|
|
|
self.impl = impl
|
|
|
|
return impl
|
|
|
|
|
2019-02-21 11:47:26 -08:00
|
|
|
def def_abstract_eval(self, abstract_eval):
|
|
|
|
self.abstract_eval = abstract_eval
|
|
|
|
return abstract_eval
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def def_custom_bind(self, bind):
|
|
|
|
self.bind = bind
|
|
|
|
return bind
|
|
|
|
|
|
|
|
def impl(self, *args, **kwargs):
|
|
|
|
raise NotImplementedError("Evaluation rule for '{}' not implemented"
|
|
|
|
.format(self.name))
|
|
|
|
|
2019-02-21 11:47:26 -08:00
|
|
|
def abstract_eval(self, *args, **kwargs):
|
|
|
|
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
|
2019-02-22 08:13:46 -08:00
|
|
|
.format(self.name))
|
2019-02-21 11:47:26 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# -------------------- lifting --------------------
|
|
|
|
|
|
|
|
|
|
|
|
def eval_jaxpr(jaxpr, consts, freevar_vals, *args):
|
|
|
|
def read(v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
|
|
|
return env[v]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def write(v, val):
|
|
|
|
env[v] = val
|
|
|
|
|
|
|
|
env = {}
|
|
|
|
write(unitvar, unit)
|
2019-08-21 13:53:57 -07:00
|
|
|
map(write, jaxpr.constvars, consts)
|
|
|
|
map(write, jaxpr.invars, args)
|
|
|
|
map(write, jaxpr.freevars, freevar_vals)
|
2018-11-17 18:03:33 -08:00
|
|
|
for eqn in jaxpr.eqns:
|
2019-07-27 10:43:40 -04:00
|
|
|
in_vals = map(read, eqn.invars)
|
2019-12-06 22:28:41 -08:00
|
|
|
subfuns = [partial(eval_jaxpr, subjaxpr, map(read, const_bindings),
|
|
|
|
map(read, freevar_bindings))
|
|
|
|
for subjaxpr, const_bindings, freevar_bindings
|
|
|
|
in eqn.bound_subjaxprs]
|
|
|
|
subfuns = map(lu.wrap_init, subfuns)
|
|
|
|
ans = eqn.primitive.bind(*(subfuns + in_vals), **eqn.params)
|
2019-07-27 10:43:40 -04:00
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write(eqn.outvars[0], ans)
|
2019-07-27 15:46:14 -07:00
|
|
|
return map(read, jaxpr.outvars)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def full_lower(val):
|
|
|
|
if isinstance(val, Tracer):
|
|
|
|
return val.full_lower()
|
|
|
|
else:
|
|
|
|
return val
|
|
|
|
|
|
|
|
|
|
|
|
def find_top_trace(xs):
|
|
|
|
try:
|
|
|
|
top_trace = max((x.trace for x in xs if isinstance(x, Tracer)),
|
|
|
|
key=attrgetter('level'))
|
|
|
|
except ValueError:
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
return type(top_trace)(top_trace.master, cur_sublevel())
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------- tracing --------------------
|
|
|
|
|
|
|
|
|
|
|
|
class Trace(object):
|
|
|
|
def __init__(self, master, sublevel):
|
|
|
|
self.master = master
|
|
|
|
self.level = master.level
|
|
|
|
self.sublevel = sublevel
|
|
|
|
|
|
|
|
def full_raise(self, val):
|
|
|
|
if not isinstance(val, Tracer):
|
|
|
|
return self.pure(val)
|
|
|
|
level = self.level
|
|
|
|
sublevel = self.sublevel
|
|
|
|
if val.trace.master is self.master:
|
|
|
|
if val.trace.sublevel == sublevel:
|
|
|
|
return val
|
|
|
|
elif val.trace.sublevel < sublevel:
|
|
|
|
return self.sublift(val)
|
|
|
|
else:
|
|
|
|
raise Exception("Can't lift sublevels {} to {}"
|
|
|
|
.format(val.trace.sublevel, sublevel))
|
|
|
|
elif val.trace.level < level:
|
|
|
|
if val.trace.sublevel > sublevel:
|
|
|
|
raise Exception("Incompatible sublevel: {}, {}"
|
|
|
|
.format(val.trace, (level, sublevel)))
|
|
|
|
return self.lift(val)
|
|
|
|
elif val.trace.level > level:
|
|
|
|
raise Exception("Can't lift {} to {}".format(val, self))
|
|
|
|
elif val.trace.level == self.level:
|
|
|
|
raise Exception("Different traces at same level: {}, {}".format(val, self))
|
|
|
|
else:
|
|
|
|
raise Exception("Can't lift {} to {}".format(val, self))
|
|
|
|
|
|
|
|
|
|
|
|
def pure(self, val):
|
|
|
|
assert False
|
|
|
|
|
|
|
|
def lift(self, tracer):
|
|
|
|
assert False
|
|
|
|
|
|
|
|
def sublift(self, tracer):
|
|
|
|
assert False
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}(level={}/{})'.format(
|
|
|
|
self.__class__.__name__, self.level, self.sublevel)
|
|
|
|
|
|
|
|
|
|
|
|
class Tracer(object):
|
|
|
|
__array_priority__ = 1000
|
2019-08-16 17:18:44 -05:00
|
|
|
__slots__ = ['trace', '__weakref__']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def __array__(self, *args, **kw):
|
2018-11-17 18:03:33 -08:00
|
|
|
raise Exception("Tracer can't be used with raw numpy functions. "
|
|
|
|
"You might have\n import numpy as np\ninstead of\n import jax.numpy as np")
|
|
|
|
|
|
|
|
def __init__(self, trace):
|
|
|
|
self.trace = trace
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
return iter(self.aval._iter(self))
|
|
|
|
|
|
|
|
def __len__(self):
|
2018-12-15 20:00:10 -08:00
|
|
|
return self.aval._len(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
assert False
|
|
|
|
|
|
|
|
def __neg__(self): return self.aval._neg(self)
|
2019-11-18 22:00:32 -05:00
|
|
|
def __pos__(self): return self.aval._pos(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __eq__(self, other): return self.aval._eq(self, other)
|
|
|
|
def __ne__(self, other): return self.aval._ne(self, other)
|
|
|
|
def __lt__(self, other): return self.aval._lt(self, other)
|
|
|
|
def __le__(self, other): return self.aval._le(self, other)
|
|
|
|
def __gt__(self, other): return self.aval._gt(self, other)
|
|
|
|
def __ge__(self, other): return self.aval._ge(self, other)
|
|
|
|
def __abs__(self): return self.aval._abs(self)
|
|
|
|
def __add__(self, other): return self.aval._add(self, other)
|
|
|
|
def __radd__(self, other): return self.aval._radd(self, other)
|
|
|
|
def __sub__(self, other): return self.aval._sub(self, other)
|
|
|
|
def __rsub__(self, other): return self.aval._rsub(self, other)
|
|
|
|
def __mul__(self, other): return self.aval._mul(self, other)
|
|
|
|
def __rmul__(self, other): return self.aval._rmul(self, other)
|
|
|
|
def __div__(self, other): return self.aval._div(self, other)
|
|
|
|
def __rdiv__(self, other): return self.aval._rdiv(self, other)
|
|
|
|
def __truediv__(self, other): return self.aval._truediv(self, other)
|
2018-11-21 14:31:25 -08:00
|
|
|
def __rtruediv__(self, other): return self.aval._rtruediv(self, other)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __floordiv__(self, other): return self.aval._floordiv(self, other)
|
|
|
|
def __rfloordiv__(self, other): return self.aval._rfloordiv(self, other)
|
|
|
|
def __divmod__(self, other): return self.aval._divmod(self, other)
|
|
|
|
def __rdivmod__(self, other): return self.aval._rdivmod(self, other)
|
|
|
|
def __mod__(self, other): return self.aval._mod(self, other)
|
|
|
|
def __rmod__(self, other): return self.aval._rmod(self, other)
|
|
|
|
def __pow__(self, other): return self.aval._pow(self, other)
|
|
|
|
def __rpow__(self, other): return self.aval._rpow(self, other)
|
|
|
|
def __matmul__(self, other): return self.aval._matmul(self, other)
|
|
|
|
def __rmatmul__(self, other): return self.aval._rmatmul(self, other)
|
|
|
|
def __and__(self, other): return self.aval._and(self, other)
|
|
|
|
def __rand__(self, other): return self.aval._rand(self, other)
|
|
|
|
def __or__(self, other): return self.aval._or(self, other)
|
|
|
|
def __ror__(self, other): return self.aval._ror(self, other)
|
|
|
|
def __xor__(self, other): return self.aval._xor(self, other)
|
|
|
|
def __rxor__(self, other): return self.aval._rxor(self, other)
|
2019-02-15 14:09:06 -08:00
|
|
|
def __invert__(self): return self.aval._invert(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __lshift__(self, other): return self.aval._lshift(self, other)
|
|
|
|
def __rshift__(self, other): return self.aval._rshift(self, other)
|
|
|
|
def __getitem__(self, idx): return self.aval._getitem(self, idx)
|
|
|
|
def __nonzero__(self): return self.aval._nonzero(self)
|
|
|
|
def __bool__(self): return self.aval._bool(self)
|
|
|
|
def __float__(self): return self.aval._float(self)
|
|
|
|
def __int__(self): return self.aval._int(self)
|
|
|
|
def __long__(self): return self.aval._long(self)
|
|
|
|
def __complex__(self): return self.aval._complex(self)
|
|
|
|
def __hex__(self): return self.aval._hex(self)
|
|
|
|
def __oct__(self): return self.aval._oct(self)
|
|
|
|
|
2018-12-13 07:24:14 -08:00
|
|
|
def __setitem__(self, idx, val):
|
|
|
|
raise TypeError("JAX 'Tracer' objects do not support item assignment")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
# if the aval property raises an AttributeError, gets caught here
|
|
|
|
assert skip_checks or name != "aval"
|
|
|
|
|
|
|
|
try:
|
|
|
|
attr = getattr(self.aval, name)
|
|
|
|
except KeyError:
|
|
|
|
raise AttributeError(
|
|
|
|
"{} has no attribute {}".format(self.__class__.__name__, name))
|
|
|
|
else:
|
|
|
|
t = type(attr)
|
|
|
|
if t is aval_property:
|
|
|
|
return attr.fget(self)
|
|
|
|
elif t is aval_method:
|
2018-11-21 13:20:44 -08:00
|
|
|
if six.PY3:
|
|
|
|
return types.MethodType(attr.fun, self)
|
|
|
|
else:
|
|
|
|
return types.MethodType(attr.fun, self, None)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return attr
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return 'Traced<{}>with<{}>'.format(self.aval, self.trace)
|
|
|
|
|
2019-12-11 02:48:51 +00:00
|
|
|
def __copy__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __deepcopy__(self, unused_memo):
|
|
|
|
return self
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# these can be used to set up forwarding of properties and instance methods from
|
|
|
|
# Tracer instances to the underlying avals
|
|
|
|
aval_property = namedtuple("aval_property", ["fget"])
|
|
|
|
aval_method = namedtuple("aval_method", ["fun"])
|
|
|
|
|
|
|
|
|
|
|
|
class MasterTrace(object):
|
|
|
|
def __init__(self, level, trace_type):
|
|
|
|
self.level = level
|
|
|
|
self.trace_type = trace_type
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "MasterTrace({},{})".format(self.level, self.trace_type.__name__)
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash((self.level, self.trace_type))
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.level == other.level and self.trace_type == other.trace_type
|
|
|
|
|
|
|
|
|
|
|
|
class TraceStack(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.upward = []
|
|
|
|
self.downward = []
|
|
|
|
|
|
|
|
def next_level(self, bottom):
|
|
|
|
if bottom:
|
|
|
|
return - (len(self.downward) + 1)
|
|
|
|
else:
|
|
|
|
return len(self.upward)
|
|
|
|
|
|
|
|
def push(self, val, bottom):
|
|
|
|
if bottom:
|
|
|
|
self.downward.append(val)
|
|
|
|
else:
|
|
|
|
self.upward.append(val)
|
|
|
|
|
|
|
|
def pop(self, bottom):
|
|
|
|
if bottom:
|
|
|
|
self.downward.pop()
|
|
|
|
else:
|
|
|
|
self.upward.pop()
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return 'Trace stack\n{} ---\n{}'.format(
|
|
|
|
map(' {}\n'.format, self.upward[::-1]),
|
|
|
|
map(' {}\n'.format, self.downward))
|
|
|
|
|
|
|
|
|
|
|
|
class Sublevel(int): pass
|
2019-07-23 09:53:27 -04:00
|
|
|
|
|
|
|
# The global state of the tracer is accessed by a thread-local object.
|
|
|
|
# This allows concurrent tracing in separate threads; passing traced objects
|
|
|
|
# between threads is forbidden.
|
|
|
|
class TraceState(threading.local):
|
|
|
|
def __init__(self):
|
|
|
|
self.trace_stack = TraceStack()
|
|
|
|
self.substack = [Sublevel(0)]
|
|
|
|
|
|
|
|
trace_state = TraceState()
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def cur_sublevel():
|
2019-07-23 09:53:27 -04:00
|
|
|
return trace_state.substack[-1]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def new_master(trace_type, bottom=False):
|
2019-07-23 09:53:27 -04:00
|
|
|
level = trace_state.trace_stack.next_level(bottom)
|
2018-11-17 18:03:33 -08:00
|
|
|
master = MasterTrace(level, trace_type)
|
2019-07-23 09:53:27 -04:00
|
|
|
trace_state.trace_stack.push(master, bottom)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
try:
|
|
|
|
yield master
|
|
|
|
finally:
|
2019-07-23 09:53:27 -04:00
|
|
|
trace_state.trace_stack.pop(bottom)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
if check_leaks:
|
|
|
|
t = ref(master)
|
|
|
|
del master
|
|
|
|
if t() is not None:
|
2019-07-23 09:53:27 -04:00
|
|
|
print(trace_state.trace_stack)
|
2018-11-17 18:03:33 -08:00
|
|
|
raise Exception('Leaked trace {}'.format(t()))
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def new_sublevel():
|
2019-07-23 09:53:27 -04:00
|
|
|
sublevel = Sublevel(len(trace_state.substack))
|
|
|
|
trace_state.substack.append(sublevel)
|
2018-11-17 18:03:33 -08:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
2019-07-23 09:53:27 -04:00
|
|
|
trace_state.substack.pop()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
if check_leaks:
|
|
|
|
t = ref(sublevel)
|
|
|
|
del sublevel
|
|
|
|
if t() is not None:
|
|
|
|
raise Exception('Leaked sublevel {}'.format(t()))
|
|
|
|
|
|
|
|
# -------------------- abstract values --------------------
|
|
|
|
|
|
|
|
|
|
|
|
class AbstractValue(object):
|
|
|
|
__slots__ = []
|
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
assert False
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
try:
|
|
|
|
kv_pairs = ('{}={}'.format(k, v) for k, v in self.__dict__.items())
|
|
|
|
return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
|
|
|
|
except AttributeError:
|
|
|
|
return self.__class__.__name__
|
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def strip_weak_type(self):
|
|
|
|
return self
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class Bot(AbstractValue): pass
|
|
|
|
|
|
|
|
bot = Bot()
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
class AbstractUnit(AbstractValue):
|
|
|
|
def join(self, other): return self
|
2019-08-23 08:17:41 -07:00
|
|
|
def _eq(self, self_traced, other): return get_aval(other) is self
|
2019-07-26 16:48:17 -04:00
|
|
|
|
|
|
|
abstract_unit = AbstractUnit()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lattice_join(x, y):
|
|
|
|
if x is None:
|
|
|
|
return y
|
|
|
|
elif y is None:
|
|
|
|
return x
|
|
|
|
elif isinstance(x, type(y)):
|
|
|
|
return y.join(x)
|
|
|
|
elif isinstance(y, type(x)):
|
|
|
|
return x.join(y)
|
|
|
|
else:
|
|
|
|
raise TypeError((x, y))
|
|
|
|
|
|
|
|
|
|
|
|
def valid_jaxtype(x):
|
|
|
|
try:
|
|
|
|
concrete_aval(x)
|
|
|
|
except TypeError:
|
|
|
|
return False
|
2019-05-06 22:43:31 -07:00
|
|
|
else:
|
|
|
|
return True
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def concrete_aval(x):
|
|
|
|
try:
|
|
|
|
return pytype_aval_mappings[type(x)](x)
|
|
|
|
except KeyError:
|
|
|
|
raise TypeError("{} is not a valid Jax type".format(type(x)))
|
|
|
|
|
|
|
|
|
|
|
|
def get_aval(x):
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
return x.aval
|
|
|
|
else:
|
|
|
|
return concrete_aval(x)
|
|
|
|
|
|
|
|
|
|
|
|
pytype_aval_mappings = {}
|
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
class Unit(object):
|
|
|
|
def __repr__(self): return '*'
|
2019-07-26 16:48:17 -04:00
|
|
|
unit = Unit()
|
2019-07-27 15:46:14 -07:00
|
|
|
literalable_types.add(Unit)
|
|
|
|
|
|
|
|
class UnitVar(object):
|
|
|
|
def __repr__(self): return '*'
|
|
|
|
unitvar = UnitVar()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
pytype_aval_mappings[Unit] = lambda _: abstract_unit
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
identity_p = Primitive('id')
|
|
|
|
identity_p.def_impl(lambda x: x)
|
|
|
|
identity_p.def_custom_bind(lambda x: x)
|
|
|
|
|
|
|
|
# ------------------- Call -------------------
|
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def apply_todos(todos, outs):
|
2018-11-17 18:03:33 -08:00
|
|
|
while todos:
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = map(full_lower, todos.pop()(outs))
|
|
|
|
return outs
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-12-06 22:28:41 -08:00
|
|
|
@lu.transformation_with_aux
|
2019-05-03 12:37:14 -07:00
|
|
|
def process_env_traces(primitive, level, params_tuple, *args):
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = yield args, {}
|
|
|
|
params = dict(params_tuple)
|
2018-11-17 18:03:33 -08:00
|
|
|
todo = []
|
2019-07-27 15:46:14 -07:00
|
|
|
while True:
|
|
|
|
tracers = [x for x in outs if isinstance(x, Tracer) and x.trace.level > level]
|
|
|
|
if tracers:
|
|
|
|
ans = max(tracers, key=lambda x: x.trace.level)
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
trace = type(ans.trace)(ans.trace.master, cur_sublevel())
|
|
|
|
outs = map(trace.full_raise, outs)
|
|
|
|
outs, cur_todo = trace.post_process_call(primitive, outs, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
todo.append(cur_todo)
|
2019-07-27 15:46:14 -07:00
|
|
|
yield outs, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-03 12:37:14 -07:00
|
|
|
def call_bind(primitive, f, *args, **params):
|
2018-11-17 18:03:33 -08:00
|
|
|
top_trace = find_top_trace(args)
|
2019-07-23 09:53:27 -04:00
|
|
|
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
|
2019-05-03 12:37:14 -07:00
|
|
|
params_tuple = tuple(params.items())
|
|
|
|
f, env_trace_todo = process_env_traces(f, primitive, level, params_tuple)
|
2018-11-17 18:03:33 -08:00
|
|
|
if top_trace is None:
|
|
|
|
with new_sublevel():
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = primitive.impl(f, *args, **params)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
tracers = map(top_trace.full_raise, args)
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
|
|
|
|
return apply_todos(env_trace_todo(), outs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-05-03 12:37:14 -07:00
|
|
|
def call_impl(f, *args, **params):
|
2019-11-22 10:53:11 -08:00
|
|
|
del params # params parameterize the call primitive, not the function
|
|
|
|
return f.call_wrapped(*args)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
call_p = Primitive('call')
|
|
|
|
call = partial(call_bind, call_p)
|
|
|
|
call_p.def_custom_bind(call)
|
|
|
|
call_p.def_impl(call_impl)
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- Jaxpr printed representation -------------------
|
|
|
|
|
|
|
|
def check_jaxpr(jaxpr):
|
|
|
|
def context():
|
|
|
|
return "\njaxpr:\n{}\n".format(jaxpr)
|
|
|
|
|
|
|
|
def read_env(env, v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if v not in env and type(v) is not Literal:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise Exception("Variable '{}' not defined".format(v) + context())
|
|
|
|
|
|
|
|
def write_env(env, v):
|
|
|
|
if v in env:
|
|
|
|
raise Exception("Variable {} already bound".format(v) + context())
|
|
|
|
env.add(v)
|
|
|
|
|
|
|
|
env = set()
|
|
|
|
read = partial(read_env, env)
|
|
|
|
write = partial(write_env, env)
|
|
|
|
|
|
|
|
write(unitvar)
|
2019-08-21 13:53:57 -07:00
|
|
|
map(write, jaxpr.constvars)
|
|
|
|
map(write, jaxpr.freevars)
|
|
|
|
map(write, jaxpr.invars)
|
2018-11-17 18:03:33 -08:00
|
|
|
for eqn in jaxpr.eqns:
|
2019-07-27 15:46:14 -07:00
|
|
|
map(read, eqn.invars)
|
2018-11-17 18:03:33 -08:00
|
|
|
for subjaxpr, constvars, freevars in eqn.bound_subjaxprs:
|
|
|
|
map(read, freevars)
|
2019-05-08 17:41:36 -07:00
|
|
|
map(read, constvars)
|
2018-11-17 18:03:33 -08:00
|
|
|
check_jaxpr(subjaxpr)
|
|
|
|
map(write, eqn.outvars)
|
2019-07-27 15:46:14 -07:00
|
|
|
map(read, jaxpr.outvars)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-10-03 17:56:25 -07:00
|
|
|
def pp_vars(vs):
|
2018-11-17 18:03:33 -08:00
|
|
|
return ' '.join(map(str, vs))
|
|
|
|
|
2019-10-03 17:56:25 -07:00
|
|
|
def pp_eqn(eqn):
|
|
|
|
lhs = pp_vars(eqn.outvars)
|
|
|
|
pp_subexpr = pp('')
|
|
|
|
if eqn.bound_subjaxprs:
|
|
|
|
for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs:
|
|
|
|
pp_subexpr = pp_subexpr + (
|
|
|
|
pp_jaxpr(subjaxpr).indent(2)
|
|
|
|
>> pp(' [ {} ; {} ]'.format(pp_vars(const_vars),
|
|
|
|
pp_vars(bound_vars))))
|
|
|
|
return (pp('{} = '.format(lhs)) >>
|
2019-11-28 07:34:40 +01:00
|
|
|
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
|
2019-10-03 17:56:25 -07:00
|
|
|
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr
|
|
|
|
|
|
|
|
def pp_jaxpr(jaxpr):
|
|
|
|
return (pp('{{ lambda {} ; {} ; {}.'.format(pp_vars(jaxpr.constvars),
|
|
|
|
pp_vars(jaxpr.freevars),
|
|
|
|
pp_vars(jaxpr.invars))) +
|
2018-11-17 18:03:33 -08:00
|
|
|
((pp('let ') >>
|
|
|
|
vcat(map(pp_eqn, jaxpr.eqns))) +
|
2019-07-26 16:48:17 -04:00
|
|
|
pp('in {} }}'.format(jaxpr.outvars))).indent(2))
|