2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
import operator
|
2018-11-17 18:03:33 -08:00
|
|
|
from operator import attrgetter
|
|
|
|
from contextlib import contextmanager
|
2020-03-09 20:42:08 +01:00
|
|
|
from collections import namedtuple
|
2020-06-01 13:24:40 -07:00
|
|
|
from functools import total_ordering
|
2019-10-08 10:57:36 -07:00
|
|
|
import itertools as it
|
2018-11-17 18:03:33 -08:00
|
|
|
from weakref import ref
|
2019-07-23 09:53:27 -04:00
|
|
|
import threading
|
2018-11-17 18:03:33 -08:00
|
|
|
import types
|
2020-06-02 10:26:43 -04:00
|
|
|
from typing import (Any, Callable, ClassVar, Dict, Generator,
|
2020-06-01 21:45:36 -04:00
|
|
|
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
|
|
|
|
Type, Union, cast)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
import numpy as onp
|
|
|
|
|
|
|
|
from . import dtypes
|
2020-05-01 09:16:31 +03:00
|
|
|
from .config import FLAGS
|
2018-11-17 18:03:33 -08:00
|
|
|
from . import linear_util as lu
|
2020-03-21 13:54:30 +01:00
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
from .util import safe_zip, safe_map, partial, curry, prod, partialmethod
|
2020-05-26 19:32:29 -07:00
|
|
|
from .pprint_util import pp, vcat, hcat, PrettyPrint
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# TODO(dougalm): the trace cache breaks the leak detector. Consisder solving.
|
|
|
|
check_leaks = False
|
2020-05-01 09:16:31 +03:00
|
|
|
|
|
|
|
"""Disables internal invariant checks."""
|
|
|
|
skip_checks = not FLAGS.jax_enable_checks # not __debug__ # google doesn't use -O
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def skipping_checks():
|
|
|
|
"""Context manager for temporarily disabling checks."""
|
|
|
|
global skip_checks
|
|
|
|
old_value, skip_checks = skip_checks, True
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
skip_checks = old_value
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
zip = safe_zip
|
|
|
|
map = safe_map
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------- jaxprs --------------------
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Jaxpr:
|
2020-06-01 21:45:36 -04:00
|
|
|
constvars: List['Var']
|
|
|
|
invars: List['Var']
|
|
|
|
outvars: List['Atom']
|
|
|
|
eqns: List['JaxprEqn']
|
|
|
|
|
|
|
|
def __init__(self, constvars: Sequence['Var'], invars: Sequence['Var'],
|
|
|
|
outvars: Sequence['Atom'], eqns: Sequence['JaxprEqn']):
|
2020-01-07 13:11:32 -08:00
|
|
|
"""
|
|
|
|
Params:
|
|
|
|
constvars: list of variables introduced for constants (either literals
|
|
|
|
in the Python program, or the result of constant folding during the
|
|
|
|
generation of the Jaxpr). Array constants are replaced with such variables
|
|
|
|
while scalar constants are kept inline.
|
|
|
|
invars: list of input variables. Together, `constvars` and `invars` are
|
|
|
|
the inputs to the Jaxpr.
|
|
|
|
outvars: list of output variables.
|
|
|
|
eqns: list of equations."""
|
2019-07-27 15:46:14 -07:00
|
|
|
self.constvars = list(constvars)
|
|
|
|
self.invars = list(invars)
|
|
|
|
self.outvars = list(outvars)
|
|
|
|
self.eqns = list(eqns)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return str(pp_jaxpr(self))
|
2019-07-27 15:46:14 -07:00
|
|
|
__repr__ = __str__
|
2019-02-06 11:49:21 -05:00
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2020-05-26 19:32:29 -07:00
|
|
|
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
|
|
|
|
for val in params.values():
|
|
|
|
vals = val if isinstance(val, tuple) else (val,)
|
|
|
|
for v in vals:
|
|
|
|
if isinstance(v, Jaxpr):
|
|
|
|
yield v
|
|
|
|
elif isinstance(v, TypedJaxpr):
|
|
|
|
yield v.jaxpr
|
|
|
|
|
|
|
|
|
2020-03-21 13:54:30 +01:00
|
|
|
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
|
2020-02-05 15:38:25 +01:00
|
|
|
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
|
|
|
|
Does not descend recursively into the found subjaxprs.
|
|
|
|
"""
|
|
|
|
for eqn in jaxpr.eqns:
|
2020-05-26 19:32:29 -07:00
|
|
|
yield from jaxprs_in_params(eqn.params)
|
2020-02-05 15:38:25 +01:00
|
|
|
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class TypedJaxpr:
|
2020-06-01 21:45:36 -04:00
|
|
|
jaxpr: Jaxpr
|
|
|
|
literals: List['Any']
|
|
|
|
in_avals: List['AbstractValue']
|
|
|
|
out_avals: List['AbstractValue']
|
|
|
|
|
2020-03-21 13:54:30 +01:00
|
|
|
def __init__(self, jaxpr: Jaxpr, literals: Sequence,
|
2020-06-01 21:45:36 -04:00
|
|
|
in_avals: Sequence['AbstractValue'],
|
|
|
|
out_avals: Sequence['AbstractValue']):
|
2019-04-18 07:19:04 -07:00
|
|
|
assert len(literals) == len(jaxpr.constvars)
|
|
|
|
assert len(in_avals) == len(jaxpr.invars)
|
2019-05-10 14:00:21 -07:00
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
if not skip_checks:
|
|
|
|
in_avals_raised = [raise_to_shaped(v) for v in in_avals]
|
|
|
|
out_avals_raised = [raise_to_shaped(v) for v in out_avals]
|
|
|
|
exp_in_avals = [v.aval for v in jaxpr.invars]
|
|
|
|
exp_out_avals = [v.aval for v in jaxpr.outvars]
|
2020-03-17 17:01:04 -04:00
|
|
|
assert in_avals_raised == exp_in_avals, "expected: {}, got: {}".format(exp_in_avals, in_avals_raised)
|
|
|
|
assert out_avals_raised == exp_out_avals, "expected: {}, got: {}".format(exp_out_avals, out_avals_raised)
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2019-05-10 14:00:21 -07:00
|
|
|
self.jaxpr = jaxpr
|
2019-07-27 15:46:14 -07:00
|
|
|
self.literals = list(literals)
|
|
|
|
self.in_avals = list(in_avals)
|
|
|
|
self.out_avals = list(out_avals)
|
2019-05-10 14:00:21 -07:00
|
|
|
|
|
|
|
def __iter__(self):
|
2019-07-27 15:46:14 -07:00
|
|
|
return iter((self.jaxpr, self.literals, self.in_avals, self.out_avals))
|
2019-04-18 07:19:04 -07:00
|
|
|
|
2019-05-11 13:28:47 -07:00
|
|
|
def __str__(self):
|
|
|
|
# TODO(mattjj): improve this with type annotations?
|
|
|
|
return str(pp_jaxpr(self.jaxpr))
|
2019-07-27 15:46:14 -07:00
|
|
|
__repr__ = __str__
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-23 09:15:16 -07:00
|
|
|
@curry
|
2020-03-21 13:54:30 +01:00
|
|
|
def jaxpr_as_fun(typed_jaxpr: TypedJaxpr, *args):
|
2020-01-07 13:11:32 -08:00
|
|
|
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)
|
2019-04-23 09:15:16 -07:00
|
|
|
|
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
class JaxprEqn(NamedTuple):
|
|
|
|
invars: List['Atom']
|
|
|
|
outvars: List['Var']
|
|
|
|
primitive: 'Primitive'
|
2020-06-02 10:26:43 -04:00
|
|
|
params: Dict[str, Any]
|
2020-06-01 21:45:36 -04:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
def __repr__(self): return str(pp_eqn(self)).rstrip()
|
|
|
|
|
2019-11-19 12:26:30 -08:00
|
|
|
new_jaxpr_eqn = JaxprEqn
|
2019-07-26 18:01:38 -04:00
|
|
|
|
2019-10-03 17:56:25 -07:00
|
|
|
|
2020-01-06 13:29:21 +00:00
|
|
|
@total_ordering
|
2020-06-02 19:10:55 -07:00
|
|
|
class Var:
|
2020-01-10 15:31:51 -08:00
|
|
|
# TODO(frostig,mattjj): We don't override __eq__ or __hash__, so comparison is
|
|
|
|
# by object id, but pretty printing might collide.
|
2020-06-01 21:45:36 -04:00
|
|
|
count: int
|
|
|
|
suffix: str
|
|
|
|
aval: 'AbstractValue'
|
2020-01-10 15:31:51 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, count: int, suffix: str, aval: 'AbstractValue'):
|
2019-10-08 10:57:36 -07:00
|
|
|
self.count = count
|
|
|
|
self.suffix = suffix
|
2020-03-09 09:14:23 +00:00
|
|
|
self.aval = raise_to_shaped(aval)
|
2019-10-08 10:57:36 -07:00
|
|
|
|
2020-01-06 13:29:21 +00:00
|
|
|
def __lt__(self, other):
|
|
|
|
if not isinstance(other, Var):
|
|
|
|
return NotImplemented
|
|
|
|
else:
|
|
|
|
return (self.count, self.suffix) < (other.count, other.suffix)
|
|
|
|
|
2019-10-08 10:57:36 -07:00
|
|
|
def __repr__(self):
|
|
|
|
rem = self.count
|
|
|
|
s = ''
|
|
|
|
while True:
|
|
|
|
rem, i = rem // 26, rem % 26
|
|
|
|
s = chr(97 + i % 26) + s
|
|
|
|
if not rem:
|
|
|
|
break
|
|
|
|
return s + self.suffix
|
|
|
|
|
2020-05-26 11:21:49 -07:00
|
|
|
def _jaxpr_vars(jaxpr):
|
|
|
|
return it.chain(
|
|
|
|
jaxpr.invars, jaxpr.constvars,
|
|
|
|
(v for eqn in jaxpr.eqns for v in eqn.outvars))
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
|
|
|
|
suffix: str = '') -> Callable[['AbstractValue'], Var]:
|
2020-05-26 11:21:49 -07:00
|
|
|
"""Produce distinct variables, printed with the optional suffix.
|
|
|
|
|
|
|
|
If `jaxprs` is provided, the variables produced will be distinct from those in
|
|
|
|
any of the given jaxprs.
|
|
|
|
"""
|
|
|
|
if jaxprs is None:
|
|
|
|
start = 0
|
|
|
|
else:
|
|
|
|
all_vars = it.chain.from_iterable(_jaxpr_vars(j) for j in jaxprs)
|
|
|
|
start = 1 + max((v.count for v in all_vars), default=-1)
|
|
|
|
counter = it.count(start=start)
|
2020-03-09 09:14:23 +00:00
|
|
|
return lambda aval: Var(next(counter), suffix, aval)
|
2019-10-08 10:57:36 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Literal:
|
2019-06-18 21:51:51 -07:00
|
|
|
__slots__ = ["val", "hash"]
|
2019-05-28 22:50:52 -07:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
val: Any
|
|
|
|
hash: Optional[int]
|
|
|
|
|
2019-05-28 22:50:52 -07:00
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
try:
|
2019-06-18 21:51:51 -07:00
|
|
|
self.hash = hash(val)
|
2019-05-28 22:50:52 -07:00
|
|
|
except TypeError:
|
2019-06-19 10:32:55 -07:00
|
|
|
if type(val) in literalable_types:
|
2019-06-18 21:51:51 -07:00
|
|
|
try:
|
2019-06-19 10:32:55 -07:00
|
|
|
self.hash = hash((val.item(), val.dtype))
|
2019-06-18 21:51:51 -07:00
|
|
|
except (TypeError, AttributeError):
|
|
|
|
self.hash = None
|
2019-05-28 22:50:52 -07:00
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
return raise_to_shaped(get_aval(self.val))
|
|
|
|
|
2019-05-28 22:50:52 -07:00
|
|
|
def __hash__(self):
|
2020-01-22 17:19:14 -08:00
|
|
|
assert False
|
2019-05-28 22:50:52 -07:00
|
|
|
|
|
|
|
def __eq__(self, other):
|
2020-01-22 17:19:14 -08:00
|
|
|
assert False
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-29 08:12:05 -07:00
|
|
|
def __repr__(self):
|
2019-06-18 21:51:51 -07:00
|
|
|
if self.hash is None:
|
2020-01-18 08:26:23 -05:00
|
|
|
return 'Literal(val={})'.format(self.val)
|
2019-06-18 21:51:51 -07:00
|
|
|
else:
|
|
|
|
return '{}'.format(self.val)
|
2019-05-29 08:12:05 -07:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
literalable_types: Set[type] = set()
|
2019-06-19 10:32:55 -07:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
Atom = Union[Var, Literal]
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Primitive:
|
2020-06-01 21:45:36 -04:00
|
|
|
name: str
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
multiple_results = False # set for multi-output primitives
|
|
|
|
call_primitive = False # set for call primitives processed in final style
|
|
|
|
map_primitive = False # set for map primitives processed in final style
|
2019-07-27 10:43:40 -04:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, name: str):
|
2018-11-17 18:03:33 -08:00
|
|
|
self.name = name
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}'.format(self.name)
|
|
|
|
|
|
|
|
def bind(self, *args, **kwargs):
|
2020-06-01 13:24:40 -07:00
|
|
|
assert skip_checks or all(isinstance(arg, Tracer)
|
|
|
|
or valid_jaxtype(arg) for arg in args), args
|
2018-11-17 18:03:33 -08:00
|
|
|
top_trace = find_top_trace(args)
|
|
|
|
if top_trace is None:
|
|
|
|
return self.impl(*args, **kwargs)
|
|
|
|
|
|
|
|
tracers = map(top_trace.full_raise, args)
|
|
|
|
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
|
2019-07-27 15:46:14 -07:00
|
|
|
if self.multiple_results:
|
|
|
|
return map(full_lower, out_tracer)
|
|
|
|
else:
|
|
|
|
return full_lower(out_tracer)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def def_impl(self, impl):
|
|
|
|
self.impl = impl
|
|
|
|
return impl
|
|
|
|
|
2019-02-21 11:47:26 -08:00
|
|
|
def def_abstract_eval(self, abstract_eval):
|
|
|
|
self.abstract_eval = abstract_eval
|
|
|
|
return abstract_eval
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def def_custom_bind(self, bind):
|
|
|
|
self.bind = bind
|
|
|
|
return bind
|
|
|
|
|
|
|
|
def impl(self, *args, **kwargs):
|
|
|
|
raise NotImplementedError("Evaluation rule for '{}' not implemented"
|
|
|
|
.format(self.name))
|
|
|
|
|
2019-02-21 11:47:26 -08:00
|
|
|
def abstract_eval(self, *args, **kwargs):
|
|
|
|
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
|
2019-02-22 08:13:46 -08:00
|
|
|
.format(self.name))
|
2019-02-21 11:47:26 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# -------------------- lifting --------------------
|
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
# TODO(necula): this belongs next to pe.new_eqn_recipe, but is needed in
|
|
|
|
# core.py. Plan to move all these utilities to jaxpr.py.
|
2020-06-01 21:45:36 -04:00
|
|
|
def extract_call_jaxpr(
|
|
|
|
primitive: Primitive,
|
|
|
|
params: Dict[str, Any]) -> Tuple[Optional[Jaxpr], Dict[str, Any]]:
|
2020-02-05 15:38:25 +01:00
|
|
|
"""Extract the call primitive subjaxpr from the params.
|
|
|
|
|
2020-02-13 09:28:01 +01:00
|
|
|
Returns the subjaxpr and the params without the "call_jaxpr" value. If this is
|
|
|
|
not a call primitive then returns (None, params).
|
2020-02-05 15:38:25 +01:00
|
|
|
"""
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
if not (primitive.call_primitive or primitive.map_primitive):
|
2020-02-05 15:38:25 +01:00
|
|
|
return (None, params)
|
|
|
|
else:
|
|
|
|
assert "call_jaxpr" in params
|
|
|
|
new_params = dict(params)
|
|
|
|
del new_params["call_jaxpr"]
|
|
|
|
return (params["call_jaxpr"], new_params)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
|
2018-11-17 18:03:33 -08:00
|
|
|
def read(v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
|
|
|
return env[v]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def write(v, val):
|
|
|
|
env[v] = val
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
env: Dict[Var, Any] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
write(unitvar, unit)
|
2019-08-21 13:53:57 -07:00
|
|
|
map(write, jaxpr.constvars, consts)
|
|
|
|
map(write, jaxpr.invars, args)
|
2018-11-17 18:03:33 -08:00
|
|
|
for eqn in jaxpr.eqns:
|
2019-07-27 10:43:40 -04:00
|
|
|
in_vals = map(read, eqn.invars)
|
2020-02-05 15:38:25 +01:00
|
|
|
call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params)
|
|
|
|
if call_jaxpr:
|
|
|
|
subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))]
|
2020-02-03 20:58:56 +01:00
|
|
|
else:
|
|
|
|
subfuns = []
|
2020-02-05 15:38:25 +01:00
|
|
|
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
|
2019-07-27 10:43:40 -04:00
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write(eqn.outvars[0], ans)
|
2019-07-27 15:46:14 -07:00
|
|
|
return map(read, jaxpr.outvars)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
# -------------------- tracing --------------------
|
|
|
|
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
class Trace:
|
|
|
|
master: 'MasterTrace'
|
|
|
|
level: int
|
|
|
|
sublevel: 'Sublevel'
|
|
|
|
|
|
|
|
def __init__(self, master: 'MasterTrace', sublevel: 'Sublevel') -> None:
|
2018-11-17 18:03:33 -08:00
|
|
|
self.master = master
|
|
|
|
self.level = master.level
|
|
|
|
self.sublevel = sublevel
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def full_raise(self, val) -> 'Tracer':
|
2018-11-17 18:03:33 -08:00
|
|
|
if not isinstance(val, Tracer):
|
|
|
|
return self.pure(val)
|
|
|
|
level = self.level
|
|
|
|
sublevel = self.sublevel
|
2020-01-29 16:23:27 -05:00
|
|
|
if val._trace.master is self.master:
|
|
|
|
if val._trace.sublevel == sublevel:
|
2018-11-17 18:03:33 -08:00
|
|
|
return val
|
2020-01-29 16:23:27 -05:00
|
|
|
elif val._trace.sublevel < sublevel:
|
2018-11-17 18:03:33 -08:00
|
|
|
return self.sublift(val)
|
|
|
|
else:
|
2020-03-28 14:55:58 -07:00
|
|
|
raise escaped_tracer_error("Can't lift sublevels {} to {}"
|
|
|
|
.format(val._trace.sublevel, sublevel))
|
2020-01-29 16:23:27 -05:00
|
|
|
elif val._trace.level < level:
|
|
|
|
if val._trace.sublevel > sublevel:
|
2020-03-28 14:55:58 -07:00
|
|
|
raise escaped_tracer_error("Incompatible sublevel: {}, {}"
|
|
|
|
.format(val._trace, (level, sublevel)))
|
2018-11-17 18:03:33 -08:00
|
|
|
return self.lift(val)
|
2020-01-29 16:23:27 -05:00
|
|
|
elif val._trace.level > level:
|
2020-03-28 14:55:58 -07:00
|
|
|
raise escaped_tracer_error("Can't lift level {} to {}"
|
|
|
|
.format(val, self))
|
2020-02-15 06:35:49 +01:00
|
|
|
else: # val._trace.level == self.level:
|
2020-03-28 14:55:58 -07:00
|
|
|
raise escaped_tracer_error("Different traces at same level: {}, {}"
|
|
|
|
.format(val, self))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def pure(self, val):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lift(self, tracer):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def sublift(self, tracer):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __repr__(self):
|
|
|
|
return '{}(level={}/{})'.format(
|
|
|
|
self.__class__.__name__, self.level, self.sublevel)
|
|
|
|
|
2020-03-30 22:06:00 -07:00
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
2020-03-30 11:57:03 -07:00
|
|
|
raise NotImplementedError("must override to handle call-like primitives")
|
|
|
|
|
|
|
|
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
|
|
|
|
# As a default implementation, drop the custom differentiation rule. This
|
|
|
|
# behavior is desirable when staging out of the JAX system, but not when
|
|
|
|
# there are further differentiation transformations to be applied. Override
|
|
|
|
# this method to allow differentiation to be performed downstream.
|
|
|
|
del primitive, jvp # Unused.
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
# See comment in the above process_custom_jvp_call method.
|
|
|
|
del primitive, fwd, bwd, out_trees # Unused.
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
def escaped_tracer_error(detail):
|
|
|
|
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
|
2020-03-28 14:55:58 -07:00
|
|
|
"through global state from a previously traced function.\n"
|
|
|
|
"The functions being transformed should not save traced values to "
|
|
|
|
"global state.\nDetails: {}.")
|
|
|
|
return UnexpectedTracerError(msg.format(detail))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
class UnexpectedTracerError(Exception): pass
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Tracer:
|
2018-11-17 18:03:33 -08:00
|
|
|
__array_priority__ = 1000
|
2020-01-29 16:23:27 -05:00
|
|
|
__slots__ = ['_trace', '__weakref__']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def __array__(self, *args, **kw):
|
2020-05-20 19:09:44 -07:00
|
|
|
msg = ("The numpy.ndarray conversion method __array__() was called on "
|
|
|
|
f"the JAX Tracer object {self}.\n\n"
|
|
|
|
"This error can occur when a JAX Tracer object is passed to a raw "
|
|
|
|
"numpy function, or a method on a numpy.ndarray object. You might "
|
|
|
|
"want to check that you are using `jnp` together with "
|
|
|
|
"`import jax.numpy as jnp` rather than using `np` via "
|
|
|
|
"`import numpy as np`. If this error arises on a line that involves "
|
|
|
|
"array indexing, like `x[idx]`, it may be that the array being "
|
|
|
|
"indexed `x` is a raw numpy.ndarray while the indices `idx` are a "
|
|
|
|
"JAX Tracer instance; in that case, you can instead write "
|
|
|
|
"`jax.device_put(x)[idx]`.")
|
|
|
|
raise Exception(msg)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, trace: Trace):
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
return iter(self.aval._iter(self))
|
|
|
|
|
|
|
|
def __len__(self):
|
2018-12-15 20:00:10 -08:00
|
|
|
return self.aval._len(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __neg__(self): return self.aval._neg(self)
|
2019-11-18 22:00:32 -05:00
|
|
|
def __pos__(self): return self.aval._pos(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __eq__(self, other): return self.aval._eq(self, other)
|
|
|
|
def __ne__(self, other): return self.aval._ne(self, other)
|
|
|
|
def __lt__(self, other): return self.aval._lt(self, other)
|
|
|
|
def __le__(self, other): return self.aval._le(self, other)
|
|
|
|
def __gt__(self, other): return self.aval._gt(self, other)
|
|
|
|
def __ge__(self, other): return self.aval._ge(self, other)
|
|
|
|
def __abs__(self): return self.aval._abs(self)
|
|
|
|
def __add__(self, other): return self.aval._add(self, other)
|
|
|
|
def __radd__(self, other): return self.aval._radd(self, other)
|
|
|
|
def __sub__(self, other): return self.aval._sub(self, other)
|
|
|
|
def __rsub__(self, other): return self.aval._rsub(self, other)
|
|
|
|
def __mul__(self, other): return self.aval._mul(self, other)
|
|
|
|
def __rmul__(self, other): return self.aval._rmul(self, other)
|
|
|
|
def __div__(self, other): return self.aval._div(self, other)
|
|
|
|
def __rdiv__(self, other): return self.aval._rdiv(self, other)
|
|
|
|
def __truediv__(self, other): return self.aval._truediv(self, other)
|
2018-11-21 14:31:25 -08:00
|
|
|
def __rtruediv__(self, other): return self.aval._rtruediv(self, other)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __floordiv__(self, other): return self.aval._floordiv(self, other)
|
|
|
|
def __rfloordiv__(self, other): return self.aval._rfloordiv(self, other)
|
|
|
|
def __divmod__(self, other): return self.aval._divmod(self, other)
|
|
|
|
def __rdivmod__(self, other): return self.aval._rdivmod(self, other)
|
|
|
|
def __mod__(self, other): return self.aval._mod(self, other)
|
|
|
|
def __rmod__(self, other): return self.aval._rmod(self, other)
|
|
|
|
def __pow__(self, other): return self.aval._pow(self, other)
|
|
|
|
def __rpow__(self, other): return self.aval._rpow(self, other)
|
|
|
|
def __matmul__(self, other): return self.aval._matmul(self, other)
|
|
|
|
def __rmatmul__(self, other): return self.aval._rmatmul(self, other)
|
|
|
|
def __and__(self, other): return self.aval._and(self, other)
|
|
|
|
def __rand__(self, other): return self.aval._rand(self, other)
|
|
|
|
def __or__(self, other): return self.aval._or(self, other)
|
|
|
|
def __ror__(self, other): return self.aval._ror(self, other)
|
|
|
|
def __xor__(self, other): return self.aval._xor(self, other)
|
|
|
|
def __rxor__(self, other): return self.aval._rxor(self, other)
|
2019-02-15 14:09:06 -08:00
|
|
|
def __invert__(self): return self.aval._invert(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __lshift__(self, other): return self.aval._lshift(self, other)
|
|
|
|
def __rshift__(self, other): return self.aval._rshift(self, other)
|
|
|
|
def __getitem__(self, idx): return self.aval._getitem(self, idx)
|
|
|
|
def __nonzero__(self): return self.aval._nonzero(self)
|
|
|
|
def __bool__(self): return self.aval._bool(self)
|
|
|
|
def __int__(self): return self.aval._int(self)
|
|
|
|
def __long__(self): return self.aval._long(self)
|
|
|
|
def __hex__(self): return self.aval._hex(self)
|
|
|
|
def __oct__(self): return self.aval._oct(self)
|
|
|
|
|
2020-04-03 21:33:32 -07:00
|
|
|
def __float__(self):
|
|
|
|
raise TypeError("JAX Tracer object cannot be interpreted as a float. "
|
|
|
|
"Try using `x.astype(float)` instead.")
|
|
|
|
|
|
|
|
def __complex__(self):
|
|
|
|
raise TypeError("JAX Tracer object cannot be interpreted as a complex. "
|
|
|
|
"Try using `x.astype(complex)` instead.")
|
|
|
|
|
2018-12-13 07:24:14 -08:00
|
|
|
def __setitem__(self, idx, val):
|
|
|
|
raise TypeError("JAX 'Tracer' objects do not support item assignment")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
# if the aval property raises an AttributeError, gets caught here
|
|
|
|
assert skip_checks or name != "aval"
|
|
|
|
|
|
|
|
try:
|
|
|
|
attr = getattr(self.aval, name)
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise AttributeError(
|
2020-03-09 22:06:12 +02:00
|
|
|
"{} has no attribute {}".format(self.__class__.__name__, name)
|
|
|
|
) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
t = type(attr)
|
|
|
|
if t is aval_property:
|
|
|
|
return attr.fget(self)
|
|
|
|
elif t is aval_method:
|
2020-01-08 13:17:55 -05:00
|
|
|
return types.MethodType(attr.fun, self)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return attr
|
|
|
|
|
|
|
|
def __repr__(self):
|
2020-04-02 21:04:12 -07:00
|
|
|
base = pp('Traced<{}>with<{}>'.format(self.aval, self._trace))
|
|
|
|
contents = self._contents()
|
|
|
|
if contents:
|
|
|
|
base += pp(' with ') >> vcat(pp('{} = '.format(name)) >> pp_payload
|
|
|
|
for name, pp_payload in contents)
|
|
|
|
return str(base)
|
|
|
|
|
|
|
|
def _contents(self):
|
|
|
|
try:
|
|
|
|
return [(name, pp(repr(getattr(self, name)))) for name in self.__slots__]
|
|
|
|
except AttributeError:
|
|
|
|
return ()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-12-11 02:48:51 +00:00
|
|
|
def __copy__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __deepcopy__(self, unused_memo):
|
|
|
|
return self
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# these can be used to set up forwarding of properties and instance methods from
|
|
|
|
# Tracer instances to the underlying avals
|
|
|
|
aval_property = namedtuple("aval_property", ["fget"])
|
|
|
|
aval_method = namedtuple("aval_method", ["fun"])
|
|
|
|
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
class MasterTrace:
|
|
|
|
level: int
|
|
|
|
trace_type: Type[Trace]
|
|
|
|
|
|
|
|
def __init__(self, level, trace_type) -> None:
|
2018-11-17 18:03:33 -08:00
|
|
|
self.level = level
|
|
|
|
self.trace_type = trace_type
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def __repr__(self) -> str:
|
2018-11-17 18:03:33 -08:00
|
|
|
return "MasterTrace({},{})".format(self.level, self.trace_type.__name__)
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def __hash__(self) -> int:
|
2018-11-17 18:03:33 -08:00
|
|
|
return hash((self.level, self.trace_type))
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
return (isinstance(other, MasterTrace) and
|
|
|
|
self.level == other.level and self.trace_type == other.trace_type)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
class TraceStack:
|
|
|
|
upward: List[MasterTrace]
|
|
|
|
downward: List[MasterTrace]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.upward = []
|
|
|
|
self.downward = []
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def next_level(self, bottom: bool) -> int:
|
2018-11-17 18:03:33 -08:00
|
|
|
if bottom:
|
|
|
|
return - (len(self.downward) + 1)
|
|
|
|
else:
|
|
|
|
return len(self.upward)
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def push(self, master_trace: MasterTrace, bottom: bool) -> None:
|
2018-11-17 18:03:33 -08:00
|
|
|
if bottom:
|
2020-03-28 14:55:58 -07:00
|
|
|
self.downward.append(master_trace)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-03-28 14:55:58 -07:00
|
|
|
self.upward.append(master_trace)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def pop(self, bottom: bool) -> None:
|
2018-11-17 18:03:33 -08:00
|
|
|
if bottom:
|
|
|
|
self.downward.pop()
|
|
|
|
else:
|
|
|
|
self.upward.pop()
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def __repr__(self) -> str:
|
2018-11-17 18:03:33 -08:00
|
|
|
return 'Trace stack\n{} ---\n{}'.format(
|
|
|
|
map(' {}\n'.format, self.upward[::-1]),
|
|
|
|
map(' {}\n'.format, self.downward))
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def copy(self):
|
|
|
|
new = TraceStack()
|
|
|
|
new.upward = self.upward[:]
|
|
|
|
new.downward = self.downward[:]
|
|
|
|
return new
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class Sublevel(int): pass
|
2019-07-23 09:53:27 -04:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
|
2019-07-23 09:53:27 -04:00
|
|
|
# The global state of the tracer is accessed by a thread-local object.
|
|
|
|
# This allows concurrent tracing in separate threads; passing traced objects
|
|
|
|
# between threads is forbidden.
|
|
|
|
class TraceState(threading.local):
|
2020-03-28 14:55:58 -07:00
|
|
|
trace_stack: TraceStack
|
|
|
|
substack: List[Sublevel]
|
2020-03-28 14:15:46 -07:00
|
|
|
initial_style: bool
|
2020-03-28 14:55:58 -07:00
|
|
|
|
|
|
|
def __init__(self) -> None:
|
2019-07-23 09:53:27 -04:00
|
|
|
self.trace_stack = TraceStack()
|
|
|
|
self.substack = [Sublevel(0)]
|
2020-03-28 14:15:46 -07:00
|
|
|
self.initial_style = False
|
2019-07-23 09:53:27 -04:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def copy(self):
|
|
|
|
new = TraceState()
|
|
|
|
new.trace_stack = self.trace_stack.copy()
|
|
|
|
new.substack = self.substack[:]
|
2020-03-28 14:15:46 -07:00
|
|
|
new.initial_style = self.initial_style
|
2020-03-28 14:55:58 -07:00
|
|
|
return new
|
2019-07-23 09:53:27 -04:00
|
|
|
trace_state = TraceState()
|
|
|
|
|
2020-04-02 18:03:58 -07:00
|
|
|
def reset_trace_state() -> bool:
|
|
|
|
"Reset the global trace state and return True if it was already clean."
|
|
|
|
if (trace_state.substack != [Sublevel(0)] or
|
|
|
|
trace_state.trace_stack.downward or
|
|
|
|
trace_state.trace_stack.upward):
|
2020-04-02 20:14:12 -07:00
|
|
|
trace_state.__init__() # type: ignore
|
2020-04-02 18:03:58 -07:00
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def cur_sublevel() -> Sublevel:
|
2019-07-23 09:53:27 -04:00
|
|
|
return trace_state.substack[-1]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@contextmanager
|
2020-03-21 13:54:30 +01:00
|
|
|
def new_master(trace_type: Type[Trace], bottom=False) -> Generator[MasterTrace, None, None]:
|
2019-07-23 09:53:27 -04:00
|
|
|
level = trace_state.trace_stack.next_level(bottom)
|
2018-11-17 18:03:33 -08:00
|
|
|
master = MasterTrace(level, trace_type)
|
2019-07-23 09:53:27 -04:00
|
|
|
trace_state.trace_stack.push(master, bottom)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
try:
|
|
|
|
yield master
|
|
|
|
finally:
|
2019-07-23 09:53:27 -04:00
|
|
|
trace_state.trace_stack.pop(bottom)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
if check_leaks:
|
|
|
|
t = ref(master)
|
|
|
|
del master
|
|
|
|
if t() is not None:
|
2019-07-23 09:53:27 -04:00
|
|
|
print(trace_state.trace_stack)
|
2018-11-17 18:03:33 -08:00
|
|
|
raise Exception('Leaked trace {}'.format(t()))
|
|
|
|
|
|
|
|
@contextmanager
|
2020-03-28 14:55:58 -07:00
|
|
|
def new_sublevel() -> Generator[None, None, None]:
|
2019-07-23 09:53:27 -04:00
|
|
|
sublevel = Sublevel(len(trace_state.substack))
|
|
|
|
trace_state.substack.append(sublevel)
|
2018-11-17 18:03:33 -08:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
2019-07-23 09:53:27 -04:00
|
|
|
trace_state.substack.pop()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
if check_leaks:
|
|
|
|
t = ref(sublevel)
|
|
|
|
del sublevel
|
|
|
|
if t() is not None:
|
|
|
|
raise Exception('Leaked sublevel {}'.format(t()))
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def full_lower(val):
|
|
|
|
if isinstance(val, Tracer):
|
|
|
|
return val.full_lower()
|
|
|
|
else:
|
|
|
|
return val
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def find_top_trace(xs) -> Optional[Trace]:
|
2020-06-01 13:24:40 -07:00
|
|
|
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
|
|
|
|
key=attrgetter('level'), default=None)
|
|
|
|
return top_trace and type(top_trace)(top_trace.master, cur_sublevel())
|
2020-03-28 14:55:58 -07:00
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
@contextmanager
|
|
|
|
def initial_style_staging():
|
|
|
|
prev, trace_state.initial_style = trace_state.initial_style, True
|
2020-04-28 00:53:38 -04:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
trace_state.initial_style = prev
|
2020-03-28 14:15:46 -07:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# -------------------- abstract values --------------------
|
|
|
|
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class AbstractValue:
|
2020-03-18 17:06:05 -04:00
|
|
|
__slots__: List[str] = []
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
assert False
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
try:
|
|
|
|
kv_pairs = ('{}={}'.format(k, v) for k, v in self.__dict__.items())
|
|
|
|
return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
|
|
|
|
except AttributeError:
|
|
|
|
return self.__class__.__name__
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def strip_weak_type(self) -> 'AbstractValue':
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return self
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def join(self, other):
|
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class Bot(AbstractValue): pass
|
|
|
|
|
|
|
|
bot = Bot()
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
class AbstractUnit(AbstractValue):
|
2020-05-01 09:16:31 +03:00
|
|
|
def join(self, other):
|
|
|
|
if not skip_checks:
|
|
|
|
assert other is abstract_unit, other
|
|
|
|
return self
|
2019-08-23 08:17:41 -07:00
|
|
|
def _eq(self, self_traced, other): return get_aval(other) is self
|
2019-07-26 16:48:17 -04:00
|
|
|
|
|
|
|
abstract_unit = AbstractUnit()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def lattice_join(x: Optional[AbstractValue],
|
|
|
|
y: Optional[AbstractValue]) -> AbstractValue:
|
2018-11-17 18:03:33 -08:00
|
|
|
if x is None:
|
2020-06-02 19:10:55 -07:00
|
|
|
return cast(AbstractValue, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif y is None:
|
2020-06-02 19:10:55 -07:00
|
|
|
return cast(AbstractValue, x)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(x, type(y)):
|
|
|
|
return y.join(x)
|
|
|
|
elif isinstance(y, type(x)):
|
|
|
|
return x.join(y)
|
|
|
|
else:
|
|
|
|
raise TypeError((x, y))
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
|
|
|
|
Value = Any
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def valid_jaxtype(x):
|
|
|
|
try:
|
|
|
|
concrete_aval(x)
|
|
|
|
except TypeError:
|
|
|
|
return False
|
2019-05-06 22:43:31 -07:00
|
|
|
else:
|
|
|
|
return True
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 13:24:40 -07:00
|
|
|
def check_valid_jaxtype(x):
|
|
|
|
if not valid_jaxtype(x):
|
|
|
|
raise TypeError(f"{x} of type {type(x)} is not a valid JAX type")
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def concrete_aval(x):
|
2020-05-07 01:46:13 -04:00
|
|
|
for typ in type(x).mro():
|
|
|
|
handler = pytype_aval_mappings.get(typ)
|
|
|
|
if handler: return handler(x)
|
2020-06-01 13:24:40 -07:00
|
|
|
raise TypeError(f"{type(x)} is not a valid JAX type")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def get_aval(x):
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
return x.aval
|
|
|
|
else:
|
|
|
|
return concrete_aval(x)
|
|
|
|
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Unit:
|
2019-07-27 15:46:14 -07:00
|
|
|
def __repr__(self): return '*'
|
2019-07-26 16:48:17 -04:00
|
|
|
unit = Unit()
|
2019-07-27 15:46:14 -07:00
|
|
|
literalable_types.add(Unit)
|
|
|
|
|
2020-04-15 18:01:24 -07:00
|
|
|
class UnitVar(Var):
|
2020-05-21 18:28:09 -07:00
|
|
|
count = -1
|
2020-06-01 21:45:36 -04:00
|
|
|
suffix = ''
|
2020-05-21 18:28:09 -07:00
|
|
|
def __init__(self): pass
|
2020-03-09 09:14:23 +00:00
|
|
|
@property
|
|
|
|
def aval(self): return abstract_unit
|
2019-07-27 15:46:14 -07:00
|
|
|
def __repr__(self): return '*'
|
|
|
|
unitvar = UnitVar()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
pytype_aval_mappings[Unit] = lambda _: abstract_unit
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
identity_p = Primitive('id')
|
|
|
|
identity_p.def_impl(lambda x: x)
|
|
|
|
identity_p.def_custom_bind(lambda x: x)
|
|
|
|
|
2020-04-22 10:25:06 +03:00
|
|
|
class ConcretizationTypeError(TypeError): pass
|
|
|
|
|
|
|
|
def raise_concretization_error(val, context=""):
|
|
|
|
msg = (f"Abstract tracer value encountered where concrete value is expected ({context}).\n"
|
|
|
|
"Use transformation parameters such as `static_argnums` for `jit` "
|
|
|
|
"to avoid tracing input values.\n"
|
|
|
|
"See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.\n"
|
|
|
|
f"Encountered value: {val}")
|
|
|
|
raise ConcretizationTypeError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
def concretization_function_error(fun, context=""):
|
2020-03-09 09:14:23 +00:00
|
|
|
fname = getattr(fun, "__name__", fun)
|
2020-04-22 10:25:06 +03:00
|
|
|
fname_context = f"in `{fname}`"
|
|
|
|
if context:
|
|
|
|
fname_context += f" {context}"
|
|
|
|
def error(self, arg):
|
|
|
|
raise_concretization_error(arg, fname_context)
|
2020-03-09 09:14:23 +00:00
|
|
|
return error
|
|
|
|
|
|
|
|
|
2020-04-22 10:25:06 +03:00
|
|
|
def concrete_or_error(typ: Type, val: Any, context=""):
|
|
|
|
"""Like typ(val), but gives the context in the error message.
|
|
|
|
Use with typ either `int`, or `bool`.
|
|
|
|
"""
|
|
|
|
if isinstance(val, Tracer):
|
|
|
|
if isinstance(val.aval, ConcreteArray):
|
|
|
|
return typ(val.aval.val)
|
|
|
|
else:
|
|
|
|
raise_concretization_error(val, context)
|
|
|
|
else:
|
|
|
|
return typ(val)
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
class UnshapedArray(AbstractValue):
|
|
|
|
__slots__ = ['dtype', 'weak_type']
|
|
|
|
array_abstraction_level = 2
|
|
|
|
|
|
|
|
def __init__(self, dtype, weak_type=False):
|
|
|
|
self.dtype = onp.dtype(dtypes.canonicalize_dtype(dtype))
|
|
|
|
self.weak_type = weak_type
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other) and self.dtype == other.dtype and
|
|
|
|
self.weak_type == other.weak_type)
|
|
|
|
|
|
|
|
def __ne__(self, other):
|
|
|
|
return not self == other
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
|
|
|
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
|
|
|
|
# the unique character code via hash(self.dtype.char)
|
|
|
|
return hash((self.dtype, self.weak_type))
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
|
|
|
|
", weak_type=True" if self.weak_type else "")
|
|
|
|
|
|
|
|
_bool = _nonzero = concretization_function_error(bool)
|
|
|
|
_float = concretization_function_error(
|
2020-04-03 21:33:32 -07:00
|
|
|
float, "Try using `x.astype(float)` instead.")
|
2020-03-09 09:14:23 +00:00
|
|
|
_int = concretization_function_error(
|
2020-04-03 21:33:32 -07:00
|
|
|
int, "Try using `x.astype(int)` instead.")
|
2020-03-09 09:14:23 +00:00
|
|
|
_complex = concretization_function_error(
|
2020-04-03 21:33:32 -07:00
|
|
|
complex, "Try using `x.astype(complex)` instead.")
|
2020-03-09 09:14:23 +00:00
|
|
|
_hex = concretization_function_error(hex)
|
|
|
|
_oct = concretization_function_error(oct)
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def at_least_vspace(self) -> AbstractValue:
|
2020-03-09 09:14:23 +00:00
|
|
|
return self
|
|
|
|
|
|
|
|
def join(self, other):
|
|
|
|
if self.dtype == other.dtype:
|
|
|
|
if self.weak_type == other.weak_type:
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
return UnshapedArray(self.dtype, weak_type=False)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def str_short(self) -> str:
|
2020-03-09 09:14:23 +00:00
|
|
|
return self.dtype.name
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def strip_weak_type(self) -> 'UnshapedArray':
|
2020-03-09 09:14:23 +00:00
|
|
|
"""Returns a copy of the aval with weak_type=False."""
|
|
|
|
return UnshapedArray(self.dtype) if self.weak_type else self
|
|
|
|
|
2020-03-24 20:43:33 -07:00
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
msg = ("UnshapedArray has no shape. Please open an issue at "
|
|
|
|
"https://github.com/google/jax/issues because it's unexpected for "
|
|
|
|
"UnshapedArray instances to ever be produced.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
class ShapedArray(UnshapedArray):
|
|
|
|
__slots__ = ['shape']
|
|
|
|
array_abstraction_level = 1
|
|
|
|
|
|
|
|
def __init__(self, shape, dtype, weak_type=False):
|
|
|
|
super(ShapedArray, self).__init__(dtype, weak_type=weak_type)
|
|
|
|
self.shape = canonicalize_shape(shape)
|
|
|
|
|
|
|
|
ndim = property(lambda self: len(self.shape))
|
|
|
|
size = property(lambda self: prod(self.shape))
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
broadcast: ClassVar[Optional[aval_method]] = None
|
|
|
|
transpose: ClassVar[Optional[aval_method]] = None
|
|
|
|
reshape: ClassVar[Optional[aval_method]] = None
|
|
|
|
_iter: ClassVar[Optional[staticmethod]] = None
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other)
|
|
|
|
and self.dtype == other.dtype and self.shape == other.shape
|
|
|
|
and self.weak_type == other.weak_type)
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
|
|
|
# objects, e.g. `onp.zeros(3).dtype is onp.zeros(4).dtype`, or we can use
|
|
|
|
# the unique character code via hash(self.dtype.char)
|
|
|
|
return hash((self.shape, self.dtype, self.weak_type))
|
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def join(self, other):
|
|
|
|
if self.shape == other.shape and self.dtype == other.dtype:
|
|
|
|
if self.weak_type == other.weak_type:
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
return ShapedArray(self.shape, self.dtype, weak_type=False)
|
|
|
|
elif self.dtype == other.dtype:
|
|
|
|
return UnshapedArray(self.dtype)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
|
|
|
def str_short(self):
|
|
|
|
shapestr = ','.join(map(str, self.shape))
|
|
|
|
return '{}[{}]'.format(self.dtype.name, shapestr)
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
try:
|
|
|
|
return self.shape[0]
|
|
|
|
except IndexError:
|
|
|
|
raise TypeError("len() of unsized object") # same as numpy error
|
|
|
|
|
|
|
|
def _len(self, ignored_tracer):
|
|
|
|
return len(self)
|
|
|
|
|
|
|
|
def strip_weak_type(self):
|
|
|
|
return ShapedArray(self.shape, self.dtype) if self.weak_type else self
|
|
|
|
|
|
|
|
|
|
|
|
def _forward_to_value(self, fun, ignored_tracer, *args):
|
|
|
|
return fun(self.val, *args)
|
|
|
|
|
|
|
|
class ConcreteArray(ShapedArray):
|
|
|
|
__slots__ = ['val']
|
|
|
|
array_abstraction_level = 0
|
|
|
|
|
|
|
|
def __init__(self, val, weak_type=False):
|
|
|
|
super(ConcreteArray, self).__init__(onp.shape(val), onp.result_type(val),
|
|
|
|
weak_type=weak_type)
|
|
|
|
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
|
|
|
self.val = val
|
|
|
|
assert self.dtype != onp.dtype('O')
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other) and self.dtype == other.dtype
|
|
|
|
and self.shape == other.shape and self.weak_type == other.weak_type
|
|
|
|
and onp.all(self.val == other.val))
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return id(self.val)
|
|
|
|
|
|
|
|
def at_least_vspace(self):
|
|
|
|
return ShapedArray(self.shape, self.dtype, weak_type=self.weak_type)
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def join(self, other) -> UnshapedArray:
|
2020-03-09 09:14:23 +00:00
|
|
|
if self == other:
|
|
|
|
return self
|
|
|
|
elif self.shape == other.shape and self.dtype == other.dtype:
|
|
|
|
return ShapedArray(self.shape, self.dtype,
|
|
|
|
weak_type=self.weak_type and other.weak_type)
|
|
|
|
elif self.dtype == other.dtype:
|
|
|
|
return UnshapedArray(self.dtype,
|
|
|
|
weak_type=self.weak_type and other.weak_type)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def str_short(self) -> str:
|
2020-03-09 09:14:23 +00:00
|
|
|
return str(self.val)
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def strip_weak_type(self) -> 'ConcreteArray':
|
2020-03-09 09:14:23 +00:00
|
|
|
return ConcreteArray(self.val) if self.weak_type else self
|
|
|
|
|
|
|
|
_bool = _nonzero = partialmethod(_forward_to_value, bool)
|
|
|
|
_int = partialmethod(_forward_to_value, int)
|
|
|
|
_hex = partialmethod(_forward_to_value, hex)
|
|
|
|
_oct = partialmethod(_forward_to_value, oct)
|
|
|
|
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
class AbstractToken(AbstractValue):
|
|
|
|
def join(self, other):
|
|
|
|
if isinstance(other, AbstractToken):
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
assert False, f"Cannot join {self} with {other}"
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
abstract_token = AbstractToken()
|
|
|
|
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def raise_to_shaped(aval: AbstractValue, weak_type=False):
|
2020-03-09 09:14:23 +00:00
|
|
|
if isinstance(aval, ShapedArray):
|
|
|
|
return ShapedArray(aval.shape, aval.dtype, weak_type=weak_type)
|
|
|
|
elif aval is abstract_unit:
|
|
|
|
return abstract_unit
|
|
|
|
elif aval is abstract_token:
|
|
|
|
return abstract_token
|
|
|
|
else:
|
|
|
|
raise TypeError(type(aval))
|
|
|
|
|
|
|
|
# Registry for valid dimension types. This is used by masking.Poly.
|
2020-03-18 17:06:05 -04:00
|
|
|
_DIMENSION_TYPES: Set[type] = {int}
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
def _canonicalize_dimension(dim):
|
|
|
|
if type(dim) in _DIMENSION_TYPES:
|
|
|
|
return dim
|
|
|
|
else:
|
|
|
|
return operator.index(dim)
|
|
|
|
|
|
|
|
def canonicalize_shape(shape):
|
|
|
|
"""Canonicalizes and checks for errors in a user-provided shape value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape: a Python value that represents a shape.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple of integers.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
return tuple(map(_canonicalize_dimension, shape))
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
|
|
|
|
"got {}.")
|
|
|
|
if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
|
|
|
|
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
|
|
|
|
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
|
|
|
|
"smaller subfunctions.")
|
|
|
|
raise TypeError(msg.format(shape))
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-21 18:12:02 -07:00
|
|
|
# ------------------- Call and map -------------------
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def apply_todos(todos, outs):
|
2020-01-05 04:32:48 +01:00
|
|
|
todos_list = list(todos)
|
|
|
|
while todos_list:
|
|
|
|
outs = map(full_lower, todos_list.pop()(outs))
|
2019-07-27 15:46:14 -07:00
|
|
|
return outs
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-12-06 22:28:41 -08:00
|
|
|
@lu.transformation_with_aux
|
2020-04-21 18:12:02 -07:00
|
|
|
def process_env_traces(post_processor: str, primitive: Primitive,
|
2020-06-01 21:45:36 -04:00
|
|
|
level: int, params_tuple: tuple, *args):
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = yield args, {}
|
|
|
|
params = dict(params_tuple)
|
2018-11-17 18:03:33 -08:00
|
|
|
todo = []
|
2019-07-27 15:46:14 -07:00
|
|
|
while True:
|
2020-01-29 16:23:27 -05:00
|
|
|
tracers = [x for x in outs if isinstance(x, Tracer) and x._trace.level > level]
|
2019-07-27 15:46:14 -07:00
|
|
|
if tracers:
|
2020-01-29 16:23:27 -05:00
|
|
|
ans = max(tracers, key=lambda x: x._trace.level)
|
2019-07-27 15:46:14 -07:00
|
|
|
else:
|
|
|
|
break
|
2020-01-29 16:23:27 -05:00
|
|
|
trace = type(ans._trace)(ans._trace.master, cur_sublevel())
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = map(trace.full_raise, outs)
|
2020-04-21 18:12:02 -07:00
|
|
|
post_process = getattr(trace, post_processor)
|
|
|
|
outs, cur_todo = post_process(primitive, outs, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
todo.append(cur_todo)
|
2020-01-05 04:32:48 +01:00
|
|
|
yield outs, tuple(todo) # Ensure the aux output is immutable
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-21 18:12:02 -07:00
|
|
|
def _call_bind(processor: str, post_processor: str, primitive: Primitive,
|
2020-04-24 18:19:24 -07:00
|
|
|
f: lu.WrappedFun, *args, **params):
|
2018-11-17 18:03:33 -08:00
|
|
|
top_trace = find_top_trace(args)
|
2019-07-23 09:53:27 -04:00
|
|
|
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
|
2019-05-03 12:37:14 -07:00
|
|
|
params_tuple = tuple(params.items())
|
2020-04-21 18:12:02 -07:00
|
|
|
f, env_trace_todo = process_env_traces(f, post_processor, primitive, level, params_tuple)
|
2018-11-17 18:03:33 -08:00
|
|
|
if top_trace is None:
|
|
|
|
with new_sublevel():
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = primitive.impl(f, *args, **params)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
tracers = map(top_trace.full_raise, args)
|
2020-04-21 18:12:02 -07:00
|
|
|
process = getattr(top_trace, processor)
|
|
|
|
outs = map(full_lower, process(primitive, f, tracers, params))
|
2019-07-27 15:46:14 -07:00
|
|
|
return apply_todos(env_trace_todo(), outs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-21 18:12:02 -07:00
|
|
|
call_bind = partial(_call_bind, 'process_call', 'post_process_call')
|
|
|
|
map_bind = partial(_call_bind, 'process_map', 'post_process_map')
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def call_impl(f: lu.WrappedFun, *args, **params):
|
2019-11-22 10:53:11 -08:00
|
|
|
del params # params parameterize the call primitive, not the function
|
|
|
|
return f.call_wrapped(*args)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
call_p = Primitive('call')
|
2020-02-14 23:29:33 -08:00
|
|
|
call_p.multiple_results = True
|
2020-02-13 13:55:19 -08:00
|
|
|
call_p.call_primitive = True
|
2018-11-17 18:03:33 -08:00
|
|
|
call = partial(call_bind, call_p)
|
|
|
|
call_p.def_custom_bind(call)
|
|
|
|
call_p.def_impl(call_impl)
|
|
|
|
|
|
|
|
|
2020-04-15 11:05:32 -07:00
|
|
|
# ------------------- Jaxpr checking -------------------
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def mapped_aval(size: int, aval: AbstractValue) -> AbstractValue:
|
2020-05-21 13:11:58 -07:00
|
|
|
if aval is abstract_unit:
|
|
|
|
return aval
|
|
|
|
elif isinstance(aval, ShapedArray):
|
|
|
|
# might be raising abstraction level from Concrete here
|
|
|
|
assert aval.shape[0] == size
|
|
|
|
return ShapedArray(aval.shape[1:], aval.dtype)
|
|
|
|
else:
|
|
|
|
raise TypeError(f"Mapped operand {aval}")
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def unmapped_aval(size: int, aval: AbstractValue) -> AbstractValue:
|
2020-05-21 13:11:58 -07:00
|
|
|
if aval is abstract_unit:
|
|
|
|
return aval
|
|
|
|
elif isinstance(aval, ShapedArray):
|
|
|
|
return ShapedArray((size,) + aval.shape, aval.dtype)
|
|
|
|
else:
|
|
|
|
raise TypeError(f"Mapped output {aval}")
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def typecheck(aval: AbstractValue, x) -> bool:
|
|
|
|
return typecompat(aval, get_aval(x))
|
2020-04-15 17:02:01 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
|
2020-04-15 17:02:01 -07:00
|
|
|
"""Determine whether `aval` conforms to `aval_ref`"""
|
|
|
|
aval_ref = raise_to_shaped(aval_ref).strip_weak_type()
|
|
|
|
try:
|
|
|
|
return aval_ref == lattice_join(aval_ref, aval).strip_weak_type()
|
|
|
|
except TypeError:
|
|
|
|
return False
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def typematch(aval1: UnshapedArray, aval2: UnshapedArray) -> bool:
|
2020-04-15 17:02:01 -07:00
|
|
|
return (raise_to_shaped(aval1).strip_weak_type() ==
|
|
|
|
raise_to_shaped(aval2).strip_weak_type())
|
|
|
|
|
2020-03-21 13:54:30 +01:00
|
|
|
def check_jaxpr(jaxpr: Jaxpr):
|
2020-02-05 15:38:25 +01:00
|
|
|
"""Checks well-formedness of a jaxpr.
|
|
|
|
|
2020-04-15 17:02:48 -07:00
|
|
|
Specifically, check that:
|
|
|
|
- variables that are read are bound beforehand
|
|
|
|
- variables are typed equally throughout a jaxpr
|
|
|
|
- variable type annotations are compatible with their binding expression
|
2020-05-21 20:02:30 -07:00
|
|
|
|
|
|
|
Raises `TypeError` if `jaxpr` is determined invalid. Returns `None` otherwise.
|
2020-02-05 15:38:25 +01:00
|
|
|
"""
|
2020-05-21 20:02:30 -07:00
|
|
|
try:
|
2020-06-02 19:10:55 -07:00
|
|
|
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
|
2020-05-21 20:02:30 -07:00
|
|
|
except Exception as e:
|
|
|
|
exception_type = type(e)
|
|
|
|
msg_context = f"while checking jaxpr:\n\n{jaxpr}\n"
|
|
|
|
if len(e.args) == 0:
|
2020-05-21 20:54:02 -07:00
|
|
|
exception_args = [msg_context]
|
2020-05-21 20:02:30 -07:00
|
|
|
else:
|
|
|
|
msg = f"{e.args[0]}\n\n" + msg_context
|
2020-05-21 20:54:02 -07:00
|
|
|
exception_args = [msg, *e.args[1:]]
|
2020-05-21 20:02:30 -07:00
|
|
|
raise exception_type(*exception_args) from e
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
2020-04-14 22:22:35 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def read(v: Atom) -> AbstractValue:
|
|
|
|
if isinstance(v, Literal):
|
|
|
|
return get_aval(v.val)
|
|
|
|
else:
|
|
|
|
if v not in env:
|
|
|
|
raise TypeError(f"Variable '{v}' not defined")
|
|
|
|
return env[v]
|
2020-04-14 22:22:35 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def write(v: Var, a: AbstractValue) -> None:
|
|
|
|
if v in env:
|
|
|
|
raise TypeError(f"Variable '{v}' already bound")
|
|
|
|
# TODO(frostig): we'd rather check equality or just typecompat here, but
|
|
|
|
# partial_eval.tracers_to_jaxpr types eqn outvars as abstract_unit if the
|
|
|
|
# outvars are unused
|
|
|
|
if not typecompat(v.aval, a) and v.aval is not abstract_unit:
|
|
|
|
raise TypeError(f"Variable '{v}' inconsistently typed as {a}, "
|
|
|
|
f"bound as {v.aval}")
|
|
|
|
env[v] = a
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
env : Dict[Var, AbstractValue] = {}
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
write(unitvar, abstract_unit)
|
|
|
|
map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars])
|
|
|
|
map(write, jaxpr.invars, in_avals)
|
2020-05-21 13:11:58 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
in_avals = map(read, eqn.invars)
|
|
|
|
if eqn.primitive.call_primitive:
|
|
|
|
out_avals = check_call(eqn.primitive, in_avals, eqn.params)
|
|
|
|
elif eqn.primitive.map_primitive:
|
|
|
|
out_avals = check_map(eqn.primitive, in_avals, eqn.params)
|
|
|
|
else:
|
|
|
|
out_avals = check_eqn(eqn.primitive, in_avals, eqn.params)
|
|
|
|
try:
|
|
|
|
map(write, eqn.outvars, out_avals)
|
|
|
|
except TypeError as e:
|
|
|
|
msg, = e.args
|
|
|
|
raise TypeError(msg + f" in '{eqn}'") from None
|
|
|
|
|
|
|
|
map(read, jaxpr.outvars)
|
|
|
|
|
|
|
|
def check_eqn(prim, in_avals, params):
|
2020-05-26 19:32:29 -07:00
|
|
|
for jaxpr in jaxprs_in_params(params):
|
|
|
|
check_jaxpr(jaxpr)
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
out_avals = prim.abstract_eval(*in_avals, **params)
|
|
|
|
if not prim.multiple_results:
|
|
|
|
out_avals = [out_avals]
|
|
|
|
return out_avals
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def check_call(prim, in_avals, params):
|
|
|
|
if "call_jaxpr" not in params:
|
|
|
|
raise TypeError(f"Call primitive {prim} missing 'call_jaxpr' parameter")
|
|
|
|
call_jaxpr = params["call_jaxpr"]
|
|
|
|
|
|
|
|
# These checks also happen in recursive call, but give better errors here.
|
|
|
|
if len(in_avals) != len(call_jaxpr.invars):
|
|
|
|
raise TypeError(f"Call primitive {prim} with {len(call_jaxpr.invars)} "
|
|
|
|
f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
|
|
|
|
f"inputs")
|
|
|
|
binder_avals = [v.aval for v in call_jaxpr.invars]
|
|
|
|
for binder_aval, in_aval in zip(binder_avals, in_avals):
|
|
|
|
if not typecompat(binder_aval, in_aval):
|
2020-05-21 13:11:58 -07:00
|
|
|
raise TypeError(
|
2020-06-02 19:10:55 -07:00
|
|
|
f"Call primitive {prim} passes operand {in_aval} "
|
|
|
|
f"to jaxpr expecting {binder_aval}")
|
2020-05-21 13:11:58 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
_check_jaxpr(call_jaxpr, in_avals)
|
2020-05-21 13:11:58 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
out_avals = [v.aval for v in call_jaxpr.outvars]
|
|
|
|
return out_avals
|
2020-05-21 13:11:58 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def check_map(prim, in_avals, params):
|
|
|
|
if "call_jaxpr" not in params:
|
|
|
|
raise TypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
|
|
|
|
call_jaxpr = params["call_jaxpr"]
|
|
|
|
if "axis_size" not in params:
|
|
|
|
raise TypeError(f"Map primitive {prim} missing 'axis_size' parameter")
|
|
|
|
axis_size = params["axis_size"]
|
|
|
|
if "mapped_invars" not in params:
|
|
|
|
raise TypeError(f"Map primitive {prim} missing 'mapped_invars' parameter")
|
|
|
|
mapped_invars = params["mapped_invars"]
|
|
|
|
|
|
|
|
binder_avals = [unmapped_aval(axis_size, v.aval) if mapped else v.aval
|
|
|
|
for v, mapped in zip(call_jaxpr.invars, mapped_invars)]
|
|
|
|
for binder_aval, in_aval in zip(binder_avals, in_avals):
|
|
|
|
if not typecompat(binder_aval, in_aval):
|
|
|
|
raise TypeError(f"Call primitive {prim} passes operand {in_aval} "
|
|
|
|
f"to jaxpr expecting {binder_aval}")
|
|
|
|
|
|
|
|
mapped_avals = [mapped_aval(axis_size, aval) if mapped else aval
|
|
|
|
for aval, mapped in zip(in_avals, mapped_invars)]
|
|
|
|
_check_jaxpr(call_jaxpr, mapped_avals)
|
|
|
|
|
|
|
|
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
|
|
|
out_avals = [unmapped_aval(axis_size, aval) for aval in mapped_out_avals]
|
2020-05-21 13:11:58 -07:00
|
|
|
return out_avals
|
2020-04-15 11:05:32 -07:00
|
|
|
|
|
|
|
|
|
|
|
# ------------------- Jaxpr printed representation -------------------
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 10:26:43 -04:00
|
|
|
def pp_vars(vs: Sequence[Any]) -> str:
|
2020-04-15 11:05:32 -07:00
|
|
|
return ' '.join(map(str, vs))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-21 13:54:30 +01:00
|
|
|
def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
|
2020-01-26 23:27:56 -08:00
|
|
|
filtered_params = {k: v for k, v in params.items()
|
2020-05-26 19:32:29 -07:00
|
|
|
if (k != 'branches' and
|
|
|
|
not isinstance(v, (Jaxpr, TypedJaxpr)))}
|
2020-01-26 23:27:56 -08:00
|
|
|
return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items()))
|
|
|
|
|
2020-03-21 13:54:30 +01:00
|
|
|
def pp_eqn(eqn: JaxprEqn) -> PrettyPrint:
|
2019-10-03 17:56:25 -07:00
|
|
|
lhs = pp_vars(eqn.outvars)
|
|
|
|
pp_subexpr = pp('')
|
|
|
|
return (pp('{} = '.format(lhs)) >>
|
2019-11-28 07:34:40 +01:00
|
|
|
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
|
2019-10-03 17:56:25 -07:00
|
|
|
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def pp_jaxpr(jaxpr: Jaxpr) -> PrettyPrint:
|
2020-03-19 11:26:29 -07:00
|
|
|
pp_outvars = str(tuple(jaxpr.outvars))
|
2020-01-07 13:11:32 -08:00
|
|
|
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
|
|
|
|
pp_vars(jaxpr.invars))) +
|
2018-11-17 18:03:33 -08:00
|
|
|
((pp('let ') >>
|
|
|
|
vcat(map(pp_eqn, jaxpr.eqns))) +
|
2020-02-10 11:40:05 +01:00
|
|
|
pp('in {} }}'.format(pp_outvars))).indent(2))
|
2020-05-26 19:32:29 -07:00
|
|
|
|
|
|
|
def pp_jaxprs(jaxprs) -> PrettyPrint:
|
|
|
|
jaxprs = [j.jaxpr if isinstance(j, TypedJaxpr) else j for j in jaxprs]
|
|
|
|
return pp('( ') >> vcat(map(pp_jaxpr, jaxprs)) >> pp(' )')
|
|
|
|
|
|
|
|
def pp_kv_pair(k, v):
|
|
|
|
return pp(f'{k}=') >> (pp_jaxprs(v) if k == 'branches' else pp(v))
|
|
|
|
|
|
|
|
def pp_kv_pairs(kv_pairs):
|
|
|
|
if kv_pairs:
|
|
|
|
return pp('[ ') >> vcat([pp_kv_pair(k, v) for k, v in kv_pairs]) >> pp(' ]')
|
|
|
|
else:
|
|
|
|
return pp('')
|