2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2020-03-09 20:42:08 +01:00
|
|
|
from collections import namedtuple
|
2021-09-13 17:24:44 -04:00
|
|
|
from contextlib import contextmanager
|
|
|
|
from functools import partial, total_ordering
|
2021-07-21 13:27:48 +01:00
|
|
|
import gc
|
2019-10-08 10:57:36 -07:00
|
|
|
import itertools as it
|
2021-09-13 17:24:44 -04:00
|
|
|
import operator
|
|
|
|
from operator import attrgetter
|
2019-07-23 09:53:27 -04:00
|
|
|
import threading
|
2018-11-17 18:03:33 -08:00
|
|
|
import types
|
2020-06-02 10:26:43 -04:00
|
|
|
from typing import (Any, Callable, ClassVar, Dict, Generator,
|
2020-06-01 21:45:36 -04:00
|
|
|
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
|
2020-11-24 09:58:44 -08:00
|
|
|
Type, Union, cast, Iterable, Hashable)
|
2021-09-13 17:24:44 -04:00
|
|
|
from weakref import ref
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2021-04-07 19:35:17 -07:00
|
|
|
from ._src import dtypes
|
2021-04-21 06:36:08 -07:00
|
|
|
from ._src import config as jax_config
|
2021-04-19 08:52:48 -07:00
|
|
|
from ._src.config import FLAGS, config
|
2021-03-02 09:29:59 -08:00
|
|
|
from .errors import (ConcretizationTypeError, TracerArrayConversionError,
|
2021-06-30 10:46:37 +01:00
|
|
|
TracerIntegerConversionError, UnexpectedTracerError)
|
2018-11-17 18:03:33 -08:00
|
|
|
from . import linear_util as lu
|
2020-03-21 13:54:30 +01:00
|
|
|
|
2020-11-04 11:54:01 -08:00
|
|
|
from jax._src import source_info_util
|
2021-09-13 17:24:44 -04:00
|
|
|
from ._src.util import (safe_zip, safe_map, curry, prod, partialmethod,
|
|
|
|
tuple_insert, tuple_delete, cache, as_hashable_function,
|
|
|
|
HashableFunction)
|
2021-09-24 22:08:42 -04:00
|
|
|
import jax._src.pretty_printer as pp
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-11-04 09:01:18 -08:00
|
|
|
from ._src import traceback_util
|
2020-10-26 10:03:06 -07:00
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
zip = safe_zip
|
|
|
|
map = safe_map
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------- jaxprs --------------------
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Jaxpr:
|
2020-06-01 21:45:36 -04:00
|
|
|
constvars: List['Var']
|
|
|
|
invars: List['Var']
|
|
|
|
outvars: List['Atom']
|
|
|
|
eqns: List['JaxprEqn']
|
|
|
|
|
|
|
|
def __init__(self, constvars: Sequence['Var'], invars: Sequence['Var'],
|
|
|
|
outvars: Sequence['Atom'], eqns: Sequence['JaxprEqn']):
|
2020-01-07 13:11:32 -08:00
|
|
|
"""
|
2020-09-18 10:07:13 -07:00
|
|
|
Args:
|
|
|
|
constvars: list of variables introduced for constants. Array constants are
|
|
|
|
replaced with such variables while scalar constants are kept inline.
|
2020-01-07 13:11:32 -08:00
|
|
|
invars: list of input variables. Together, `constvars` and `invars` are
|
|
|
|
the inputs to the Jaxpr.
|
|
|
|
outvars: list of output variables.
|
2020-09-18 10:07:13 -07:00
|
|
|
eqns: list of equations.
|
|
|
|
"""
|
2019-07-27 15:46:14 -07:00
|
|
|
self.constvars = list(constvars)
|
|
|
|
self.invars = list(invars)
|
|
|
|
self.outvars = list(outvars)
|
|
|
|
self.eqns = list(eqns)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return str(pp_jaxpr(self))
|
2019-07-27 15:46:14 -07:00
|
|
|
__repr__ = __str__
|
2019-02-06 11:49:21 -05:00
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
|
|
|
|
doc = pp_jaxpr(self, source_info=source_info, print_shapes=print_shapes)
|
|
|
|
return doc.format(**kw)
|
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2020-05-26 19:32:29 -07:00
|
|
|
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
|
|
|
|
for val in params.values():
|
|
|
|
vals = val if isinstance(val, tuple) else (val,)
|
|
|
|
for v in vals:
|
|
|
|
if isinstance(v, Jaxpr):
|
|
|
|
yield v
|
2020-09-18 10:07:13 -07:00
|
|
|
elif isinstance(v, ClosedJaxpr):
|
2020-05-26 19:32:29 -07:00
|
|
|
yield v.jaxpr
|
|
|
|
|
|
|
|
|
2020-03-21 13:54:30 +01:00
|
|
|
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
|
2020-02-05 15:38:25 +01:00
|
|
|
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
|
2020-09-18 10:07:13 -07:00
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
Does not descend recursively into the found subjaxprs.
|
|
|
|
"""
|
|
|
|
for eqn in jaxpr.eqns:
|
2020-05-26 19:32:29 -07:00
|
|
|
yield from jaxprs_in_params(eqn.params)
|
2020-02-05 15:38:25 +01:00
|
|
|
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
class ClosedJaxpr:
|
2020-06-01 21:45:36 -04:00
|
|
|
jaxpr: Jaxpr
|
2020-09-18 10:07:13 -07:00
|
|
|
consts: List['Any']
|
2020-06-01 21:45:36 -04:00
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
|
|
|
|
assert len(consts) == len(jaxpr.constvars)
|
|
|
|
self.jaxpr = jaxpr
|
|
|
|
self.consts = list(consts)
|
2019-05-10 14:00:21 -07:00
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
@property
|
|
|
|
def in_avals(self):
|
|
|
|
return [v.aval for v in self.jaxpr.invars]
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
@property
|
|
|
|
def out_avals(self):
|
|
|
|
return [v.aval for v in self.jaxpr.outvars]
|
2019-05-10 14:00:21 -07:00
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
@property
|
|
|
|
def literals(self):
|
|
|
|
return self.consts # backwards compatible alias
|
2019-04-18 07:19:04 -07:00
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
@property
|
|
|
|
def eqns(self):
|
|
|
|
return self.jaxpr.eqns
|
|
|
|
|
2020-11-03 12:11:03 +00:00
|
|
|
def map_jaxpr(self, f):
|
|
|
|
return ClosedJaxpr(f(self.jaxpr), self.consts)
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def __str__(self): return str(self.jaxpr)
|
|
|
|
def __repr__(self): return repr(self.jaxpr)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
|
|
|
|
return pp_jaxpr(self.jaxpr, source_info=source_info,
|
|
|
|
print_shapes=print_shapes).format(**kw)
|
|
|
|
|
2019-04-23 09:15:16 -07:00
|
|
|
@curry
|
2020-09-18 10:07:13 -07:00
|
|
|
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
|
|
|
|
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
|
2019-04-23 09:15:16 -07:00
|
|
|
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
class JaxprEqn(NamedTuple):
|
|
|
|
invars: List['Atom']
|
|
|
|
outvars: List['Var']
|
|
|
|
primitive: 'Primitive'
|
2020-06-02 10:26:43 -04:00
|
|
|
params: Dict[str, Any]
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info: Optional[source_info_util.Traceback]
|
2020-06-01 21:45:36 -04:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
def __repr__(self): return str(pp_eqn(self)).rstrip()
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def new_jaxpr_eqn(invars, outvars, primitive, params, source_info=None):
|
2021-08-30 11:10:10 -07:00
|
|
|
if primitive.call_primitive:
|
|
|
|
assert len(outvars) == len(params["call_jaxpr"].outvars)
|
2020-07-30 12:59:36 -07:00
|
|
|
return JaxprEqn(invars, outvars, primitive, params, source_info)
|
2019-07-26 18:01:38 -04:00
|
|
|
|
2019-10-03 17:56:25 -07:00
|
|
|
|
2020-01-06 13:29:21 +00:00
|
|
|
@total_ordering
|
2020-06-02 19:10:55 -07:00
|
|
|
class Var:
|
2020-01-10 15:31:51 -08:00
|
|
|
# TODO(frostig,mattjj): We don't override __eq__ or __hash__, so comparison is
|
|
|
|
# by object id, but pretty printing might collide.
|
2020-06-01 21:45:36 -04:00
|
|
|
count: int
|
|
|
|
suffix: str
|
|
|
|
aval: 'AbstractValue'
|
2020-01-10 15:31:51 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, count: int, suffix: str, aval: 'AbstractValue'):
|
2019-10-08 10:57:36 -07:00
|
|
|
self.count = count
|
|
|
|
self.suffix = suffix
|
2020-03-09 09:14:23 +00:00
|
|
|
self.aval = raise_to_shaped(aval)
|
2019-10-08 10:57:36 -07:00
|
|
|
|
2020-01-06 13:29:21 +00:00
|
|
|
def __lt__(self, other):
|
|
|
|
if not isinstance(other, Var):
|
|
|
|
return NotImplemented
|
|
|
|
else:
|
|
|
|
return (self.count, self.suffix) < (other.count, other.suffix)
|
|
|
|
|
2019-10-08 10:57:36 -07:00
|
|
|
def __repr__(self):
|
|
|
|
rem = self.count
|
|
|
|
s = ''
|
|
|
|
while True:
|
|
|
|
rem, i = rem // 26, rem % 26
|
|
|
|
s = chr(97 + i % 26) + s
|
|
|
|
if not rem:
|
|
|
|
break
|
|
|
|
return s + self.suffix
|
|
|
|
|
2020-05-26 11:21:49 -07:00
|
|
|
def _jaxpr_vars(jaxpr):
|
|
|
|
return it.chain(
|
|
|
|
jaxpr.invars, jaxpr.constvars,
|
|
|
|
(v for eqn in jaxpr.eqns for v in eqn.outvars))
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
|
|
|
|
suffix: str = '') -> Callable[['AbstractValue'], Var]:
|
2020-05-26 11:21:49 -07:00
|
|
|
"""Produce distinct variables, printed with the optional suffix.
|
|
|
|
|
|
|
|
If `jaxprs` is provided, the variables produced will be distinct from those in
|
|
|
|
any of the given jaxprs.
|
|
|
|
"""
|
|
|
|
if jaxprs is None:
|
|
|
|
start = 0
|
|
|
|
else:
|
|
|
|
all_vars = it.chain.from_iterable(_jaxpr_vars(j) for j in jaxprs)
|
|
|
|
start = 1 + max((v.count for v in all_vars), default=-1)
|
|
|
|
counter = it.count(start=start)
|
2020-03-09 09:14:23 +00:00
|
|
|
return lambda aval: Var(next(counter), suffix, aval)
|
2019-10-08 10:57:36 -07:00
|
|
|
|
2020-06-08 16:13:30 -07:00
|
|
|
# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
|
|
|
|
# the assignment is dropped, i.e. that an expression's output value will never
|
|
|
|
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
|
|
|
|
# treat it as a special case of one. Its `aval` is similarly inexact.
|
|
|
|
class DropVar(Var):
|
|
|
|
count = -1
|
|
|
|
suffix = ''
|
|
|
|
def __init__(self): pass
|
|
|
|
@property
|
|
|
|
def aval(self): return abstract_unit
|
|
|
|
def __repr__(self): return '_'
|
|
|
|
dropvar = DropVar()
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Literal:
|
2019-06-18 21:51:51 -07:00
|
|
|
__slots__ = ["val", "hash"]
|
2019-05-28 22:50:52 -07:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
val: Any
|
|
|
|
hash: Optional[int]
|
|
|
|
|
2019-05-28 22:50:52 -07:00
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
try:
|
2019-06-18 21:51:51 -07:00
|
|
|
self.hash = hash(val)
|
2019-05-28 22:50:52 -07:00
|
|
|
except TypeError:
|
2019-06-19 10:32:55 -07:00
|
|
|
if type(val) in literalable_types:
|
2019-06-18 21:51:51 -07:00
|
|
|
try:
|
2019-06-19 10:32:55 -07:00
|
|
|
self.hash = hash((val.item(), val.dtype))
|
2020-07-30 12:59:36 -07:00
|
|
|
except (TypeError, AttributeError, ValueError):
|
2019-06-18 21:51:51 -07:00
|
|
|
self.hash = None
|
2019-05-28 22:50:52 -07:00
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
@property
|
|
|
|
def aval(self):
|
|
|
|
return raise_to_shaped(get_aval(self.val))
|
|
|
|
|
2019-05-28 22:50:52 -07:00
|
|
|
def __hash__(self):
|
2020-01-22 17:19:14 -08:00
|
|
|
assert False
|
2019-05-28 22:50:52 -07:00
|
|
|
|
2019-05-29 08:12:05 -07:00
|
|
|
def __repr__(self):
|
2020-07-30 12:59:36 -07:00
|
|
|
if hasattr(self, 'hash'):
|
2019-06-18 21:51:51 -07:00
|
|
|
return '{}'.format(self.val)
|
2020-07-30 12:59:36 -07:00
|
|
|
else:
|
|
|
|
return 'Literal(val={})'.format(self.val)
|
2019-05-29 08:12:05 -07:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
literalable_types: Set[type] = set()
|
2019-06-19 10:32:55 -07:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
Atom = Union[Var, Literal]
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Primitive:
|
2020-06-01 21:45:36 -04:00
|
|
|
name: str
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
multiple_results = False # set for multi-output primitives
|
|
|
|
call_primitive = False # set for call primitives processed in final style
|
|
|
|
map_primitive = False # set for map primitives processed in final style
|
2021-06-16 11:10:42 -07:00
|
|
|
_dispatch_on_params = False # whether to include axis names from params in dispatch
|
2019-07-27 10:43:40 -04:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, name: str):
|
2018-11-17 18:03:33 -08:00
|
|
|
self.name = name
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}'.format(self.name)
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
|
|
|
|
def bind(self, *args, **params):
|
2021-03-19 13:49:38 -07:00
|
|
|
assert (not config.jax_enable_checks or
|
|
|
|
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
|
2021-04-09 12:43:40 +00:00
|
|
|
top_trace = find_top_trace(
|
|
|
|
args, used_axis_names(self, params) if self._dispatch_on_params else None)
|
2018-11-17 18:03:33 -08:00
|
|
|
tracers = map(top_trace.full_raise, args)
|
2020-09-15 08:06:46 -07:00
|
|
|
out = top_trace.process_primitive(self, tracers, params)
|
|
|
|
return map(full_lower, out) if self.multiple_results else full_lower(out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def def_impl(self, impl):
|
|
|
|
self.impl = impl
|
|
|
|
return impl
|
|
|
|
|
2019-02-21 11:47:26 -08:00
|
|
|
def def_abstract_eval(self, abstract_eval):
|
|
|
|
self.abstract_eval = abstract_eval
|
|
|
|
return abstract_eval
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def def_custom_bind(self, bind):
|
|
|
|
self.bind = bind
|
|
|
|
return bind
|
|
|
|
|
2020-07-26 22:38:14 -07:00
|
|
|
def impl(self, *args, **params):
|
2018-11-17 18:03:33 -08:00
|
|
|
raise NotImplementedError("Evaluation rule for '{}' not implemented"
|
|
|
|
.format(self.name))
|
|
|
|
|
2020-07-26 22:38:14 -07:00
|
|
|
def abstract_eval(self, *args, **params):
|
2019-02-21 11:47:26 -08:00
|
|
|
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
|
2019-02-22 08:13:46 -08:00
|
|
|
.format(self.name))
|
2019-02-21 11:47:26 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# -------------------- lifting --------------------
|
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
# TODO(necula): this belongs next to pe.new_eqn_recipe, but is needed in
|
|
|
|
# core.py. Plan to move all these utilities to jaxpr.py.
|
2020-06-01 21:45:36 -04:00
|
|
|
def extract_call_jaxpr(
|
|
|
|
primitive: Primitive,
|
|
|
|
params: Dict[str, Any]) -> Tuple[Optional[Jaxpr], Dict[str, Any]]:
|
2020-02-05 15:38:25 +01:00
|
|
|
"""Extract the call primitive subjaxpr from the params.
|
|
|
|
|
2020-02-13 09:28:01 +01:00
|
|
|
Returns the subjaxpr and the params without the "call_jaxpr" value. If this is
|
|
|
|
not a call primitive then returns (None, params).
|
2020-02-05 15:38:25 +01:00
|
|
|
"""
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
if not (primitive.call_primitive or primitive.map_primitive):
|
2020-02-05 15:38:25 +01:00
|
|
|
return (None, params)
|
|
|
|
else:
|
|
|
|
assert "call_jaxpr" in params
|
|
|
|
new_params = dict(params)
|
|
|
|
del new_params["call_jaxpr"]
|
|
|
|
return (params["call_jaxpr"], new_params)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-07-29 10:34:43 -07:00
|
|
|
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
|
2020-11-20 11:43:11 +00:00
|
|
|
def traverse_jaxpr_params(f, params):
|
|
|
|
"""Applies f to each jaxpr parameter and returns a tuple of returned values."""
|
2021-07-29 10:34:43 -07:00
|
|
|
return {name: f(p)
|
2021-07-19 08:38:58 -07:00
|
|
|
for name, param in params.items()
|
2021-07-29 10:34:43 -07:00
|
|
|
for p in (param if isinstance(param, (tuple, list)) else [param])
|
|
|
|
if type(p) in (Jaxpr, ClosedJaxpr)}
|
2020-11-20 11:43:11 +00:00
|
|
|
|
|
|
|
|
2021-07-30 09:33:34 +01:00
|
|
|
def eval_jaxpr_eqn(eqn, in_vals):
|
|
|
|
"""Evaluates the jaxpr equation with the provided input values."""
|
|
|
|
call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params)
|
|
|
|
if call_jaxpr:
|
|
|
|
subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))]
|
|
|
|
else:
|
|
|
|
subfuns = []
|
|
|
|
if eqn.primitive in initial_to_final_param_rules:
|
|
|
|
bind_params = initial_to_final_param_rules[eqn.primitive](params)
|
|
|
|
elif eqn.primitive.map_primitive:
|
|
|
|
out_axes_thunk = HashableFunction(lambda: params['out_axes'],
|
|
|
|
closure=params['out_axes'])
|
|
|
|
bind_params = dict(params, out_axes_thunk=out_axes_thunk)
|
|
|
|
del bind_params['out_axes']
|
|
|
|
else:
|
|
|
|
bind_params = params
|
|
|
|
with source_info_util.user_context(eqn.source_info):
|
|
|
|
return eqn.primitive.bind(*(subfuns + in_vals), **bind_params)
|
|
|
|
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
|
2018-11-17 18:03:33 -08:00
|
|
|
def read(v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if type(v) is Literal:
|
|
|
|
return v.val
|
|
|
|
else:
|
|
|
|
return env[v]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def write(v, val):
|
|
|
|
env[v] = val
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
env: Dict[Var, Any] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
write(unitvar, unit)
|
2019-08-21 13:53:57 -07:00
|
|
|
map(write, jaxpr.constvars, consts)
|
|
|
|
map(write, jaxpr.invars, args)
|
2018-11-17 18:03:33 -08:00
|
|
|
for eqn in jaxpr.eqns:
|
2021-07-30 09:33:34 +01:00
|
|
|
ans = eval_jaxpr_eqn(eqn, map(read, eqn.invars))
|
2019-07-27 10:43:40 -04:00
|
|
|
if eqn.primitive.multiple_results:
|
|
|
|
map(write, eqn.outvars, ans)
|
|
|
|
else:
|
|
|
|
write(eqn.outvars[0], ans)
|
2019-07-27 15:46:14 -07:00
|
|
|
return map(read, jaxpr.outvars)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-04-22 15:30:03 -07:00
|
|
|
initial_to_final_param_rules: Dict[Primitive, Callable] = {}
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
# -------------------- tracing --------------------
|
|
|
|
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
class Trace:
|
2020-08-30 01:16:51 -07:00
|
|
|
__slots__ = ['main', 'level', 'sublevel']
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-08-30 01:16:51 -07:00
|
|
|
main: 'MainTrace'
|
2020-03-28 14:55:58 -07:00
|
|
|
level: int
|
|
|
|
sublevel: 'Sublevel'
|
|
|
|
|
2020-08-30 01:16:51 -07:00
|
|
|
def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None:
|
|
|
|
self.main = main
|
|
|
|
self.level = main.level
|
2018-11-17 18:03:33 -08:00
|
|
|
self.sublevel = sublevel
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def full_raise(self, val) -> 'Tracer':
|
2018-11-17 18:03:33 -08:00
|
|
|
if not isinstance(val, Tracer):
|
|
|
|
return self.pure(val)
|
2020-09-16 23:59:58 -07:00
|
|
|
val._assert_live()
|
2018-11-17 18:03:33 -08:00
|
|
|
level = self.level
|
|
|
|
sublevel = self.sublevel
|
2020-08-30 01:16:51 -07:00
|
|
|
if val._trace.main is self.main:
|
2020-01-29 16:23:27 -05:00
|
|
|
if val._trace.sublevel == sublevel:
|
2018-11-17 18:03:33 -08:00
|
|
|
return val
|
2020-01-29 16:23:27 -05:00
|
|
|
elif val._trace.sublevel < sublevel:
|
2018-11-17 18:03:33 -08:00
|
|
|
return self.sublift(val)
|
|
|
|
else:
|
2021-01-05 14:52:54 -08:00
|
|
|
raise escaped_tracer_error(
|
2021-01-18 20:37:12 -08:00
|
|
|
val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}")
|
2020-01-29 16:23:27 -05:00
|
|
|
elif val._trace.level < level:
|
|
|
|
if val._trace.sublevel > sublevel:
|
2021-01-05 14:52:54 -08:00
|
|
|
raise escaped_tracer_error(
|
2021-01-18 20:37:12 -08:00
|
|
|
val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}")
|
2018-11-17 18:03:33 -08:00
|
|
|
return self.lift(val)
|
2020-01-29 16:23:27 -05:00
|
|
|
elif val._trace.level > level:
|
2021-01-05 14:52:54 -08:00
|
|
|
raise escaped_tracer_error(
|
2021-01-18 20:37:12 -08:00
|
|
|
val, f"Can't lift level {val} to {self}")
|
2020-02-15 06:35:49 +01:00
|
|
|
else: # val._trace.level == self.level:
|
2021-01-05 14:52:54 -08:00
|
|
|
raise escaped_tracer_error(
|
2021-01-18 20:37:12 -08:00
|
|
|
val, f"Different traces at same level: {val}, {self}")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def pure(self, val):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lift(self, tracer):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def sublift(self, tracer):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __repr__(self):
|
|
|
|
return '{}(level={}/{})'.format(
|
|
|
|
self.__class__.__name__, self.level, self.sublevel)
|
|
|
|
|
2020-03-30 22:06:00 -07:00
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
2020-10-16 00:21:04 -07:00
|
|
|
msg = (f"{type(self)} must override process_call to handle call-like "
|
|
|
|
"primitives")
|
|
|
|
raise NotImplementedError(msg)
|
2020-03-30 11:57:03 -07:00
|
|
|
|
2021-06-16 11:10:42 -07:00
|
|
|
def process_map(self, map_primitive, f, tracers, params):
|
2020-10-16 00:21:04 -07:00
|
|
|
msg = (f"{type(self)} must override process_map to handle map-like "
|
|
|
|
"primitives")
|
|
|
|
raise NotImplementedError(msg)
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-03-30 11:57:03 -07:00
|
|
|
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
|
2020-10-16 00:21:04 -07:00
|
|
|
msg = (f"{type(self)} must override process_custom_jvp_call "
|
|
|
|
"to handle custom_jvp primitives")
|
|
|
|
raise NotImplementedError(msg)
|
2020-03-30 11:57:03 -07:00
|
|
|
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
2020-10-16 00:21:04 -07:00
|
|
|
msg = (f"{type(self)} must override process_custom_vjp_call "
|
|
|
|
"to handle custom_vjp primitives")
|
|
|
|
raise NotImplementedError(msg)
|
2020-03-30 11:57:03 -07:00
|
|
|
|
2021-01-20 10:56:13 -08:00
|
|
|
def escaped_tracer_error(tracer, detail=None):
|
|
|
|
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
|
2021-08-09 17:23:05 +01:00
|
|
|
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
|
|
|
|
'had a side effect, allowing for a reference to an intermediate value '
|
|
|
|
f'with shape {tracer.shape} and dtype {tracer.dtype} to escape.\n'
|
|
|
|
'JAX transformations require that functions explicitly return their '
|
|
|
|
'outputs, and disallow saving intermediate values to global state.')
|
2021-05-06 17:55:47 -07:00
|
|
|
dbg = getattr(tracer._trace.main, 'debug_info', None)
|
|
|
|
if dbg is not None:
|
2021-08-09 17:23:05 +01:00
|
|
|
msg += ('\nThe function being traced when the value leaked was '
|
2021-05-01 12:28:12 -07:00
|
|
|
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
|
2021-08-09 17:23:05 +01:00
|
|
|
line_info = getattr(tracer, '_line_info', None)
|
|
|
|
if line_info is not None:
|
|
|
|
divider = '\n' + '-'*30 + '\n'
|
|
|
|
msg += divider
|
|
|
|
msg += ('The leaked intermediate value was created on line '
|
|
|
|
f'{source_info_util.summarize(line_info)}. ')
|
|
|
|
msg += divider
|
|
|
|
if num_frames > 0:
|
|
|
|
msg += (f'When the value was created, the final {num_frames} stack '
|
|
|
|
'frames (most recent last) excluding JAX-internal frames were:')
|
|
|
|
msg += divider + source_info_util.summarize(
|
|
|
|
line_info, num_frames=num_frames) + divider
|
2021-01-21 21:29:09 -08:00
|
|
|
msg += ('\nTo catch the leak earlier, try setting the environment variable '
|
|
|
|
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
|
|
|
|
'manager.')
|
2021-08-09 17:23:05 +01:00
|
|
|
if detail:
|
|
|
|
msg += f'Detail: {detail}'
|
2020-09-16 15:59:50 -07:00
|
|
|
return UnexpectedTracerError(msg)
|
2020-01-15 15:00:38 -08:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Tracer:
|
2018-11-17 18:03:33 -08:00
|
|
|
__array_priority__ = 1000
|
2021-01-05 14:52:54 -08:00
|
|
|
__slots__ = ['_trace', '__weakref__', '_line_info']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
def __array__(self, *args, **kw):
|
2021-03-02 09:29:59 -08:00
|
|
|
raise TracerArrayConversionError(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-02-25 13:35:41 -08:00
|
|
|
def __index__(self):
|
2021-03-02 09:29:59 -08:00
|
|
|
raise TracerIntegerConversionError(self)
|
2021-02-25 13:35:41 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, trace: Trace):
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
return iter(self.aval._iter(self))
|
|
|
|
|
|
|
|
def __len__(self):
|
2018-12-15 20:00:10 -08:00
|
|
|
return self.aval._len(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
2020-01-15 15:00:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-09-16 15:59:50 -07:00
|
|
|
def _assert_live(self) -> None:
|
|
|
|
pass # Override for liveness checking
|
|
|
|
|
2020-08-18 09:40:57 -07:00
|
|
|
# Python looks up special methods only on classes, not instances. This means
|
|
|
|
# these methods needs to be defined explicitly rather than relying on
|
|
|
|
# __getattr__.
|
2018-11-17 18:03:33 -08:00
|
|
|
def __neg__(self): return self.aval._neg(self)
|
2019-11-18 22:00:32 -05:00
|
|
|
def __pos__(self): return self.aval._pos(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __eq__(self, other): return self.aval._eq(self, other)
|
|
|
|
def __ne__(self, other): return self.aval._ne(self, other)
|
|
|
|
def __lt__(self, other): return self.aval._lt(self, other)
|
|
|
|
def __le__(self, other): return self.aval._le(self, other)
|
|
|
|
def __gt__(self, other): return self.aval._gt(self, other)
|
|
|
|
def __ge__(self, other): return self.aval._ge(self, other)
|
|
|
|
def __abs__(self): return self.aval._abs(self)
|
|
|
|
def __add__(self, other): return self.aval._add(self, other)
|
|
|
|
def __radd__(self, other): return self.aval._radd(self, other)
|
|
|
|
def __sub__(self, other): return self.aval._sub(self, other)
|
|
|
|
def __rsub__(self, other): return self.aval._rsub(self, other)
|
|
|
|
def __mul__(self, other): return self.aval._mul(self, other)
|
|
|
|
def __rmul__(self, other): return self.aval._rmul(self, other)
|
|
|
|
def __div__(self, other): return self.aval._div(self, other)
|
|
|
|
def __rdiv__(self, other): return self.aval._rdiv(self, other)
|
|
|
|
def __truediv__(self, other): return self.aval._truediv(self, other)
|
2018-11-21 14:31:25 -08:00
|
|
|
def __rtruediv__(self, other): return self.aval._rtruediv(self, other)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __floordiv__(self, other): return self.aval._floordiv(self, other)
|
|
|
|
def __rfloordiv__(self, other): return self.aval._rfloordiv(self, other)
|
|
|
|
def __divmod__(self, other): return self.aval._divmod(self, other)
|
|
|
|
def __rdivmod__(self, other): return self.aval._rdivmod(self, other)
|
|
|
|
def __mod__(self, other): return self.aval._mod(self, other)
|
|
|
|
def __rmod__(self, other): return self.aval._rmod(self, other)
|
|
|
|
def __pow__(self, other): return self.aval._pow(self, other)
|
|
|
|
def __rpow__(self, other): return self.aval._rpow(self, other)
|
|
|
|
def __matmul__(self, other): return self.aval._matmul(self, other)
|
|
|
|
def __rmatmul__(self, other): return self.aval._rmatmul(self, other)
|
|
|
|
def __and__(self, other): return self.aval._and(self, other)
|
|
|
|
def __rand__(self, other): return self.aval._rand(self, other)
|
|
|
|
def __or__(self, other): return self.aval._or(self, other)
|
|
|
|
def __ror__(self, other): return self.aval._ror(self, other)
|
|
|
|
def __xor__(self, other): return self.aval._xor(self, other)
|
|
|
|
def __rxor__(self, other): return self.aval._rxor(self, other)
|
2019-02-15 14:09:06 -08:00
|
|
|
def __invert__(self): return self.aval._invert(self)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __lshift__(self, other): return self.aval._lshift(self, other)
|
2020-08-06 03:36:46 +02:00
|
|
|
def __rlshift__(self, other): return self.aval._rlshift(self, other)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __rshift__(self, other): return self.aval._rshift(self, other)
|
2020-08-06 03:36:46 +02:00
|
|
|
def __rrshift__(self, other): return self.aval._rrshift(self, other)
|
2018-11-17 18:03:33 -08:00
|
|
|
def __getitem__(self, idx): return self.aval._getitem(self, idx)
|
|
|
|
def __nonzero__(self): return self.aval._nonzero(self)
|
|
|
|
def __bool__(self): return self.aval._bool(self)
|
|
|
|
def __int__(self): return self.aval._int(self)
|
|
|
|
def __long__(self): return self.aval._long(self)
|
|
|
|
def __hex__(self): return self.aval._hex(self)
|
|
|
|
def __oct__(self): return self.aval._oct(self)
|
2020-09-15 08:06:46 -07:00
|
|
|
def __float__(self): return self.aval._float(self)
|
|
|
|
def __complex__(self): return self.aval._complex(self)
|
2020-04-03 21:33:32 -07:00
|
|
|
|
2021-09-01 20:43:13 -07:00
|
|
|
# raises the better error message from ShapedArray
|
|
|
|
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-08-18 09:40:57 -07:00
|
|
|
# NumPy also only looks up special methods on classes.
|
|
|
|
def __array_module__(self, types): return self.aval._array_module(self, types)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def __getattr__(self, name):
|
|
|
|
# if the aval property raises an AttributeError, gets caught here
|
2021-03-19 13:49:38 -07:00
|
|
|
assert not config.jax_enable_checks or name != "aval"
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
try:
|
|
|
|
attr = getattr(self.aval, name)
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise AttributeError(
|
2020-03-09 22:06:12 +02:00
|
|
|
"{} has no attribute {}".format(self.__class__.__name__, name)
|
|
|
|
) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
t = type(attr)
|
|
|
|
if t is aval_property:
|
|
|
|
return attr.fget(self)
|
|
|
|
elif t is aval_method:
|
2020-01-08 13:17:55 -05:00
|
|
|
return types.MethodType(attr.fun, self)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return attr
|
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def _pretty_print(self):
|
|
|
|
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
|
|
|
|
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
|
|
|
|
else pp.text(repr(attr))) for name, attr in self._contents()]
|
2020-04-02 21:04:12 -07:00
|
|
|
if contents:
|
2021-09-24 22:08:42 -04:00
|
|
|
base = pp.group(pp.nest(2, pp.concat([
|
|
|
|
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
|
|
|
|
pp.text('{} = '.format(name)) + pp_payload
|
|
|
|
for name, pp_payload in contents])
|
|
|
|
])))
|
|
|
|
return base
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return self._pretty_print().format()
|
2020-04-02 21:04:12 -07:00
|
|
|
|
|
|
|
def _contents(self):
|
|
|
|
try:
|
2021-07-22 15:38:16 -07:00
|
|
|
return [(name, getattr(self, name)) for name in self.__slots__]
|
2020-04-02 21:04:12 -07:00
|
|
|
except AttributeError:
|
|
|
|
return ()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-12-11 02:48:51 +00:00
|
|
|
def __copy__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __deepcopy__(self, unused_memo):
|
|
|
|
return self
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def _origin_msg(self) -> str:
|
|
|
|
return ""
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# these can be used to set up forwarding of properties and instance methods from
|
|
|
|
# Tracer instances to the underlying avals
|
|
|
|
aval_property = namedtuple("aval_property", ["fget"])
|
|
|
|
aval_method = namedtuple("aval_method", ["fun"])
|
|
|
|
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
class EvalTrace(Trace):
|
2020-10-07 09:18:15 +03:00
|
|
|
# See comments in https://github.com/google/jax/pull/3370
|
2020-07-30 12:59:36 -07:00
|
|
|
def pure(self, x): return x
|
|
|
|
lift = sublift = pure
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
return primitive.impl(*tracers, **params)
|
|
|
|
|
|
|
|
def process_call(self, primitive, f, tracers, params):
|
|
|
|
return primitive.impl(f, *tracers, **params)
|
|
|
|
process_map = process_call
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
|
|
|
|
del primitive, jvp # Unused.
|
2021-05-03 21:40:50 -07:00
|
|
|
with new_sublevel():
|
|
|
|
return fun.call_wrapped(*tracers)
|
2020-10-16 00:21:04 -07:00
|
|
|
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
del primitive, fwd, bwd, out_trees # Unused.
|
2021-05-03 21:40:50 -07:00
|
|
|
with new_sublevel():
|
|
|
|
return fun.call_wrapped(*tracers)
|
2020-10-16 00:21:04 -07:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-08-30 01:16:51 -07:00
|
|
|
class MainTrace:
|
2020-03-28 14:55:58 -07:00
|
|
|
level: int
|
|
|
|
trace_type: Type[Trace]
|
2020-10-26 10:11:13 +00:00
|
|
|
payload: Dict[str, Any]
|
2020-03-28 14:55:58 -07:00
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
def __init__(self, level, trace_type, **payload) -> None:
|
2018-11-17 18:03:33 -08:00
|
|
|
self.level = level
|
|
|
|
self.trace_type = trace_type
|
2020-10-26 10:11:13 +00:00
|
|
|
self.payload = payload
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def __repr__(self) -> str:
|
2020-08-30 01:16:51 -07:00
|
|
|
return "MainTrace({},{})".format(self.level, self.trace_type.__name__)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def __hash__(self) -> int:
|
2018-11-17 18:03:33 -08:00
|
|
|
return hash((self.level, self.trace_type))
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def __eq__(self, other: object) -> bool:
|
2020-08-30 01:16:51 -07:00
|
|
|
return (isinstance(other, MainTrace) and
|
2020-10-26 10:11:13 +00:00
|
|
|
self.level == other.level and
|
|
|
|
self.trace_type == other.trace_type and
|
|
|
|
self.payload == other.payload)
|
|
|
|
|
|
|
|
def with_cur_sublevel(self):
|
|
|
|
return self.trace_type(self, cur_sublevel(), **self.payload)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
class TraceStack:
|
2020-10-07 09:18:15 +03:00
|
|
|
# See comments in https://github.com/google/jax/pull/3370
|
2020-11-13 07:23:02 -08:00
|
|
|
stack: List[MainTrace]
|
|
|
|
dynamic: MainTrace
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __init__(self):
|
2020-09-15 08:06:46 -07:00
|
|
|
eval_trace = MainTrace(0, EvalTrace)
|
|
|
|
self.stack = [eval_trace]
|
|
|
|
self.dynamic = eval_trace
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def next_level(self) -> int:
|
|
|
|
return len(self.stack)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def push(self, main_trace: MainTrace) -> None:
|
|
|
|
self.stack.append(main_trace)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def pop(self) -> None:
|
|
|
|
self.stack.pop()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def __repr__(self) -> str:
|
2020-09-15 08:06:46 -07:00
|
|
|
stack_str = map(' {}\n'.format, self.stack[::-1])
|
|
|
|
return f'Trace stack\n{stack_str}\n{self.dynamic}'
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def copy(self):
|
2020-09-15 08:06:46 -07:00
|
|
|
new = self.__new__(TraceStack)
|
|
|
|
new.stack = self.stack[:]
|
|
|
|
new.dynamic = self.dynamic
|
2020-03-28 14:55:58 -07:00
|
|
|
return new
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-03-18 17:32:33 +00:00
|
|
|
|
|
|
|
@total_ordering
|
|
|
|
class Sublevel:
|
|
|
|
|
|
|
|
def __init__(self, level: int):
|
|
|
|
self.level = level
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return str(self.level)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return type(other) is Sublevel and self.level == other.level
|
|
|
|
|
|
|
|
def __lt__(self, other):
|
|
|
|
return type(other) is Sublevel and self.level < other.level
|
|
|
|
|
|
|
|
|
2020-08-30 12:38:14 +03:00
|
|
|
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
|
2020-11-24 09:58:44 -08:00
|
|
|
AxisName = Hashable
|
2019-07-23 09:53:27 -04:00
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
no_axis_name = object()
|
|
|
|
|
2020-07-26 22:38:14 -07:00
|
|
|
class TraceState:
|
2020-03-28 14:55:58 -07:00
|
|
|
trace_stack: TraceStack
|
|
|
|
substack: List[Sublevel]
|
2020-09-15 08:06:46 -07:00
|
|
|
axis_env: List[AxisEnvFrame]
|
2020-03-28 14:55:58 -07:00
|
|
|
|
|
|
|
def __init__(self) -> None:
|
2019-07-23 09:53:27 -04:00
|
|
|
self.trace_stack = TraceStack()
|
|
|
|
self.substack = [Sublevel(0)]
|
2020-09-15 08:06:46 -07:00
|
|
|
self.axis_env = []
|
2019-07-23 09:53:27 -04:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def copy(self):
|
2020-09-15 08:06:46 -07:00
|
|
|
new = self.__new__(TraceState)
|
2020-03-28 14:55:58 -07:00
|
|
|
new.trace_stack = self.trace_stack.copy()
|
|
|
|
new.substack = self.substack[:]
|
2020-09-15 08:06:46 -07:00
|
|
|
new.axis_env = self.axis_env[:]
|
2020-03-28 14:55:58 -07:00
|
|
|
return new
|
2020-07-26 22:38:14 -07:00
|
|
|
|
2021-03-23 12:00:19 -07:00
|
|
|
|
2021-08-09 15:59:35 -04:00
|
|
|
def _update_thread_local_jit_state(dynamic):
|
|
|
|
# Copies the MainTrace instance, removing any .debug_info or .jaxpr_stack
|
|
|
|
# fields that should not be kept alive as part of a cache key.
|
|
|
|
# TODO(mattjj): split debug_info and jaxpr_stack out of MainTrace.
|
|
|
|
# TODO(mattjj): add a test that verifies that JIT-ted functions are not kept
|
|
|
|
# alive by the JIT cache, particularly for nested JIT-ted functions.
|
|
|
|
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
|
|
|
|
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy)
|
|
|
|
|
|
|
|
|
2020-07-26 22:38:14 -07:00
|
|
|
# The global state of the tracer is accessed by a thread-local object.
|
|
|
|
# This allows concurrent tracing in separate threads; passing traced objects
|
|
|
|
# between threads is forbidden.
|
|
|
|
class ThreadLocalState(threading.local):
|
|
|
|
def __init__(self):
|
|
|
|
self.trace_state = TraceState()
|
2021-08-09 15:59:35 -04:00
|
|
|
_update_thread_local_jit_state(self.trace_state.trace_stack.dynamic)
|
2020-07-26 22:38:14 -07:00
|
|
|
thread_local_state = ThreadLocalState()
|
2019-07-23 09:53:27 -04:00
|
|
|
|
2020-10-07 09:18:15 +03:00
|
|
|
def trace_state_clean() -> bool:
|
|
|
|
trace_state = thread_local_state.trace_state
|
|
|
|
return (trace_state.substack == [Sublevel(0)] and
|
|
|
|
trace_state.axis_env == [] and
|
|
|
|
trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and
|
|
|
|
trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace))
|
|
|
|
|
2020-04-02 18:03:58 -07:00
|
|
|
def reset_trace_state() -> bool:
|
|
|
|
"Reset the global trace state and return True if it was already clean."
|
2020-10-07 09:18:15 +03:00
|
|
|
if not trace_state_clean():
|
2020-07-26 22:38:14 -07:00
|
|
|
thread_local_state.trace_state.__init__() # type: ignore
|
2020-04-02 18:03:58 -07:00
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def cur_sublevel() -> Sublevel:
|
2020-07-26 22:38:14 -07:00
|
|
|
return thread_local_state.trace_state.substack[-1]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-07-21 13:27:48 +01:00
|
|
|
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
|
|
|
|
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
|
|
|
|
|
|
|
|
It's possible there's none! eg. there's some cases where JAX itself holds a
|
|
|
|
reference to `x` inside of a lambda closure, and no tracers were leaked
|
|
|
|
by the user. In this case an empty list is returned.
|
|
|
|
"""
|
|
|
|
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
|
|
|
|
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
|
|
|
|
return tracers
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
@contextmanager
|
2020-10-26 10:11:13 +00:00
|
|
|
def new_main(trace_type: Type[Trace],
|
|
|
|
dynamic: bool = False,
|
|
|
|
**payload) -> Generator[MainTrace, None, None]:
|
2020-10-07 09:18:15 +03:00
|
|
|
# See comments in https://github.com/google/jax/pull/3370
|
2020-09-15 08:06:46 -07:00
|
|
|
stack = thread_local_state.trace_state.trace_stack
|
|
|
|
level = stack.next_level()
|
2020-10-26 10:11:13 +00:00
|
|
|
main = MainTrace(level, trace_type, **payload)
|
2020-09-15 08:06:46 -07:00
|
|
|
stack.push(main)
|
|
|
|
if dynamic:
|
|
|
|
prev_dynamic, stack.dynamic = stack.dynamic, main
|
2021-08-09 15:59:35 -04:00
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
try:
|
2020-08-30 01:16:51 -07:00
|
|
|
yield main
|
2018-11-17 18:03:33 -08:00
|
|
|
finally:
|
2021-01-19 18:38:53 -08:00
|
|
|
stack.pop()
|
2020-09-15 08:06:46 -07:00
|
|
|
if dynamic:
|
|
|
|
stack.dynamic = prev_dynamic
|
2021-08-09 15:59:35 -04:00
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-03-19 13:49:38 -07:00
|
|
|
if config.jax_check_tracer_leaks:
|
2020-08-30 01:16:51 -07:00
|
|
|
t = ref(main)
|
|
|
|
del main
|
2018-11-17 18:03:33 -08:00
|
|
|
if t() is not None:
|
2021-07-21 13:27:48 +01:00
|
|
|
leaked_tracers = maybe_find_leaked_tracers(t())
|
|
|
|
if leaked_tracers:
|
|
|
|
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
@contextmanager
|
|
|
|
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
2020-10-07 09:18:15 +03:00
|
|
|
# See comments in https://github.com/google/jax/pull/3370
|
2020-09-15 08:06:46 -07:00
|
|
|
stack = thread_local_state.trace_state.trace_stack
|
|
|
|
main = MainTrace(0, trace_type)
|
|
|
|
prev_dynamic, stack.dynamic = stack.dynamic, main
|
|
|
|
prev_base, stack.stack[0] = stack.stack[0], main
|
2021-08-09 15:59:35 -04:00
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
2020-09-15 08:06:46 -07:00
|
|
|
try:
|
|
|
|
yield main
|
|
|
|
finally:
|
|
|
|
stack.dynamic = prev_dynamic
|
|
|
|
stack.stack[0] = prev_base
|
2021-08-09 15:59:35 -04:00
|
|
|
_update_thread_local_jit_state(stack.dynamic)
|
2020-09-15 08:06:46 -07:00
|
|
|
|
2021-03-19 13:49:38 -07:00
|
|
|
if config.jax_check_tracer_leaks:
|
2021-01-19 18:38:53 -08:00
|
|
|
t = ref(main)
|
|
|
|
del main
|
|
|
|
if t() is not None:
|
2021-07-21 13:27:48 +01:00
|
|
|
leaked_tracers = maybe_find_leaked_tracers(t())
|
|
|
|
if leaked_tracers:
|
|
|
|
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
|
2021-01-19 18:38:53 -08:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
@contextmanager
|
|
|
|
def eval_context():
|
|
|
|
with new_base_main(EvalTrace):
|
|
|
|
yield
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
@contextmanager
|
2020-03-28 14:55:58 -07:00
|
|
|
def new_sublevel() -> Generator[None, None, None]:
|
2020-07-26 22:38:14 -07:00
|
|
|
sublevel = Sublevel(len(thread_local_state.trace_state.substack))
|
|
|
|
thread_local_state.trace_state.substack.append(sublevel)
|
2018-11-17 18:03:33 -08:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
2020-07-26 22:38:14 -07:00
|
|
|
thread_local_state.trace_state.substack.pop()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-03-19 13:49:38 -07:00
|
|
|
if config.jax_check_tracer_leaks:
|
2021-03-18 17:32:33 +00:00
|
|
|
t = ref(sublevel)
|
|
|
|
del sublevel
|
|
|
|
if t() is not None:
|
2021-07-21 13:27:48 +01:00
|
|
|
leaked_tracers = maybe_find_leaked_tracers(t())
|
|
|
|
if leaked_tracers:
|
|
|
|
raise Exception(f'Leaked sublevel {t()}. Leaked tracer(s): {leaked_tracers}.')
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
def full_lower(val):
|
|
|
|
if isinstance(val, Tracer):
|
|
|
|
return val.full_lower()
|
|
|
|
else:
|
|
|
|
return val
|
|
|
|
|
2021-04-09 12:43:40 +00:00
|
|
|
def find_top_trace(xs, axis_names=None) -> Trace:
|
|
|
|
top_main: Optional[MainTrace] = None
|
|
|
|
if axis_names:
|
|
|
|
top_main = max((axis_frame(a).main_trace for a in axis_names),
|
|
|
|
default=None, key=lambda t: getattr(t, 'level', -1))
|
2020-12-17 18:16:12 +01:00
|
|
|
top_tracer = max((x for x in xs if isinstance(x, Tracer)),
|
2021-04-09 12:43:40 +00:00
|
|
|
default=None, key=attrgetter('_trace.level'))
|
2020-12-17 18:16:12 +01:00
|
|
|
if top_tracer is not None:
|
|
|
|
top_tracer._assert_live()
|
2021-04-09 12:43:40 +00:00
|
|
|
if top_tracer._trace.main.level > getattr(top_main, 'level', -1):
|
|
|
|
top_main = top_tracer._trace.main
|
2020-09-15 08:06:46 -07:00
|
|
|
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
|
|
|
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
|
2020-11-10 11:10:06 +00:00
|
|
|
else top_main)
|
2020-10-26 10:11:13 +00:00
|
|
|
return top_main and top_main.with_cur_sublevel() # type: ignore
|
2020-07-31 22:20:58 -07:00
|
|
|
|
2020-03-28 14:55:58 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# -------------------- abstract values --------------------
|
|
|
|
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class AbstractValue:
|
2020-03-18 17:06:05 -04:00
|
|
|
__slots__: List[str] = []
|
2020-09-29 11:53:17 -07:00
|
|
|
_num_buffers: int = 1 # number of buffers used to represent the value.
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def at_least_vspace(self):
|
2020-12-29 11:43:44 -08:00
|
|
|
raise NotImplementedError("must override")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
try:
|
|
|
|
kv_pairs = ('{}={}'.format(k, v) for k, v in self.__dict__.items())
|
|
|
|
return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
|
|
|
|
except AttributeError:
|
|
|
|
return self.__class__.__name__
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def strip_weak_type(self) -> 'AbstractValue':
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
return self
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-02-05 10:39:22 -08:00
|
|
|
def strip_named_shape(self) -> 'AbstractValue':
|
|
|
|
return self
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def join(self, other):
|
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
2021-01-27 15:13:30 -08:00
|
|
|
def update(self, **kwargs):
|
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def str_short(self, short_dtypes=False):
|
2021-02-11 13:23:38 -08:00
|
|
|
raise NotImplementedError("must override")
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class Bot(AbstractValue): pass
|
|
|
|
|
|
|
|
bot = Bot()
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
class AbstractUnit(AbstractValue):
|
2020-09-29 11:53:17 -07:00
|
|
|
# TODO(jakevdp): make it possible to set zero buffers
|
|
|
|
# _num_buffers = 0
|
2020-12-29 11:43:44 -08:00
|
|
|
def at_least_vspace(self): return self
|
2020-05-01 09:16:31 +03:00
|
|
|
def join(self, other):
|
2021-03-19 13:49:38 -07:00
|
|
|
if config.jax_enable_checks:
|
2020-05-01 09:16:31 +03:00
|
|
|
assert other is abstract_unit, other
|
|
|
|
return self
|
2019-08-23 08:17:41 -07:00
|
|
|
def _eq(self, self_traced, other): return get_aval(other) is self
|
2021-09-24 22:08:42 -04:00
|
|
|
def str_short(self, short_dtypes=False): return '*'
|
2019-07-26 16:48:17 -04:00
|
|
|
|
|
|
|
abstract_unit = AbstractUnit()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def lattice_join(x: Optional[AbstractValue],
|
|
|
|
y: Optional[AbstractValue]) -> AbstractValue:
|
2018-11-17 18:03:33 -08:00
|
|
|
if x is None:
|
2020-06-02 19:10:55 -07:00
|
|
|
return cast(AbstractValue, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif y is None:
|
2020-06-02 19:10:55 -07:00
|
|
|
return cast(AbstractValue, x)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(x, type(y)):
|
|
|
|
return y.join(x)
|
|
|
|
elif isinstance(y, type(x)):
|
|
|
|
return x.join(y)
|
|
|
|
else:
|
2021-02-05 10:39:22 -08:00
|
|
|
raise TypeError(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
|
|
|
|
Value = Any
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def valid_jaxtype(x):
|
|
|
|
try:
|
|
|
|
concrete_aval(x)
|
|
|
|
except TypeError:
|
|
|
|
return False
|
2019-05-06 22:43:31 -07:00
|
|
|
else:
|
|
|
|
return True
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 13:24:40 -07:00
|
|
|
def check_valid_jaxtype(x):
|
|
|
|
if not valid_jaxtype(x):
|
2021-02-16 16:46:19 -05:00
|
|
|
raise TypeError(
|
|
|
|
f"Value {repr(x)} of type {type(x)} is not a valid JAX type")
|
2020-06-01 13:24:40 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def concrete_aval(x):
|
2020-05-07 01:46:13 -04:00
|
|
|
for typ in type(x).mro():
|
|
|
|
handler = pytype_aval_mappings.get(typ)
|
|
|
|
if handler: return handler(x)
|
2021-02-05 20:30:14 -08:00
|
|
|
if hasattr(x, '__jax_array__'):
|
|
|
|
return concrete_aval(x.__jax_array__())
|
2021-02-16 16:46:19 -05:00
|
|
|
raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
|
|
|
|
"type")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def get_aval(x):
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
return x.aval
|
|
|
|
else:
|
|
|
|
return concrete_aval(x)
|
|
|
|
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
class Unit:
|
2019-07-27 15:46:14 -07:00
|
|
|
def __repr__(self): return '*'
|
2019-07-26 16:48:17 -04:00
|
|
|
unit = Unit()
|
2019-07-27 15:46:14 -07:00
|
|
|
literalable_types.add(Unit)
|
|
|
|
|
2020-04-15 18:01:24 -07:00
|
|
|
class UnitVar(Var):
|
2020-05-21 18:28:09 -07:00
|
|
|
count = -1
|
2020-06-01 21:45:36 -04:00
|
|
|
suffix = ''
|
2020-05-21 18:28:09 -07:00
|
|
|
def __init__(self): pass
|
2020-03-09 09:14:23 +00:00
|
|
|
@property
|
|
|
|
def aval(self): return abstract_unit
|
2019-07-27 15:46:14 -07:00
|
|
|
def __repr__(self): return '*'
|
|
|
|
unitvar = UnitVar()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
pytype_aval_mappings[Unit] = lambda _: abstract_unit
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def concretization_function_error(fun, suggest_astype=False):
|
2020-03-09 09:14:23 +00:00
|
|
|
fname = getattr(fun, "__name__", fun)
|
2020-09-15 08:06:46 -07:00
|
|
|
fname_context = f"The problem arose with the `{fname}` function. "
|
|
|
|
if suggest_astype:
|
|
|
|
fname_context += ("If trying to convert the data type of a value, "
|
|
|
|
f"try using `x.astype({fun.__name__})` "
|
|
|
|
f"or `jnp.array(x, {fun.__name__})` instead.")
|
2020-04-22 10:25:06 +03:00
|
|
|
def error(self, arg):
|
2021-03-02 09:29:59 -08:00
|
|
|
raise ConcretizationTypeError(arg, fname_context)
|
2020-03-09 09:14:23 +00:00
|
|
|
return error
|
|
|
|
|
2020-07-03 20:54:25 -07:00
|
|
|
def concrete_or_error(force: Any, val: Any, context=""):
|
|
|
|
"""Like force(val), but gives the context in the error message."""
|
2020-09-25 14:18:46 -07:00
|
|
|
if force is None:
|
|
|
|
force = lambda x: x
|
2020-04-22 10:25:06 +03:00
|
|
|
if isinstance(val, Tracer):
|
|
|
|
if isinstance(val.aval, ConcreteArray):
|
2020-07-03 20:54:25 -07:00
|
|
|
return force(val.aval.val)
|
2020-04-22 10:25:06 +03:00
|
|
|
else:
|
2021-03-02 09:29:59 -08:00
|
|
|
raise ConcretizationTypeError(val, context)
|
2020-04-22 10:25:06 +03:00
|
|
|
else:
|
2020-07-03 20:54:25 -07:00
|
|
|
return force(val)
|
2020-04-22 10:25:06 +03:00
|
|
|
|
2021-03-21 13:39:57 -07:00
|
|
|
convert_element_type_p = Primitive('convert_element_type')
|
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
|
|
|
|
def _short_dtype_name(dtype):
|
|
|
|
return (dtype.name.replace('float', 'f').replace('uint', 'u')
|
|
|
|
.replace('int', 'i').replace('complex', 'c'))
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
class UnshapedArray(AbstractValue):
|
|
|
|
__slots__ = ['dtype', 'weak_type']
|
|
|
|
array_abstraction_level = 2
|
|
|
|
|
|
|
|
def __init__(self, dtype, weak_type=False):
|
2020-07-14 13:05:31 -07:00
|
|
|
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
|
2020-03-09 09:14:23 +00:00
|
|
|
self.weak_type = weak_type
|
|
|
|
|
2021-01-27 15:13:30 -08:00
|
|
|
def update(self, dtype=None, weak_type=None):
|
|
|
|
if dtype is None:
|
|
|
|
dtype = self.dtype
|
|
|
|
if weak_type is None:
|
|
|
|
weak_type = self.weak_type
|
|
|
|
return UnshapedArray(dtype, weak_type)
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other) and self.dtype == other.dtype and
|
|
|
|
self.weak_type == other.weak_type)
|
|
|
|
|
|
|
|
def __ne__(self, other):
|
|
|
|
return not self == other
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
2020-07-14 13:05:31 -07:00
|
|
|
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
2020-03-09 09:14:23 +00:00
|
|
|
# the unique character code via hash(self.dtype.char)
|
|
|
|
return hash((self.dtype, self.weak_type))
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
|
|
|
|
", weak_type=True" if self.weak_type else "")
|
|
|
|
|
|
|
|
_bool = _nonzero = concretization_function_error(bool)
|
2020-09-15 08:06:46 -07:00
|
|
|
_float = concretization_function_error(float, True)
|
|
|
|
_int = concretization_function_error(int, True)
|
|
|
|
_complex = concretization_function_error(complex, True)
|
2020-03-09 09:14:23 +00:00
|
|
|
_hex = concretization_function_error(hex)
|
|
|
|
_oct = concretization_function_error(oct)
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def at_least_vspace(self) -> AbstractValue:
|
2020-09-24 16:29:57 +01:00
|
|
|
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
|
|
|
|
self.weak_type)
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
def join(self, other):
|
|
|
|
if self.dtype == other.dtype:
|
|
|
|
if self.weak_type == other.weak_type:
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
return UnshapedArray(self.dtype, weak_type=False)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def str_short(self, short_dtypes=False) -> str:
|
|
|
|
return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2021-01-27 15:13:30 -08:00
|
|
|
def strip_weak_type(self):
|
2020-03-09 09:14:23 +00:00
|
|
|
"""Returns a copy of the aval with weak_type=False."""
|
2021-01-27 15:13:30 -08:00
|
|
|
return self.update(weak_type=False)
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2020-03-24 20:43:33 -07:00
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
msg = ("UnshapedArray has no shape. Please open an issue at "
|
|
|
|
"https://github.com/google/jax/issues because it's unexpected for "
|
|
|
|
"UnshapedArray instances to ever be produced.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
class ShapedArray(UnshapedArray):
|
2021-02-05 10:39:22 -08:00
|
|
|
__slots__ = ['shape', 'named_shape']
|
2020-03-09 09:14:23 +00:00
|
|
|
array_abstraction_level = 1
|
|
|
|
|
2021-02-05 10:39:22 -08:00
|
|
|
def __init__(self, shape, dtype, weak_type=False, named_shape={}):
|
2021-08-05 13:11:07 -07:00
|
|
|
super().__init__(dtype, weak_type=weak_type)
|
2020-03-09 09:14:23 +00:00
|
|
|
self.shape = canonicalize_shape(shape)
|
2021-04-14 15:29:45 +00:00
|
|
|
self.named_shape = dict(named_shape)
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2021-02-05 10:39:22 -08:00
|
|
|
def update(self, shape=None, dtype=None, weak_type=None, named_shape=None):
|
2021-01-27 15:13:30 -08:00
|
|
|
if shape is None:
|
|
|
|
shape = self.shape
|
|
|
|
if dtype is None:
|
|
|
|
dtype = self.dtype
|
|
|
|
if weak_type is None:
|
|
|
|
weak_type = self.weak_type
|
2021-02-05 10:39:22 -08:00
|
|
|
if named_shape is None:
|
|
|
|
named_shape = self.named_shape
|
|
|
|
return ShapedArray(shape, dtype, weak_type, named_shape)
|
2021-01-27 15:13:30 -08:00
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
ndim = property(lambda self: len(self.shape))
|
|
|
|
size = property(lambda self: prod(self.shape))
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
broadcast: ClassVar[Optional[aval_method]] = None
|
|
|
|
transpose: ClassVar[Optional[aval_method]] = None
|
|
|
|
reshape: ClassVar[Optional[aval_method]] = None
|
|
|
|
_iter: ClassVar[Optional[staticmethod]] = None
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(self) is type(other)
|
|
|
|
and self.dtype == other.dtype and self.shape == other.shape
|
2021-02-05 10:39:22 -08:00
|
|
|
and self.weak_type == other.weak_type
|
|
|
|
and self.named_shape == other.named_shape)
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
2020-07-14 13:05:31 -07:00
|
|
|
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
2020-03-09 09:14:23 +00:00
|
|
|
# the unique character code via hash(self.dtype.char)
|
2021-02-05 10:39:22 -08:00
|
|
|
return hash((self.shape, self.dtype, self.weak_type,
|
|
|
|
tuple(self.named_shape.items())))
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
def at_least_vspace(self):
|
2020-09-24 16:29:57 +01:00
|
|
|
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
2021-02-05 10:39:22 -08:00
|
|
|
self.weak_type, self.named_shape)
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
def join(self, other):
|
2021-06-23 10:52:03 +02:00
|
|
|
if symbolic_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
|
2021-02-05 10:39:22 -08:00
|
|
|
weak_type = self.weak_type and other.weak_type
|
|
|
|
named_shape = join_named_shapes(self.named_shape, other.named_shape)
|
|
|
|
return self.update(weak_type=weak_type, named_shape=named_shape)
|
2020-03-09 09:14:23 +00:00
|
|
|
elif self.dtype == other.dtype:
|
|
|
|
return UnshapedArray(self.dtype)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def str_short(self, short_dtypes=False):
|
|
|
|
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
2020-03-09 09:14:23 +00:00
|
|
|
shapestr = ','.join(map(str, self.shape))
|
2021-02-05 10:39:22 -08:00
|
|
|
if self.named_shape:
|
|
|
|
named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items())
|
2021-09-24 22:08:42 -04:00
|
|
|
return f'{dt_str}[{shapestr};{named_shapestr}]'
|
2021-02-05 10:39:22 -08:00
|
|
|
else:
|
2021-09-24 22:08:42 -04:00
|
|
|
return f'{dt_str}[{shapestr}]'
|
2021-02-05 10:39:22 -08:00
|
|
|
|
|
|
|
def strip_named_shape(self):
|
|
|
|
return self.update(named_shape={})
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
try:
|
|
|
|
return self.shape[0]
|
2020-09-30 01:20:00 +09:00
|
|
|
except IndexError as err:
|
|
|
|
raise TypeError("len() of unsized object") from err # same as numpy error
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
def _len(self, ignored_tracer):
|
|
|
|
return len(self)
|
|
|
|
|
|
|
|
|
|
|
|
def _forward_to_value(self, fun, ignored_tracer, *args):
|
|
|
|
return fun(self.val, *args)
|
|
|
|
|
|
|
|
class ConcreteArray(ShapedArray):
|
|
|
|
__slots__ = ['val']
|
|
|
|
array_abstraction_level = 0
|
|
|
|
|
|
|
|
def __init__(self, val, weak_type=False):
|
2021-08-05 13:11:07 -07:00
|
|
|
super().__init__(np.shape(val), np.result_type(val),
|
|
|
|
weak_type=weak_type)
|
2020-03-09 09:14:23 +00:00
|
|
|
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
|
|
|
self.val = val
|
2020-09-24 16:29:57 +01:00
|
|
|
assert self.dtype != np.dtype('O'), val
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2021-01-27 15:13:30 -08:00
|
|
|
def update(self, val=None, weak_type=None):
|
|
|
|
if val is None:
|
|
|
|
val = self.val
|
|
|
|
if weak_type is None:
|
|
|
|
weak_type = self.weak_type
|
|
|
|
return ConcreteArray(val, weak_type)
|
|
|
|
|
2020-03-09 09:14:23 +00:00
|
|
|
def __eq__(self, other):
|
2020-10-16 18:21:01 -07:00
|
|
|
if (type(self) is type(other) and self.dtype == other.dtype
|
|
|
|
and self.shape == other.shape and self.weak_type == other.weak_type):
|
|
|
|
with eval_context(): # in case self.val is a DeviceArray
|
|
|
|
return (self.val == other.val).all()
|
|
|
|
else:
|
|
|
|
return False
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return id(self.val)
|
|
|
|
|
2021-02-05 10:39:22 -08:00
|
|
|
def join(self, other) -> AbstractValue:
|
2020-03-09 09:14:23 +00:00
|
|
|
if self == other:
|
|
|
|
return self
|
|
|
|
elif self.shape == other.shape and self.dtype == other.dtype:
|
2021-02-05 10:39:22 -08:00
|
|
|
weak_type = self.weak_type and other.weak_type
|
2021-04-12 12:49:35 +00:00
|
|
|
named_shape = join_named_shapes(self.named_shape, other.named_shape)
|
2021-02-05 10:39:22 -08:00
|
|
|
return ShapedArray(
|
|
|
|
self.shape, self.dtype, weak_type=weak_type, named_shape=named_shape)
|
2020-03-09 09:14:23 +00:00
|
|
|
elif self.dtype == other.dtype:
|
|
|
|
return UnshapedArray(self.dtype,
|
|
|
|
weak_type=self.weak_type and other.weak_type)
|
|
|
|
else:
|
|
|
|
raise TypeError(self, other)
|
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def str_short(self, short_dtypes=False) -> str:
|
|
|
|
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
|
|
|
return f'{self.val}, dtype={dt_str}'
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
_bool = _nonzero = partialmethod(_forward_to_value, bool)
|
2020-09-15 08:06:46 -07:00
|
|
|
_int = partialmethod(_forward_to_value, int)
|
|
|
|
_hex = partialmethod(_forward_to_value, hex)
|
|
|
|
_oct = partialmethod(_forward_to_value, oct)
|
|
|
|
|
|
|
|
_float = concretization_function_error(float, True)
|
|
|
|
_complex = concretization_function_error(complex, True)
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2020-09-24 16:29:57 +01:00
|
|
|
def primal_dtype_to_tangent_dtype(primal_dtype):
|
|
|
|
if not dtypes.issubdtype(primal_dtype, np.inexact):
|
|
|
|
return dtypes.float0
|
|
|
|
else:
|
|
|
|
return primal_dtype
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
class AbstractToken(AbstractValue):
|
|
|
|
def join(self, other):
|
|
|
|
if isinstance(other, AbstractToken):
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
assert False, f"Cannot join {self} with {other}"
|
2021-09-24 22:08:42 -04:00
|
|
|
def str_short(self, short_dtypes=False): return 'Tok'
|
2021-01-22 10:57:33 -05:00
|
|
|
def at_least_vspace(self): return self
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2021-01-19 18:38:53 -08:00
|
|
|
abstract_token: AbstractToken = AbstractToken()
|
2020-03-09 09:14:23 +00:00
|
|
|
|
|
|
|
|
2020-10-07 11:41:22 -07:00
|
|
|
def raise_to_shaped(aval: AbstractValue, weak_type=None):
|
|
|
|
if weak_type is None:
|
|
|
|
weak_type = getattr(aval, 'weak_type', False)
|
2020-08-14 11:51:19 -07:00
|
|
|
for typ in type(aval).mro():
|
|
|
|
handler = raise_to_shaped_mappings.get(typ)
|
|
|
|
if handler: return handler(aval, weak_type)
|
|
|
|
raise TypeError(type(aval))
|
|
|
|
|
|
|
|
raise_to_shaped_mappings : Dict[type, Callable] = {
|
|
|
|
AbstractUnit: lambda aval, _: aval,
|
|
|
|
AbstractToken: lambda aval, _: aval,
|
2021-02-05 10:39:22 -08:00
|
|
|
Bot: lambda aval, _: aval,
|
|
|
|
UnshapedArray: lambda aval, _: aval,
|
|
|
|
ShapedArray: lambda aval, weak_type: ShapedArray(
|
|
|
|
aval.shape, aval.dtype, weak_type, aval.named_shape)
|
2020-08-14 11:51:19 -07:00
|
|
|
}
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
### Operations on shapes and dimension sizes.
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2021-04-01 15:37:01 +03:00
|
|
|
# Shapes are tuples of dimension sizes, which are normally integers. We allow
|
|
|
|
# modules to extend the set of dimension sizes to contain other types, e.g.,
|
2021-04-05 16:37:35 +03:00
|
|
|
# symbolic dimensions in jax2tf.shape_poly.DimVar and masking.Poly.
|
|
|
|
DimSize = Union[int, Any] # extensible
|
2021-04-01 15:37:01 +03:00
|
|
|
Shape = Sequence[DimSize]
|
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
|
|
|
|
class InconclusiveDimensionOperation(Exception):
|
|
|
|
"""Raised when we cannot conclusively compute with symbolic dimensions."""
|
|
|
|
pass
|
|
|
|
|
2021-04-01 15:37:01 +03:00
|
|
|
class DimensionHandler:
|
2021-04-05 16:37:35 +03:00
|
|
|
"""Operations on dimension sizes.
|
2021-04-01 15:37:01 +03:00
|
|
|
|
|
|
|
Dimension sizes are normally integer constants, but can also be symbolic,
|
|
|
|
e.g., masking.Poly or jax2tf.shape_poly.DimVar.
|
|
|
|
|
2021-04-09 13:46:28 +03:00
|
|
|
The base class works for integers only. Subclasses are invoked when at
|
|
|
|
least one of the operands has a type registered in _SPECIAL_DIMENSION_HANDLERS.
|
|
|
|
In that case, all operands are guaranteed to be either the special dimension
|
|
|
|
type, or Python integer scalars.
|
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
Subclasses should raise InconclusiveDimensionOperation if the result cannot
|
|
|
|
be computed in some contexts.
|
2021-04-01 15:37:01 +03:00
|
|
|
"""
|
2021-04-08 17:45:14 +03:00
|
|
|
def is_constant(self, d: DimSize) -> bool:
|
|
|
|
"""The dimension is a constant."""
|
|
|
|
return True
|
|
|
|
|
2021-04-01 15:37:01 +03:00
|
|
|
def symbolic_equal(self, d1: DimSize, d2: DimSize) -> bool:
|
2021-04-05 16:37:35 +03:00
|
|
|
"""True iff the dimension sizes are equal in all contexts; False otherwise.
|
|
|
|
Unlike `d1 == d2` this never raises InconclusiveDimensionOperation.
|
|
|
|
"""
|
2021-04-01 15:37:01 +03:00
|
|
|
return d1 == d2
|
|
|
|
|
|
|
|
def greater_equal(self, d1: DimSize, d2: DimSize) -> bool:
|
2021-04-05 16:37:35 +03:00
|
|
|
"""Computes `d1 >= d2`.
|
|
|
|
Raise InconclusiveDimensionOperation if the result is different in
|
|
|
|
different contexts.
|
|
|
|
"""
|
2021-04-01 15:37:01 +03:00
|
|
|
return d1 >= d2
|
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def sum(self, *ds: DimSize) -> DimSize:
|
|
|
|
"""Sum of dimensions.
|
|
|
|
Raises InconclusiveDimensionOperation if the result cannot be represented
|
|
|
|
by the same DimSize in all contexts.
|
|
|
|
"""
|
|
|
|
return sum(ds)
|
|
|
|
|
|
|
|
def diff(self, d1: DimSize, d2: DimSize) -> DimSize:
|
|
|
|
"""Difference of dimensions.
|
|
|
|
Raises InconclusiveDimensionOperation if the result cannot be represented
|
|
|
|
by the same DimSize in all contexts.
|
|
|
|
"""
|
|
|
|
return d1 - d2
|
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
|
|
|
|
"""Computes integer "i" such that i * size(s2) == size(s1).
|
2021-04-01 15:37:01 +03:00
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
Raise InconclusiveDimensionOperation if there is no such integer for all
|
|
|
|
contexts,
|
2021-04-01 15:37:01 +03:00
|
|
|
"""
|
2021-04-08 11:08:45 -04:00
|
|
|
sz1 = int(np.prod(s1))
|
|
|
|
sz2 = int(np.prod(s2))
|
2021-04-01 15:37:01 +03:00
|
|
|
if sz1 == 0 and sz2 == 0:
|
|
|
|
return 1
|
|
|
|
if sz1 % sz2:
|
2021-04-05 16:37:35 +03:00
|
|
|
raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
|
2021-04-01 15:37:01 +03:00
|
|
|
return sz1 // sz2
|
|
|
|
|
2021-04-04 16:23:24 +03:00
|
|
|
def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
|
|
|
|
"""(d - window_size) // window_stride + 1"""
|
|
|
|
return (d - window_size) // window_stride + 1
|
|
|
|
|
|
|
|
def dilate(self, d: DimSize, dilation: int) -> DimSize:
|
|
|
|
"""Implements `0 if d == 0 else 1 + dilation * (d - 1))`"""
|
|
|
|
return 0 if d == 0 else 1 + dilation * (d - 1)
|
|
|
|
|
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
|
|
|
def as_value(self, d: DimSize):
|
|
|
|
"""Turns a dimension size into a JAX value that we can compute with."""
|
|
|
|
return d
|
2021-04-01 15:37:01 +03:00
|
|
|
|
|
|
|
_dimension_handler_int = DimensionHandler()
|
|
|
|
_SPECIAL_DIMENSION_HANDLERS: Dict[type, DimensionHandler] = {}
|
|
|
|
|
2021-04-09 13:46:28 +03:00
|
|
|
def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple[DimSize, ...]]:
|
|
|
|
"""Finds the handler for the given dimensions; also returns the canonical dimensions.
|
|
|
|
|
|
|
|
A dimension is canonical if it is a Python integer scalar, or has a type
|
|
|
|
registered in _SPECIAL_DIMENSION_HANDLERS.
|
2021-04-05 16:37:35 +03:00
|
|
|
"""
|
2021-04-01 15:37:01 +03:00
|
|
|
special_handlers = set()
|
2021-04-09 13:46:28 +03:00
|
|
|
canonical = []
|
2021-04-01 15:37:01 +03:00
|
|
|
for d in dlist:
|
|
|
|
handler = _SPECIAL_DIMENSION_HANDLERS.get(type(d))
|
|
|
|
if handler:
|
|
|
|
special_handlers.add(handler)
|
2021-04-09 13:46:28 +03:00
|
|
|
canonical.append(d)
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
canonical.append(operator.index(d))
|
|
|
|
except TypeError:
|
|
|
|
raise _invalid_shape_error(dlist)
|
2021-04-08 06:21:12 -07:00
|
|
|
|
2021-04-09 13:46:28 +03:00
|
|
|
if len(special_handlers) > 1:
|
|
|
|
msg = (f"Dimension size operation involves multiple special dimension types {dlist}")
|
|
|
|
raise ValueError(msg)
|
|
|
|
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2021-04-08 17:45:14 +03:00
|
|
|
def is_constant_dim(d: DimSize) -> bool:
|
2021-04-12 05:50:19 -07:00
|
|
|
handler, ds = _dim_handler_and_canonical(d)
|
|
|
|
return handler.is_constant(*ds)
|
2021-04-08 17:45:14 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def symbolic_equal_dim(d1: DimSize, d2: DimSize) -> bool:
|
2021-04-09 13:46:28 +03:00
|
|
|
handler, ds = _dim_handler_and_canonical(d1, d2)
|
|
|
|
return handler.symbolic_equal(*ds)
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def symbolic_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
|
2021-04-09 13:46:28 +03:00
|
|
|
handler, ds = _dim_handler_and_canonical(d1, *dlist)
|
|
|
|
return any([handler.symbolic_equal(ds[0], d) for d in ds[1:]])
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:
|
2021-04-01 15:37:01 +03:00
|
|
|
return (len(s1) == len(s2) and
|
2021-04-06 11:43:06 +03:00
|
|
|
all(map(symbolic_equal_dim, s1, s2)))
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
|
2021-04-09 13:46:28 +03:00
|
|
|
handler, ds = _dim_handler_and_canonical(d1, d2)
|
|
|
|
return handler.greater_equal(*ds)
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def greater_equal_shape(s1: Shape, s2: Shape) -> bool:
|
2021-04-06 11:43:06 +03:00
|
|
|
return all(map(greater_equal_dim, s1, s2))
|
2021-04-04 17:05:18 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def sum_dim(*ds: DimSize) -> DimSize:
|
2021-04-09 13:46:28 +03:00
|
|
|
handler, ds = _dim_handler_and_canonical(*ds)
|
|
|
|
return handler.sum(*ds)
|
2021-04-04 17:05:18 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def sum_shapes(*ss: Shape) -> Shape:
|
2021-04-06 11:43:06 +03:00
|
|
|
return tuple(map(sum_dim, *ss))
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def diff_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
2021-04-09 13:46:28 +03:00
|
|
|
handler, ds = _dim_handler_and_canonical(d1, d2)
|
|
|
|
return handler.diff(*ds)
|
2021-04-05 16:37:35 +03:00
|
|
|
|
|
|
|
def diff_shape(s1: Shape, s2: Shape) -> Shape:
|
2021-04-06 11:43:06 +03:00
|
|
|
return tuple(map(diff_dim, s1, s2))
|
2021-04-05 16:37:35 +03:00
|
|
|
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize:
|
|
|
|
"""Returns an integer "i" s.t., i * size(s2) == size(s1).
|
|
|
|
Raises if there is no such integer."""
|
2021-04-09 13:46:28 +03:00
|
|
|
s1 = s1 or (1,)
|
|
|
|
s2 = s2 or (1,)
|
|
|
|
handler, ds = _dim_handler_and_canonical(*s1, *s2)
|
|
|
|
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def same_shape_sizes(s1: Shape, s2: Shape) -> bool:
|
|
|
|
return 1 == divide_shape_sizes(s1, s2)
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2021-04-09 14:02:44 +03:00
|
|
|
def is_empty_shape(s: Shape) -> bool:
|
|
|
|
return any(symbolic_equal_dim(d, 0) for d in s)
|
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
|
2021-04-04 16:23:24 +03:00
|
|
|
"""Implements `0 if d == 0 else 1 + dilation * (d - 1))`"""
|
2021-04-09 13:46:28 +03:00
|
|
|
handler, ds = _dim_handler_and_canonical(d, dilation)
|
|
|
|
return handler.dilate(*ds)
|
2021-04-04 16:23:24 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def dilate_shape(s: Shape, dilations: Sequence[int]) -> Shape:
|
2021-04-06 11:43:06 +03:00
|
|
|
return tuple(map(dilate_dim, s, dilations))
|
2021-04-04 16:23:24 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
|
2021-04-09 13:46:28 +03:00
|
|
|
handler, ds = _dim_handler_and_canonical(d, window_size, window_stride)
|
|
|
|
return handler.stride(*ds)
|
2021-04-04 16:23:24 +03:00
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape:
|
2021-04-04 16:23:24 +03:00
|
|
|
"""(s - window_size) // window_stride + 1"""
|
2021-04-06 11:43:06 +03:00
|
|
|
return tuple(map(stride_dim, s, window_size, window_stride))
|
2021-04-04 16:23:24 +03:00
|
|
|
|
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
|
|
|
def dimension_as_value(d: DimSize):
|
|
|
|
"""Turns a dimension size into a JAX value that we can compute with.
|
|
|
|
This is the identity function for constant dimensions."""
|
|
|
|
handler, ds = _dim_handler_and_canonical(d)
|
|
|
|
return handler.as_value(*ds)
|
2021-04-01 15:37:01 +03:00
|
|
|
|
|
|
|
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
|
|
|
if type(dim) in _SPECIAL_DIMENSION_HANDLERS:
|
2020-03-09 09:14:23 +00:00
|
|
|
return dim
|
|
|
|
else:
|
|
|
|
return operator.index(dim)
|
|
|
|
|
2021-08-03 09:12:04 +03:00
|
|
|
def canonicalize_shape(shape: Shape, context: str="") -> Shape:
|
2020-03-09 09:14:23 +00:00
|
|
|
"""Canonicalizes and checks for errors in a user-provided shape value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
shape: a Python value that represents a shape.
|
|
|
|
|
|
|
|
Returns:
|
2021-08-03 09:12:04 +03:00
|
|
|
A tuple of canonical dimension values.
|
2020-03-09 09:14:23 +00:00
|
|
|
"""
|
|
|
|
try:
|
|
|
|
return tuple(map(_canonicalize_dimension, shape))
|
|
|
|
except TypeError:
|
|
|
|
pass
|
2021-08-03 09:12:04 +03:00
|
|
|
raise _invalid_shape_error(shape, context)
|
2021-04-09 13:46:28 +03:00
|
|
|
|
2021-08-03 09:12:04 +03:00
|
|
|
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
|
|
|
|
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: a Python value that represents a dimension.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A canonical dimension value.
|
|
|
|
"""
|
|
|
|
return canonicalize_shape((d,), context)[0]
|
|
|
|
|
|
|
|
def _invalid_shape_error(shape: Shape, context: str=""):
|
2020-03-09 09:14:23 +00:00
|
|
|
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
|
2021-08-03 09:12:04 +03:00
|
|
|
f"got {shape}.")
|
|
|
|
if context:
|
|
|
|
msg += f" {context}."
|
2020-03-09 09:14:23 +00:00
|
|
|
if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
|
|
|
|
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
|
|
|
|
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
|
|
|
|
"smaller subfunctions.")
|
2021-08-03 09:12:04 +03:00
|
|
|
return TypeError(msg)
|
2020-03-09 09:14:23 +00:00
|
|
|
|
2021-02-04 12:38:12 +00:00
|
|
|
# ------------------- Named shapes -------------------
|
|
|
|
|
|
|
|
|
|
|
|
class NamedShape:
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.__positional = canonicalize_shape(args)
|
|
|
|
# TODO: Assert that kwargs match axis env?
|
|
|
|
self.__named = dict(kwargs)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def rank(self):
|
|
|
|
return len(self.__positional) + len(self.__named)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def positional_rank(self):
|
|
|
|
return len(self.__positional)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def named_rank(self):
|
|
|
|
return len(self.__named)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def positional(self):
|
|
|
|
return self.__positional
|
|
|
|
|
|
|
|
@property
|
|
|
|
def names(self):
|
|
|
|
return self.__named.keys()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def named_sizes(self):
|
2021-04-16 14:20:25 +01:00
|
|
|
return self.__named.values()
|
2021-02-04 12:38:12 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def named_items(self):
|
|
|
|
return self.__named.items()
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
try:
|
|
|
|
idx = operator.index(idx)
|
|
|
|
return self.__positional[idx]
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
return self.__named[idx]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def total(self):
|
|
|
|
total = 1
|
|
|
|
for s in self.__positional: total *= s
|
|
|
|
for s in self.__named.values(): total *= s
|
|
|
|
return total
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
|
|
|
|
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
if isinstance(other, NamedShape):
|
|
|
|
return (self.__positional, self.__named) == (other.__positional, other.__named)
|
|
|
|
if isinstance(other, tuple):
|
|
|
|
return not self.__named and self.__positional == other
|
|
|
|
raise TypeError(f"NamedShape doesn't support comparisons with {type(other)}")
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash((self.__positional, tuple(self.__named.items())))
|
|
|
|
|
2021-04-05 16:37:35 +03:00
|
|
|
def join_named_shapes(*named_shapes):
|
2021-04-14 15:29:45 +00:00
|
|
|
result = {}
|
|
|
|
for named_shape in named_shapes:
|
|
|
|
for name, size in named_shape.items():
|
|
|
|
if result.setdefault(name, size) != size:
|
|
|
|
raise TypeError(
|
|
|
|
f"Axis name {name} used with inconsistent sizes: {result[name]} != {size}")
|
|
|
|
return result
|
2021-04-05 16:37:35 +03:00
|
|
|
|
2021-02-04 12:38:12 +00:00
|
|
|
# TODO: Make canonicalize_shape return named shapes?
|
|
|
|
def as_named_shape(shape) -> NamedShape:
|
|
|
|
if isinstance(shape, NamedShape):
|
|
|
|
return shape
|
|
|
|
return NamedShape(*shape)
|
|
|
|
|
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
# ------------------- Call -------------------
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def apply_todos(todos, outs):
|
2020-01-05 04:32:48 +01:00
|
|
|
todos_list = list(todos)
|
|
|
|
while todos_list:
|
|
|
|
outs = map(full_lower, todos_list.pop()(outs))
|
2019-07-27 15:46:14 -07:00
|
|
|
return outs
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
class _IgnoreElemList(list):
|
|
|
|
"""Compares equal to all other _ignore_elem_lists."""
|
|
|
|
def __hash__(self): return 0
|
|
|
|
def __eq__(self, other):
|
|
|
|
return type(other) is _IgnoreElemList
|
|
|
|
|
2019-12-06 22:28:41 -08:00
|
|
|
@lu.transformation_with_aux
|
2020-06-23 09:39:45 -07:00
|
|
|
def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
level: int, params_tuple: tuple, out_axes_transforms, *args):
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = yield args, {}
|
|
|
|
params = dict(params_tuple)
|
2018-11-17 18:03:33 -08:00
|
|
|
todo = []
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
assert not out_axes_transforms
|
2019-07-27 15:46:14 -07:00
|
|
|
while True:
|
2020-07-30 12:59:36 -07:00
|
|
|
tracers = [x for x in outs if isinstance(x, Tracer)
|
|
|
|
and (level is None or x._trace.level > level)]
|
2019-07-27 15:46:14 -07:00
|
|
|
if tracers:
|
2020-01-29 16:23:27 -05:00
|
|
|
ans = max(tracers, key=lambda x: x._trace.level)
|
2019-07-27 15:46:14 -07:00
|
|
|
else:
|
|
|
|
break
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = ans._trace.main.with_cur_sublevel()
|
2019-07-27 15:46:14 -07:00
|
|
|
outs = map(trace.full_raise, outs)
|
2020-06-23 09:39:45 -07:00
|
|
|
outs, cur_todo = primitive.post_process(trace, outs, params)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
if isinstance(primitive, MapPrimitive):
|
|
|
|
cur_todo, out_axes_transform = cur_todo
|
|
|
|
out_axes_transforms.append(out_axes_transform)
|
2018-11-17 18:03:33 -08:00
|
|
|
todo.append(cur_todo)
|
2020-01-05 04:32:48 +01:00
|
|
|
yield outs, tuple(todo) # Ensure the aux output is immutable
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
2020-09-15 08:06:46 -07:00
|
|
|
fun, *args, **params):
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
out_axes_transforms = _IgnoreElemList()
|
|
|
|
if primitive.map_primitive:
|
|
|
|
out_axes_thunk = params['out_axes_thunk']
|
|
|
|
# The new thunk depends deterministically on the old thunk and the wrapped function.
|
|
|
|
# Any caching already has to include the wrapped function as part of the key, so we
|
|
|
|
# only use the previous thunk for equality checks.
|
2020-12-02 14:13:05 +00:00
|
|
|
@as_hashable_function(closure=out_axes_thunk)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def new_out_axes_thunk():
|
|
|
|
out_axes = out_axes_thunk()
|
|
|
|
for t in out_axes_transforms:
|
|
|
|
out_axes = t(out_axes)
|
|
|
|
return out_axes
|
|
|
|
params = dict(params, out_axes_thunk=new_out_axes_thunk)
|
2020-06-23 09:39:45 -07:00
|
|
|
params_tuple = tuple(params.items())
|
2018-11-17 18:03:33 -08:00
|
|
|
top_trace = find_top_trace(args)
|
2020-09-15 08:06:46 -07:00
|
|
|
fun, env_trace_todo = process_env_traces(
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
fun, primitive, top_trace and top_trace.level,
|
|
|
|
params_tuple, out_axes_transforms)
|
2020-09-15 08:06:46 -07:00
|
|
|
tracers = map(top_trace.full_raise, args)
|
2021-05-03 21:40:50 -07:00
|
|
|
outs = primitive.process(top_trace, fun, tracers, params)
|
2020-09-15 08:06:46 -07:00
|
|
|
return map(full_lower, apply_todos(env_trace_todo(), outs))
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
class CallPrimitive(Primitive):
|
|
|
|
multiple_results = True
|
|
|
|
call_primitive = True
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def bind(self, fun, *args, **params):
|
|
|
|
return call_bind(self, fun, *args, **params)
|
2020-06-23 09:39:45 -07:00
|
|
|
|
|
|
|
def process(self, trace, fun, tracers, params):
|
|
|
|
return trace.process_call(self, fun, tracers, params)
|
2020-04-21 18:12:02 -07:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
def post_process(self, trace, out_tracers, params):
|
|
|
|
return trace.post_process_call(self, out_tracers, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def call_impl(f: lu.WrappedFun, *args, **params):
|
2019-11-22 10:53:11 -08:00
|
|
|
del params # params parameterize the call primitive, not the function
|
2021-05-03 21:40:50 -07:00
|
|
|
with new_sublevel():
|
|
|
|
return f.call_wrapped(*args)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
call_p = CallPrimitive('call')
|
|
|
|
call = call_p.bind
|
2018-11-17 18:03:33 -08:00
|
|
|
call_p.def_impl(call_impl)
|
|
|
|
|
2020-11-12 17:36:46 -08:00
|
|
|
named_call_p = CallPrimitive('named_call')
|
|
|
|
named_call_p.def_impl(call_impl)
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
# ------------------- Map -------------------
|
|
|
|
|
2021-02-13 08:19:31 -08:00
|
|
|
def mapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue:
|
|
|
|
handler, _ = aval_mapping_handlers.get(type(aval), (None, None))
|
|
|
|
if handler is not None:
|
|
|
|
return handler(size, axis, aval)
|
|
|
|
else:
|
|
|
|
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")
|
|
|
|
|
2021-09-08 01:41:38 -07:00
|
|
|
def unmapped_aval(size: int, axis_name, axis: int, aval: AbstractValue) -> AbstractValue:
|
2021-02-13 08:19:31 -08:00
|
|
|
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
|
|
|
|
if handler is not None:
|
2021-09-08 01:41:38 -07:00
|
|
|
return handler(size, axis_name, axis, aval)
|
2021-02-13 08:19:31 -08:00
|
|
|
else:
|
|
|
|
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
|
|
|
|
|
2021-09-08 01:41:38 -07:00
|
|
|
def _map_unit(*_) -> AbstractUnit:
|
|
|
|
return abstract_unit
|
2021-02-13 08:19:31 -08:00
|
|
|
|
|
|
|
def _map_shaped_array(size: int, axis: int, aval: ShapedArray) -> ShapedArray:
|
|
|
|
assert aval.shape[axis] == size
|
2021-09-08 01:41:38 -07:00
|
|
|
# TODO: Extend the named shape
|
2021-09-07 03:25:54 -07:00
|
|
|
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
|
|
|
|
named_shape=aval.named_shape)
|
2021-02-13 08:19:31 -08:00
|
|
|
|
2021-09-08 01:41:38 -07:00
|
|
|
def _unmap_shaped_array(size: int, axis_name, axis: int, aval: ShapedArray) -> ShapedArray:
|
|
|
|
named_shape = dict(aval.named_shape)
|
|
|
|
# TODO: Make this mandatory
|
|
|
|
named_shape.pop(axis_name, None)
|
2021-09-07 03:25:54 -07:00
|
|
|
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
2021-09-08 01:41:38 -07:00
|
|
|
named_shape=named_shape)
|
2021-02-13 08:19:31 -08:00
|
|
|
|
|
|
|
AvalMapHandlerPair = Tuple[Callable, Callable]
|
|
|
|
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
|
|
|
|
AbstractUnit: (_map_unit, _map_unit),
|
|
|
|
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
|
|
|
|
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
class MapPrimitive(Primitive):
|
|
|
|
multiple_results = True
|
|
|
|
map_primitive = True
|
|
|
|
|
|
|
|
def bind(self, fun, *args, **params):
|
2020-11-05 11:54:05 +00:00
|
|
|
assert len(params['in_axes']) == len(args)
|
2020-06-23 09:39:45 -07:00
|
|
|
return call_bind(self, fun, *args, **params)
|
|
|
|
|
|
|
|
def process(self, trace, fun, tracers, params):
|
|
|
|
return trace.process_map(self, fun, tracers, params)
|
|
|
|
|
|
|
|
def post_process(self, trace, out_tracers, params):
|
|
|
|
return trace.post_process_map(self, out_tracers, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-08-14 18:22:04 +02:00
|
|
|
@contextmanager
|
2020-11-19 11:38:06 +00:00
|
|
|
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
|
2020-09-15 08:06:46 -07:00
|
|
|
frame = AxisEnvFrame(axis_name, size, tag)
|
|
|
|
thread_local_state.trace_state.axis_env.append(frame)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
thread_local_state.trace_state.axis_env.pop()
|
|
|
|
|
2020-11-19 11:38:06 +00:00
|
|
|
@contextmanager
|
|
|
|
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
|
|
|
|
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
|
|
|
|
thread_local_state.trace_state.axis_env.extend(frames)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
for _ in frames:
|
|
|
|
thread_local_state.trace_state.axis_env.pop()
|
|
|
|
|
2020-10-12 14:52:47 +02:00
|
|
|
|
|
|
|
# When a mapped function is given no axis name, we generate a name object based
|
|
|
|
# on the id of the function object. Collisions aren't important because this
|
|
|
|
# name can't be used in collectives, as user code never gets a ref to this
|
|
|
|
# object. We don't want to use the function object itself because that might
|
|
|
|
# persist references to the function object.
|
|
|
|
# TODO(mattjj): revisit this unique axis name strategy
|
2021-02-05 10:39:22 -08:00
|
|
|
@total_ordering
|
2020-10-12 14:52:47 +02:00
|
|
|
class _TempAxisName:
|
|
|
|
|
|
|
|
def __init__(self, obj):
|
|
|
|
self.id = id(obj)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f'<axis {hex(self.id)}>'
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.id)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return type(other) is _TempAxisName and self.id == other.id
|
|
|
|
|
2021-02-05 10:39:22 -08:00
|
|
|
def __lt__(self, other):
|
|
|
|
return type(other) is _TempAxisName and self.id < other.id
|
|
|
|
|
2020-10-12 14:52:47 +02:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def axis_frame(axis_name):
|
|
|
|
frames = thread_local_state.trace_state.axis_env
|
|
|
|
for frame in reversed(frames):
|
|
|
|
if frame.name == axis_name:
|
|
|
|
return frame
|
2020-11-24 09:58:44 -08:00
|
|
|
named_axes = [frame.name for frame in reversed(frames)
|
|
|
|
if not isinstance(frame.name, _TempAxisName)]
|
2020-10-12 14:52:47 +02:00
|
|
|
raise NameError(
|
|
|
|
f'unbound axis name: {axis_name}. The following axis names (e.g. defined '
|
2020-11-24 09:58:44 -08:00
|
|
|
f'by pmap) are available to collective operations: {named_axes}')
|
2020-09-15 08:06:46 -07:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2021-01-26 18:54:00 +00:00
|
|
|
ParamDict = Dict[str, Any]
|
|
|
|
AxisSubst = Callable[[AxisName], Tuple[AxisName, ...]]
|
|
|
|
|
2021-09-10 07:09:26 -07:00
|
|
|
class NameGatheringSubst:
|
|
|
|
def __init__(self):
|
|
|
|
self.axis_names = set()
|
|
|
|
def __call__(self, axis_name):
|
|
|
|
self.axis_names.add(axis_name)
|
2021-01-26 18:54:00 +00:00
|
|
|
return (axis_name,)
|
2021-09-10 07:09:26 -07:00
|
|
|
|
|
|
|
def used_axis_names(primitive: Primitive, params: ParamDict) -> Set[AxisName]:
|
|
|
|
subst = NameGatheringSubst()
|
|
|
|
subst_axis_names(primitive, params, subst)
|
|
|
|
return subst.axis_names
|
2021-01-26 18:54:00 +00:00
|
|
|
|
2021-06-07 12:43:36 +00:00
|
|
|
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst, traverse: bool = True) -> ParamDict:
|
2021-01-26 18:54:00 +00:00
|
|
|
if primitive in axis_substitution_rules:
|
2021-06-07 12:43:36 +00:00
|
|
|
return axis_substitution_rules[primitive](params, subst, traverse)
|
|
|
|
if not traverse:
|
|
|
|
return params
|
2021-03-08 12:27:01 +00:00
|
|
|
# Default implementation: substitute names in all jaxpr parameters
|
|
|
|
if isinstance(primitive, MapPrimitive):
|
|
|
|
def shadowed_subst(name):
|
|
|
|
return (name,) if name == params['axis_name'] else subst(name)
|
|
|
|
else:
|
|
|
|
shadowed_subst = subst
|
|
|
|
jaxpr_params = [(n, v) for n, v in params.items() if isinstance(v, (Jaxpr, ClosedJaxpr))]
|
|
|
|
if not jaxpr_params:
|
|
|
|
return params
|
|
|
|
new_params = dict(params)
|
|
|
|
for name, jaxpr in jaxpr_params:
|
|
|
|
new_params[name] = subst_axis_names_jaxpr(jaxpr, shadowed_subst)
|
|
|
|
return new_params
|
|
|
|
|
2021-04-14 15:23:52 +00:00
|
|
|
class DuplicateAxisNameError(Exception):
|
|
|
|
def __init__(self, var):
|
|
|
|
self.var = var
|
|
|
|
self.eqn = None
|
|
|
|
|
|
|
|
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> Var:
|
|
|
|
# Var identity is load-bearing, so we can't have duplicates!
|
|
|
|
if v is unitvar: return v
|
|
|
|
if v is dropvar: return v
|
|
|
|
assert v not in var_map
|
|
|
|
if not hasattr(v.aval, 'named_shape'):
|
|
|
|
var_map[v] = v
|
|
|
|
return v
|
|
|
|
names = tuple(it.chain.from_iterable(subst(name) for name in v.aval.named_shape))
|
|
|
|
named_shape = {name: axis_frame(name).size for name in names}
|
|
|
|
if len(named_shape) != len(names):
|
|
|
|
raise DuplicateAxisNameError(v)
|
|
|
|
new_v = Var(v.count, v.suffix, v.aval.update(named_shape=named_shape))
|
|
|
|
var_map[v] = new_v
|
|
|
|
return new_v
|
|
|
|
|
|
|
|
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: Dict[Var, Var]) -> JaxprEqn:
|
|
|
|
invars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
|
|
|
|
try:
|
|
|
|
outvars = [subst_axis_names_var(v, subst, var_map) for v in eqn.outvars]
|
|
|
|
except DuplicateAxisNameError as e:
|
|
|
|
e.eqn = eqn
|
|
|
|
raise
|
|
|
|
params = subst_axis_names(eqn.primitive, eqn.params, subst)
|
2021-06-16 11:10:42 -07:00
|
|
|
return new_jaxpr_eqn(invars, outvars, eqn.primitive, params, eqn.source_info)
|
2021-04-14 15:23:52 +00:00
|
|
|
|
2021-09-10 07:09:26 -07:00
|
|
|
def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
|
2021-03-08 12:27:01 +00:00
|
|
|
consts = None
|
|
|
|
if isinstance(jaxpr, ClosedJaxpr):
|
|
|
|
consts = jaxpr.consts
|
|
|
|
jaxpr = jaxpr.jaxpr
|
2021-08-26 13:34:01 -07:00
|
|
|
var_map: Dict[Var, Var] = {unitvar: unitvar}
|
2021-04-14 15:23:52 +00:00
|
|
|
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars]
|
|
|
|
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars]
|
|
|
|
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns]
|
|
|
|
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars]
|
|
|
|
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns)
|
2021-03-08 12:27:01 +00:00
|
|
|
if consts is not None:
|
|
|
|
return ClosedJaxpr(new_jaxpr, consts)
|
|
|
|
return new_jaxpr
|
2021-01-26 18:54:00 +00:00
|
|
|
|
2021-09-10 07:09:26 -07:00
|
|
|
@cache()
|
|
|
|
def used_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr]):
|
|
|
|
subst = NameGatheringSubst()
|
|
|
|
do_subst_axis_names_jaxpr(jaxpr, subst)
|
|
|
|
return frozenset(subst.axis_names)
|
|
|
|
|
|
|
|
def subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
|
|
|
|
if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it!
|
|
|
|
subst.axis_names |= used_axis_names_jaxpr(jaxpr)
|
|
|
|
return jaxpr
|
|
|
|
return do_subst_axis_names_jaxpr(jaxpr, subst)
|
|
|
|
|
|
|
|
|
2021-06-07 12:43:36 +00:00
|
|
|
axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
|
2021-01-26 18:54:00 +00:00
|
|
|
|
2021-04-09 12:43:40 +00:00
|
|
|
# ------------------- AxisPrimitive -------------------
|
|
|
|
# Primitives that store axis names in params and want those axis names to
|
|
|
|
# participate in dispatch should subclass AxisPrimitive.
|
|
|
|
|
|
|
|
class AxisPrimitive(Primitive):
|
|
|
|
_dispatch_on_params = True
|
|
|
|
|
2020-04-15 11:05:32 -07:00
|
|
|
# ------------------- Jaxpr checking -------------------
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def typecheck(aval: AbstractValue, x) -> bool:
|
|
|
|
return typecompat(aval, get_aval(x))
|
2020-04-15 17:02:01 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
|
2021-02-05 10:39:22 -08:00
|
|
|
"""Determine whether `aval` conforms to `aval_ref`.
|
|
|
|
|
|
|
|
Ignores weak_type and named_shape, other than to check that an axis name isn't
|
|
|
|
used with different sizes.
|
|
|
|
"""
|
2020-04-15 17:02:01 -07:00
|
|
|
try:
|
2021-02-05 10:39:22 -08:00
|
|
|
return typematch(aval_ref, lattice_join(aval_ref, aval))
|
2020-04-15 17:02:01 -07:00
|
|
|
except TypeError:
|
|
|
|
return False
|
|
|
|
|
2021-02-05 10:39:22 -08:00
|
|
|
def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
|
|
|
|
"""Determine whether `aval1` and `aval2` are equivalent.
|
|
|
|
|
|
|
|
Ignores weak_type and named_shape, other than to check that an axis name isn't
|
|
|
|
used with different sizes.
|
|
|
|
"""
|
2021-02-11 13:23:38 -08:00
|
|
|
if aval1 == aval2: return True
|
|
|
|
# unequal avals may still represent the same type, because type is represented
|
|
|
|
# by avals at the shaped level, and because weak type tags and (for now) named
|
|
|
|
# shape components aren't considered part of the type
|
2021-02-05 10:39:22 -08:00
|
|
|
if isinstance(aval1, ShapedArray) and isinstance(aval2, ShapedArray):
|
2021-02-11 13:23:38 -08:00
|
|
|
# a bonus check for whether any named axes have inconsistent sizes
|
2021-02-05 10:39:22 -08:00
|
|
|
join_named_shapes(aval1.named_shape, aval2.named_shape)
|
2021-02-11 13:23:38 -08:00
|
|
|
return (raise_to_shaped(aval1, weak_type=False).strip_named_shape() ==
|
|
|
|
raise_to_shaped(aval2, weak_type=False).strip_named_shape())
|
2020-04-15 17:02:01 -07:00
|
|
|
|
2020-06-24 15:31:33 -07:00
|
|
|
class JaxprTypeError(TypeError): pass
|
|
|
|
|
|
|
|
def typecheck_assert(pred, msg):
|
|
|
|
if not pred:
|
|
|
|
raise JaxprTypeError(msg)
|
|
|
|
|
2020-06-25 10:17:54 -07:00
|
|
|
custom_typechecks: Dict[Primitive, Callable] = {}
|
|
|
|
|
2020-03-21 13:54:30 +01:00
|
|
|
def check_jaxpr(jaxpr: Jaxpr):
|
2020-02-05 15:38:25 +01:00
|
|
|
"""Checks well-formedness of a jaxpr.
|
|
|
|
|
2020-04-15 17:02:48 -07:00
|
|
|
Specifically, check that:
|
|
|
|
- variables that are read are bound beforehand
|
|
|
|
- variables are typed equally throughout a jaxpr
|
|
|
|
- variable type annotations are compatible with their binding expression
|
2020-05-21 20:02:30 -07:00
|
|
|
|
2020-11-13 18:00:33 -08:00
|
|
|
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
|
|
|
|
otherwise.
|
2020-02-05 15:38:25 +01:00
|
|
|
"""
|
2020-05-21 20:02:30 -07:00
|
|
|
try:
|
2020-06-02 19:10:55 -07:00
|
|
|
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
|
2020-06-24 15:31:33 -07:00
|
|
|
except JaxprTypeError as e:
|
2020-06-25 17:10:56 -07:00
|
|
|
if len(e.args) == 2:
|
|
|
|
msg, eqn_idx = e.args
|
|
|
|
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqn_idx - 10, eqn_idx + 10))
|
|
|
|
else:
|
|
|
|
msg, = e.args
|
|
|
|
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20))
|
2020-06-29 12:06:57 -07:00
|
|
|
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
|
|
|
|
raise JaxprTypeError(msg) from None
|
2020-05-21 20:02:30 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
2020-04-14 22:22:35 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def read(v: Atom) -> AbstractValue:
|
|
|
|
if isinstance(v, Literal):
|
2020-10-20 11:08:39 +03:00
|
|
|
return raise_to_shaped(get_aval(v.val))
|
2020-06-02 19:10:55 -07:00
|
|
|
else:
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert(v in env, f"Variable '{v}' not defined")
|
2020-06-02 19:10:55 -07:00
|
|
|
return env[v]
|
2020-04-14 22:22:35 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def write(v: Var, a: AbstractValue) -> None:
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert(v not in env, f"Variable '{v}' already bound")
|
2020-06-08 16:13:30 -07:00
|
|
|
if v is not dropvar:
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert(typecompat(v.aval, a),
|
|
|
|
f"Variable '{v}' inconsistently typed as {a}, "
|
|
|
|
f"bound as {v.aval}")
|
2020-06-08 16:13:30 -07:00
|
|
|
env[v] = a
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
env : Dict[Var, AbstractValue] = {}
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
write(unitvar, abstract_unit)
|
|
|
|
map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars])
|
|
|
|
map(write, jaxpr.invars, in_avals)
|
2020-05-21 13:11:58 -07:00
|
|
|
|
2020-06-25 17:10:56 -07:00
|
|
|
for eqn_idx, eqn in enumerate(jaxpr.eqns):
|
2020-06-25 10:17:54 -07:00
|
|
|
prim = eqn.primitive
|
2020-06-02 19:10:55 -07:00
|
|
|
try:
|
2020-10-05 12:29:43 -07:00
|
|
|
in_avals = map(read, eqn.invars)
|
2020-10-20 11:08:39 +03:00
|
|
|
typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals),
|
2020-10-21 07:53:37 +03:00
|
|
|
"Equation given ConcreteArray type inputs")
|
2020-06-25 10:17:54 -07:00
|
|
|
if prim in custom_typechecks:
|
2021-04-20 11:40:32 -07:00
|
|
|
out_avals = custom_typechecks[prim](*in_avals, **eqn.params)
|
|
|
|
if out_avals is None:
|
|
|
|
out_avals = [v.aval for v in eqn.outvars]
|
|
|
|
elif prim.call_primitive:
|
2020-06-25 10:17:54 -07:00
|
|
|
out_avals = check_call(prim, in_avals, eqn.params)
|
|
|
|
elif prim.map_primitive:
|
|
|
|
out_avals = check_map(prim, in_avals, eqn.params)
|
|
|
|
else:
|
|
|
|
out_avals = check_eqn(prim, in_avals, eqn.params)
|
2020-06-02 19:10:55 -07:00
|
|
|
map(write, eqn.outvars, out_avals)
|
2020-06-24 15:31:33 -07:00
|
|
|
except JaxprTypeError as e:
|
2020-06-02 19:10:55 -07:00
|
|
|
msg, = e.args
|
2020-06-29 12:06:57 -07:00
|
|
|
src = source_info_util.summarize(eqn.source_info)
|
2021-09-24 22:08:42 -04:00
|
|
|
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn))),
|
2020-06-29 12:06:57 -07:00
|
|
|
f"from source: {src}"])
|
|
|
|
raise JaxprTypeError(msg, eqn_idx) from None
|
2020-06-02 19:10:55 -07:00
|
|
|
|
|
|
|
map(read, jaxpr.outvars)
|
|
|
|
|
|
|
|
def check_eqn(prim, in_avals, params):
|
2020-05-26 19:32:29 -07:00
|
|
|
for jaxpr in jaxprs_in_params(params):
|
|
|
|
check_jaxpr(jaxpr)
|
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
out_avals = prim.abstract_eval(*in_avals, **params)
|
|
|
|
if not prim.multiple_results:
|
|
|
|
out_avals = [out_avals]
|
|
|
|
return out_avals
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def check_call(prim, in_avals, params):
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert("call_jaxpr" in params,
|
|
|
|
f"Call primitive {prim} missing 'call_jaxpr' parameter")
|
2020-06-02 19:10:55 -07:00
|
|
|
call_jaxpr = params["call_jaxpr"]
|
|
|
|
|
|
|
|
# These checks also happen in recursive call, but give better errors here.
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert(len(in_avals) == len(call_jaxpr.invars),
|
|
|
|
f"Call primitive {prim} with {len(call_jaxpr.invars)} "
|
|
|
|
f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
|
|
|
|
f"inputs")
|
2020-06-02 19:10:55 -07:00
|
|
|
binder_avals = [v.aval for v in call_jaxpr.invars]
|
|
|
|
for binder_aval, in_aval in zip(binder_avals, in_avals):
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert(typecompat(binder_aval, in_aval),
|
|
|
|
f"Call primitive {prim} passes operand {in_aval} "
|
|
|
|
f"to jaxpr expecting {binder_aval}")
|
2020-05-21 13:11:58 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
_check_jaxpr(call_jaxpr, in_avals)
|
2020-05-21 13:11:58 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
out_avals = [v.aval for v in call_jaxpr.outvars]
|
|
|
|
return out_avals
|
2020-05-21 13:11:58 -07:00
|
|
|
|
2020-06-02 19:10:55 -07:00
|
|
|
def check_map(prim, in_avals, params):
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert("call_jaxpr" in params,
|
|
|
|
f"Map primitive {prim} missing 'call_jaxpr' parameter")
|
2020-06-02 19:10:55 -07:00
|
|
|
call_jaxpr = params["call_jaxpr"]
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert("axis_size" in params,
|
|
|
|
f"Map primitive {prim} missing 'axis_size' parameter")
|
2020-06-02 19:10:55 -07:00
|
|
|
axis_size = params["axis_size"]
|
2021-09-08 01:41:38 -07:00
|
|
|
typecheck_assert("axis_name" in params,
|
|
|
|
f"Map primitive {prim} missing 'axis_name' parameter")
|
|
|
|
axis_name = params["axis_name"]
|
2020-11-05 11:54:05 +00:00
|
|
|
typecheck_assert("in_axes" in params,
|
|
|
|
f"Map primitive {prim} missing 'in_axes' parameter")
|
|
|
|
in_axes = params["in_axes"]
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
typecheck_assert("out_axes" in params,
|
|
|
|
f"Map primitive {prim} missing 'out_axes' parameter")
|
|
|
|
out_axes = params["out_axes"]
|
2020-06-02 19:10:55 -07:00
|
|
|
|
2021-09-08 01:41:38 -07:00
|
|
|
binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval)
|
2020-11-05 11:54:05 +00:00
|
|
|
if in_axis is not None else v.aval
|
|
|
|
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
|
2020-06-02 19:10:55 -07:00
|
|
|
for binder_aval, in_aval in zip(binder_avals, in_avals):
|
2020-06-24 15:31:33 -07:00
|
|
|
typecheck_assert(typecompat(binder_aval, in_aval),
|
|
|
|
f"Call primitive {prim} passes operand {in_aval} "
|
|
|
|
f"to jaxpr expecting {binder_aval}")
|
2020-06-02 19:10:55 -07:00
|
|
|
|
2020-11-05 11:54:05 +00:00
|
|
|
mapped_avals = [mapped_aval(axis_size, in_axis, aval)
|
|
|
|
if in_axis is not None else aval
|
|
|
|
for aval, in_axis in zip(in_avals, in_axes)]
|
2021-03-05 17:59:16 +00:00
|
|
|
with extend_axis_env(params['axis_name'], axis_size, None):
|
|
|
|
_check_jaxpr(call_jaxpr, mapped_avals)
|
2020-06-02 19:10:55 -07:00
|
|
|
|
|
|
|
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
2021-09-08 01:41:38 -07:00
|
|
|
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
for aval, out_axis in zip(mapped_out_avals, out_axes)]
|
2020-05-21 13:11:58 -07:00
|
|
|
return out_avals
|
2020-04-15 11:05:32 -07:00
|
|
|
|
|
|
|
|
|
|
|
# ------------------- Jaxpr printed representation -------------------
|
2021-09-24 22:08:42 -04:00
|
|
|
def pp_vars(vs: Sequence[Any], *, print_shapes: bool = False) -> pp.Doc:
|
2020-07-30 12:59:36 -07:00
|
|
|
if print_shapes:
|
2021-09-24 22:08:42 -04:00
|
|
|
return pp.nest(2, pp.group(
|
|
|
|
pp.join(pp.brk(), [
|
|
|
|
pp.text(str(v)) +
|
|
|
|
pp.dim(pp.text(":" + v.aval.str_short(short_dtypes=True)))
|
|
|
|
for v in vs
|
|
|
|
])
|
|
|
|
))
|
2020-07-30 12:59:36 -07:00
|
|
|
else:
|
2021-09-24 22:08:42 -04:00
|
|
|
return pp.nest(2, pp.group(
|
|
|
|
pp.join(pp.brk(), [pp.text(str(v)) for v in vs])
|
|
|
|
))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def pp_kv_pair(k:str, v: Any) -> pp.Doc:
|
|
|
|
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
|
|
|
|
pp_v = pp_jaxprs(v)
|
|
|
|
elif isinstance(v, Jaxpr):
|
|
|
|
pp_v = pp_jaxpr(v)
|
|
|
|
elif isinstance(v, ClosedJaxpr):
|
|
|
|
pp_v = pp_jaxpr(v.jaxpr)
|
|
|
|
else:
|
|
|
|
pp_v = pp.text(str(v))
|
|
|
|
return pp.text(f'{k}=') + pp_v
|
|
|
|
|
|
|
|
def pp_kv_pairs(kv_pairs) -> pp.Doc:
|
|
|
|
if not kv_pairs:
|
|
|
|
return pp.nil()
|
|
|
|
return pp.group(
|
|
|
|
pp.nest(2, pp.concat([
|
|
|
|
pp.text("["), pp.brk(""),
|
|
|
|
pp.join(pp.brk(), [pp_kv_pair(k, v) for k, v in kv_pairs])
|
|
|
|
]))
|
|
|
|
+ pp.brk("") + pp.text("]")
|
|
|
|
)
|
|
|
|
|
|
|
|
def pp_eqn(eqn, *, print_shapes=True, source_info=False) -> pp.Doc:
|
|
|
|
lhs = pp_vars(eqn.outvars, print_shapes=print_shapes)
|
|
|
|
annotation = (source_info_util.summarize(eqn.source_info)
|
|
|
|
if source_info else None)
|
|
|
|
return pp.concat([
|
|
|
|
lhs, pp.text(" = ", annotation=annotation), pp.text(eqn.primitive.name),
|
|
|
|
pp_kv_pairs(sorted(eqn.params.items())),
|
|
|
|
pp.text(" ") + pp_vars(eqn.invars)
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
def pp_eqns(eqns, *, print_shapes=True, source_info=False) -> pp.Doc:
|
|
|
|
return pp.join(
|
|
|
|
pp.brk("; "),
|
|
|
|
map(partial(pp_eqn, print_shapes=print_shapes, source_info=source_info),
|
|
|
|
eqns))
|
|
|
|
|
|
|
|
def pp_eqn_compact(primitive_name: str, params: Dict) -> pp.Doc:
|
2020-01-26 23:27:56 -08:00
|
|
|
filtered_params = {k: v for k, v in params.items()
|
2020-05-26 19:32:29 -07:00
|
|
|
if (k != 'branches' and
|
2020-09-18 10:07:13 -07:00
|
|
|
not isinstance(v, (Jaxpr, ClosedJaxpr)))}
|
2021-09-24 22:08:42 -04:00
|
|
|
return pp.text(primitive_name) + pp_kv_pairs(sorted(filtered_params.items()))
|
|
|
|
|
|
|
|
def pp_jaxpr_skeleton(jaxpr, eqns_pp, *, print_shapes=True) -> pp.Doc:
|
2020-06-25 17:10:56 -07:00
|
|
|
str_outvars = str(tuple(jaxpr.outvars))
|
2021-09-24 22:08:42 -04:00
|
|
|
return pp.group(pp.nest(2, pp.concat([
|
|
|
|
pp.text("{ "), pp.bright(pp.text("lambda ")),
|
|
|
|
pp_vars(jaxpr.constvars, print_shapes=print_shapes),
|
|
|
|
pp.text("; "), pp_vars(jaxpr.invars, print_shapes=print_shapes),
|
|
|
|
pp.text(". "), pp.bright(pp.text("let")),
|
|
|
|
pp.nest(2, pp.brk() + eqns_pp), pp.brk(),
|
|
|
|
pp.bright(pp.text("in")),
|
|
|
|
pp.text(f" {str_outvars}")
|
|
|
|
])) + pp.text(" }"))
|
|
|
|
|
|
|
|
|
|
|
|
def pp_jaxpr(jaxpr, *, print_shapes=True, source_info=False) -> pp.Doc:
|
|
|
|
pps = pp_eqns(jaxpr.eqns, print_shapes=print_shapes, source_info=source_info)
|
|
|
|
return pp_jaxpr_skeleton(jaxpr, pps, print_shapes=print_shapes)
|
|
|
|
|
|
|
|
def pp_jaxprs(jaxprs) -> pp.Doc:
|
|
|
|
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
|
|
|
|
return pp.group(pp.nest(2, pp.concat([
|
|
|
|
pp.text('('), pp.brk(""), pp.join(pp.brk(), map(pp_jaxpr, jaxprs))]))
|
|
|
|
+ pp.brk("") + pp.text(')')
|
|
|
|
)
|
|
|
|
|
2020-06-25 17:10:56 -07:00
|
|
|
|
2021-09-24 22:08:42 -04:00
|
|
|
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, print_shapes=True,
|
|
|
|
source_info: bool = False) -> pp.Doc:
|
2020-06-25 17:10:56 -07:00
|
|
|
lo = max(lo, 0)
|
|
|
|
hi = max(lo, min(hi, len(jaxpr.eqns)))
|
|
|
|
eqns = jaxpr.eqns[lo:hi]
|
|
|
|
pps = []
|
|
|
|
if len(eqns) == 0 and len(jaxpr.eqns) != 0:
|
2021-09-24 22:08:42 -04:00
|
|
|
pps.append(pp.text('...'))
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
else:
|
2020-06-25 17:10:56 -07:00
|
|
|
if lo != 0:
|
2021-09-24 22:08:42 -04:00
|
|
|
pps.append(pp.text('...'))
|
|
|
|
pps.extend(map(partial(pp_eqn, print_shapes=print_shapes,
|
|
|
|
source_info=source_info), eqns))
|
2020-06-25 17:10:56 -07:00
|
|
|
if hi != len(jaxpr.eqns):
|
2021-09-24 22:08:42 -04:00
|
|
|
pps.append(pp.text('...'))
|
|
|
|
return pp_jaxpr_skeleton(jaxpr, pp.join(pp.brk("; "), pps),
|
|
|
|
print_shapes=print_shapes)
|