rocm_jax/jax/core.py

2154 lines
72 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from contextlib import contextmanager
from functools import partial, total_ordering
import gc
import itertools as it
import operator
from operator import attrgetter
import threading
2018-11-17 18:03:33 -08:00
import types
2020-06-02 10:26:43 -04:00
from typing import (Any, Callable, ClassVar, Dict, Generator,
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
Type, Union, cast, Iterable, Hashable)
from weakref import ref
2018-11-17 18:03:33 -08:00
import numpy as np
2020-03-09 09:14:23 +00:00
from ._src import dtypes
from ._src import config as jax_config
from ._src.config import FLAGS, config
from .errors import (ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
2018-11-17 18:03:33 -08:00
from . import linear_util as lu
from jax._src import source_info_util
from ._src.util import (safe_zip, safe_map, curry, prod, partialmethod,
tuple_insert, tuple_delete, cache, as_hashable_function,
HashableFunction)
import jax._src.pretty_printer as pp
2018-11-17 18:03:33 -08:00
from ._src import traceback_util
traceback_util.register_exclusion(__file__)
2018-11-17 18:03:33 -08:00
zip = safe_zip
map = safe_map
# -------------------- jaxprs --------------------
class Jaxpr:
constvars: List['Var']
invars: List['Var']
outvars: List['Atom']
eqns: List['JaxprEqn']
def __init__(self, constvars: Sequence['Var'], invars: Sequence['Var'],
outvars: Sequence['Atom'], eqns: Sequence['JaxprEqn']):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
replaced with such variables while scalar constants are kept inline.
invars: list of input variables. Together, `constvars` and `invars` are
the inputs to the Jaxpr.
outvars: list of output variables.
eqns: list of equations.
"""
self.constvars = list(constvars)
self.invars = list(invars)
self.outvars = list(outvars)
self.eqns = list(eqns)
2018-11-17 18:03:33 -08:00
def __str__(self):
return str(pp_jaxpr(self))
__repr__ = __str__
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
doc = pp_jaxpr(self, source_info=source_info, print_shapes=print_shapes)
return doc.format(**kw)
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, Jaxpr):
yield v
elif isinstance(v, ClosedJaxpr):
yield v.jaxpr
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
for eqn in jaxpr.eqns:
yield from jaxprs_in_params(eqn.params)
class ClosedJaxpr:
jaxpr: Jaxpr
consts: List['Any']
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
self.jaxpr = jaxpr
self.consts = list(consts)
@property
def in_avals(self):
return [v.aval for v in self.jaxpr.invars]
2020-03-09 09:14:23 +00:00
@property
def out_avals(self):
return [v.aval for v in self.jaxpr.outvars]
@property
def literals(self):
return self.consts # backwards compatible alias
@property
def eqns(self):
return self.jaxpr.eqns
2020-11-03 12:11:03 +00:00
def map_jaxpr(self, f):
return ClosedJaxpr(f(self.jaxpr), self.consts)
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
2018-11-17 18:03:33 -08:00
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
return pp_jaxpr(self.jaxpr, source_info=source_info,
print_shapes=print_shapes).format(**kw)
@curry
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
class JaxprEqn(NamedTuple):
invars: List['Atom']
outvars: List['Var']
primitive: 'Primitive'
2020-06-02 10:26:43 -04:00
params: Dict[str, Any]
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
source_info: Optional[source_info_util.Traceback]
def __repr__(self): return str(pp_eqn(self)).rstrip()
def new_jaxpr_eqn(invars, outvars, primitive, params, source_info=None):
2021-08-30 11:10:10 -07:00
if primitive.call_primitive:
assert len(outvars) == len(params["call_jaxpr"].outvars)
return JaxprEqn(invars, outvars, primitive, params, source_info)
2020-01-06 13:29:21 +00:00
@total_ordering
class Var:
# TODO(frostig,mattjj): We don't override __eq__ or __hash__, so comparison is
# by object id, but pretty printing might collide.
count: int
suffix: str
aval: 'AbstractValue'
def __init__(self, count: int, suffix: str, aval: 'AbstractValue'):
self.count = count
self.suffix = suffix
2020-03-09 09:14:23 +00:00
self.aval = raise_to_shaped(aval)
2020-01-06 13:29:21 +00:00
def __lt__(self, other):
if not isinstance(other, Var):
return NotImplemented
else:
return (self.count, self.suffix) < (other.count, other.suffix)
def __repr__(self):
rem = self.count
s = ''
while True:
rem, i = rem // 26, rem % 26
s = chr(97 + i % 26) + s
if not rem:
break
return s + self.suffix
def _jaxpr_vars(jaxpr):
return it.chain(
jaxpr.invars, jaxpr.constvars,
(v for eqn in jaxpr.eqns for v in eqn.outvars))
def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
suffix: str = '') -> Callable[['AbstractValue'], Var]:
"""Produce distinct variables, printed with the optional suffix.
If `jaxprs` is provided, the variables produced will be distinct from those in
any of the given jaxprs.
"""
if jaxprs is None:
start = 0
else:
all_vars = it.chain.from_iterable(_jaxpr_vars(j) for j in jaxprs)
start = 1 + max((v.count for v in all_vars), default=-1)
counter = it.count(start=start)
2020-03-09 09:14:23 +00:00
return lambda aval: Var(next(counter), suffix, aval)
2020-06-08 16:13:30 -07:00
# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
# the assignment is dropped, i.e. that an expression's output value will never
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
count = -1
suffix = ''
def __init__(self): pass
@property
def aval(self): return abstract_unit
def __repr__(self): return '_'
dropvar = DropVar()
class Literal:
__slots__ = ["val", "hash"]
val: Any
hash: Optional[int]
def __init__(self, val):
self.val = val
try:
self.hash = hash(val)
except TypeError:
if type(val) in literalable_types:
try:
self.hash = hash((val.item(), val.dtype))
except (TypeError, AttributeError, ValueError):
self.hash = None
2020-03-09 09:14:23 +00:00
@property
def aval(self):
return raise_to_shaped(get_aval(self.val))
def __hash__(self):
assert False
def __repr__(self):
if hasattr(self, 'hash'):
return '{}'.format(self.val)
else:
return 'Literal(val={})'.format(self.val)
literalable_types: Set[type] = set()
Atom = Union[Var, Literal]
class Primitive:
name: str
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
multiple_results = False # set for multi-output primitives
call_primitive = False # set for call primitives processed in final style
map_primitive = False # set for map primitives processed in final style
2021-06-16 11:10:42 -07:00
_dispatch_on_params = False # whether to include axis names from params in dispatch
2019-07-27 10:43:40 -04:00
def __init__(self, name: str):
2018-11-17 18:03:33 -08:00
self.name = name
def __repr__(self):
return '{}'.format(self.name)
2020-09-15 08:06:46 -07:00
def bind(self, *args, **params):
2021-03-19 13:49:38 -07:00
assert (not config.jax_enable_checks or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
top_trace = find_top_trace(
args, used_axis_names(self, params) if self._dispatch_on_params else None)
2018-11-17 18:03:33 -08:00
tracers = map(top_trace.full_raise, args)
2020-09-15 08:06:46 -07:00
out = top_trace.process_primitive(self, tracers, params)
return map(full_lower, out) if self.multiple_results else full_lower(out)
2018-11-17 18:03:33 -08:00
def def_impl(self, impl):
self.impl = impl
return impl
def def_abstract_eval(self, abstract_eval):
self.abstract_eval = abstract_eval
return abstract_eval
2018-11-17 18:03:33 -08:00
def def_custom_bind(self, bind):
self.bind = bind
return bind
def impl(self, *args, **params):
2018-11-17 18:03:33 -08:00
raise NotImplementedError("Evaluation rule for '{}' not implemented"
.format(self.name))
def abstract_eval(self, *args, **params):
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
.format(self.name))
2018-11-17 18:03:33 -08:00
# -------------------- lifting --------------------
# TODO(necula): this belongs next to pe.new_eqn_recipe, but is needed in
# core.py. Plan to move all these utilities to jaxpr.py.
def extract_call_jaxpr(
primitive: Primitive,
params: Dict[str, Any]) -> Tuple[Optional[Jaxpr], Dict[str, Any]]:
"""Extract the call primitive subjaxpr from the params.
Returns the subjaxpr and the params without the "call_jaxpr" value. If this is
not a call primitive then returns (None, params).
"""
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
if not (primitive.call_primitive or primitive.map_primitive):
return (None, params)
else:
assert "call_jaxpr" in params
new_params = dict(params)
del new_params["call_jaxpr"]
return (params["call_jaxpr"], new_params)
2018-11-17 18:03:33 -08:00
2021-07-29 10:34:43 -07:00
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
def traverse_jaxpr_params(f, params):
"""Applies f to each jaxpr parameter and returns a tuple of returned values."""
2021-07-29 10:34:43 -07:00
return {name: f(p)
for name, param in params.items()
2021-07-29 10:34:43 -07:00
for p in (param if isinstance(param, (tuple, list)) else [param])
if type(p) in (Jaxpr, ClosedJaxpr)}
def eval_jaxpr_eqn(eqn, in_vals):
"""Evaluates the jaxpr equation with the provided input values."""
call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params)
if call_jaxpr:
subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))]
else:
subfuns = []
if eqn.primitive in initial_to_final_param_rules:
bind_params = initial_to_final_param_rules[eqn.primitive](params)
elif eqn.primitive.map_primitive:
out_axes_thunk = HashableFunction(lambda: params['out_axes'],
closure=params['out_axes'])
bind_params = dict(params, out_axes_thunk=out_axes_thunk)
del bind_params['out_axes']
else:
bind_params = params
with source_info_util.user_context(eqn.source_info):
return eqn.primitive.bind(*(subfuns + in_vals), **bind_params)
def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
2018-11-17 18:03:33 -08:00
def read(v):
if type(v) is Literal:
return v.val
else:
return env[v]
2018-11-17 18:03:33 -08:00
def write(v, val):
env[v] = val
env: Dict[Var, Any] = {}
2018-11-17 18:03:33 -08:00
write(unitvar, unit)
2019-08-21 13:53:57 -07:00
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
2018-11-17 18:03:33 -08:00
for eqn in jaxpr.eqns:
ans = eval_jaxpr_eqn(eqn, map(read, eqn.invars))
2019-07-27 10:43:40 -04:00
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
return map(read, jaxpr.outvars)
2018-11-17 18:03:33 -08:00
initial_to_final_param_rules: Dict[Primitive, Callable] = {}
2018-11-17 18:03:33 -08:00
# -------------------- tracing --------------------
class Trace:
__slots__ = ['main', 'level', 'sublevel']
main: 'MainTrace'
level: int
sublevel: 'Sublevel'
def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None:
self.main = main
self.level = main.level
2018-11-17 18:03:33 -08:00
self.sublevel = sublevel
def full_raise(self, val) -> 'Tracer':
2018-11-17 18:03:33 -08:00
if not isinstance(val, Tracer):
return self.pure(val)
val._assert_live()
2018-11-17 18:03:33 -08:00
level = self.level
sublevel = self.sublevel
if val._trace.main is self.main:
if val._trace.sublevel == sublevel:
2018-11-17 18:03:33 -08:00
return val
elif val._trace.sublevel < sublevel:
2018-11-17 18:03:33 -08:00
return self.sublift(val)
else:
raise escaped_tracer_error(
val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}")
elif val._trace.level < level:
if val._trace.sublevel > sublevel:
raise escaped_tracer_error(
val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}")
2018-11-17 18:03:33 -08:00
return self.lift(val)
elif val._trace.level > level:
raise escaped_tracer_error(
val, f"Can't lift level {val} to {self}")
else: # val._trace.level == self.level:
raise escaped_tracer_error(
val, f"Different traces at same level: {val}, {self}")
2018-11-17 18:03:33 -08:00
def pure(self, val):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def lift(self, tracer):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def sublift(self, tracer):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def __repr__(self):
return '{}(level={}/{})'.format(
self.__class__.__name__, self.level, self.sublevel)
2020-03-30 22:06:00 -07:00
def process_call(self, call_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_call to handle call-like "
"primitives")
raise NotImplementedError(msg)
2021-06-16 11:10:42 -07:00
def process_map(self, map_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_map to handle map-like "
"primitives")
raise NotImplementedError(msg)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
msg = (f"{type(self)} must override process_custom_jvp_call "
"to handle custom_jvp primitives")
raise NotImplementedError(msg)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
def escaped_tracer_error(tracer, detail=None):
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
'had a side effect, allowing for a reference to an intermediate value '
f'with shape {tracer.shape} and dtype {tracer.dtype} to escape.\n'
'JAX transformations require that functions explicitly return their '
'outputs, and disallow saving intermediate values to global state.')
dbg = getattr(tracer._trace.main, 'debug_info', None)
if dbg is not None:
msg += ('\nThe function being traced when the value leaked was '
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
line_info = getattr(tracer, '_line_info', None)
if line_info is not None:
divider = '\n' + '-'*30 + '\n'
msg += divider
msg += ('The leaked intermediate value was created on line '
f'{source_info_util.summarize(line_info)}. ')
msg += divider
if num_frames > 0:
msg += (f'When the value was created, the final {num_frames} stack '
'frames (most recent last) excluding JAX-internal frames were:')
msg += divider + source_info_util.summarize(
line_info, num_frames=num_frames) + divider
msg += ('\nTo catch the leak earlier, try setting the environment variable '
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
'manager.')
if detail:
msg += f'Detail: {detail}'
return UnexpectedTracerError(msg)
class Tracer:
2018-11-17 18:03:33 -08:00
__array_priority__ = 1000
__slots__ = ['_trace', '__weakref__', '_line_info']
2018-11-17 18:03:33 -08:00
def __array__(self, *args, **kw):
raise TracerArrayConversionError(self)
2018-11-17 18:03:33 -08:00
def __index__(self):
raise TracerIntegerConversionError(self)
def __init__(self, trace: Trace):
self._trace = trace
2018-11-17 18:03:33 -08:00
def __iter__(self):
return iter(self.aval._iter(self))
def __len__(self):
return self.aval._len(self)
2018-11-17 18:03:33 -08:00
@property
def aval(self):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def _assert_live(self) -> None:
pass # Override for liveness checking
Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
# Python looks up special methods only on classes, not instances. This means
# these methods needs to be defined explicitly rather than relying on
# __getattr__.
2018-11-17 18:03:33 -08:00
def __neg__(self): return self.aval._neg(self)
def __pos__(self): return self.aval._pos(self)
2018-11-17 18:03:33 -08:00
def __eq__(self, other): return self.aval._eq(self, other)
def __ne__(self, other): return self.aval._ne(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __le__(self, other): return self.aval._le(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __ge__(self, other): return self.aval._ge(self, other)
def __abs__(self): return self.aval._abs(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __sub__(self, other): return self.aval._sub(self, other)
def __rsub__(self, other): return self.aval._rsub(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __div__(self, other): return self.aval._div(self, other)
def __rdiv__(self, other): return self.aval._rdiv(self, other)
def __truediv__(self, other): return self.aval._truediv(self, other)
def __rtruediv__(self, other): return self.aval._rtruediv(self, other)
2018-11-17 18:03:33 -08:00
def __floordiv__(self, other): return self.aval._floordiv(self, other)
def __rfloordiv__(self, other): return self.aval._rfloordiv(self, other)
def __divmod__(self, other): return self.aval._divmod(self, other)
def __rdivmod__(self, other): return self.aval._rdivmod(self, other)
def __mod__(self, other): return self.aval._mod(self, other)
def __rmod__(self, other): return self.aval._rmod(self, other)
def __pow__(self, other): return self.aval._pow(self, other)
def __rpow__(self, other): return self.aval._rpow(self, other)
def __matmul__(self, other): return self.aval._matmul(self, other)
def __rmatmul__(self, other): return self.aval._rmatmul(self, other)
def __and__(self, other): return self.aval._and(self, other)
def __rand__(self, other): return self.aval._rand(self, other)
def __or__(self, other): return self.aval._or(self, other)
def __ror__(self, other): return self.aval._ror(self, other)
def __xor__(self, other): return self.aval._xor(self, other)
def __rxor__(self, other): return self.aval._rxor(self, other)
2019-02-15 14:09:06 -08:00
def __invert__(self): return self.aval._invert(self)
2018-11-17 18:03:33 -08:00
def __lshift__(self, other): return self.aval._lshift(self, other)
def __rlshift__(self, other): return self.aval._rlshift(self, other)
2018-11-17 18:03:33 -08:00
def __rshift__(self, other): return self.aval._rshift(self, other)
def __rrshift__(self, other): return self.aval._rrshift(self, other)
2018-11-17 18:03:33 -08:00
def __getitem__(self, idx): return self.aval._getitem(self, idx)
def __nonzero__(self): return self.aval._nonzero(self)
def __bool__(self): return self.aval._bool(self)
def __int__(self): return self.aval._int(self)
def __long__(self): return self.aval._long(self)
def __hex__(self): return self.aval._hex(self)
def __oct__(self): return self.aval._oct(self)
2020-09-15 08:06:46 -07:00
def __float__(self): return self.aval._float(self)
def __complex__(self): return self.aval._complex(self)
# raises the better error message from ShapedArray
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
2018-11-17 18:03:33 -08:00
Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
# NumPy also only looks up special methods on classes.
def __array_module__(self, types): return self.aval._array_module(self, types)
2018-11-17 18:03:33 -08:00
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
2021-03-19 13:49:38 -07:00
assert not config.jax_enable_checks or name != "aval"
2018-11-17 18:03:33 -08:00
try:
attr = getattr(self.aval, name)
except KeyError as err:
2018-11-17 18:03:33 -08:00
raise AttributeError(
"{} has no attribute {}".format(self.__class__.__name__, name)
) from err
2018-11-17 18:03:33 -08:00
else:
t = type(attr)
if t is aval_property:
return attr.fget(self)
elif t is aval_method:
return types.MethodType(attr.fun, self)
2018-11-17 18:03:33 -08:00
else:
return attr
def _pretty_print(self):
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
else pp.text(repr(attr))) for name, attr in self._contents()]
if contents:
base = pp.group(pp.nest(2, pp.concat([
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
pp.text('{} = '.format(name)) + pp_payload
for name, pp_payload in contents])
])))
return base
def __repr__(self):
return self._pretty_print().format()
def _contents(self):
try:
return [(name, getattr(self, name)) for name in self.__slots__]
except AttributeError:
return ()
2018-11-17 18:03:33 -08:00
def __copy__(self):
return self
def __deepcopy__(self, unused_memo):
return self
2020-09-15 08:06:46 -07:00
def _origin_msg(self) -> str:
return ""
2018-11-17 18:03:33 -08:00
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
class EvalTrace(Trace):
# See comments in https://github.com/google/jax/pull/3370
def pure(self, x): return x
lift = sublift = pure
def process_primitive(self, primitive, tracers, params):
return primitive.impl(*tracers, **params)
def process_call(self, primitive, f, tracers, params):
return primitive.impl(f, *tracers, **params)
process_map = process_call
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
del primitive, jvp # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
del primitive, fwd, bwd, out_trees # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)
class MainTrace:
level: int
trace_type: Type[Trace]
payload: Dict[str, Any]
def __init__(self, level, trace_type, **payload) -> None:
2018-11-17 18:03:33 -08:00
self.level = level
self.trace_type = trace_type
self.payload = payload
2018-11-17 18:03:33 -08:00
def __repr__(self) -> str:
return "MainTrace({},{})".format(self.level, self.trace_type.__name__)
2018-11-17 18:03:33 -08:00
def __hash__(self) -> int:
2018-11-17 18:03:33 -08:00
return hash((self.level, self.trace_type))
def __eq__(self, other: object) -> bool:
return (isinstance(other, MainTrace) and
self.level == other.level and
self.trace_type == other.trace_type and
self.payload == other.payload)
def with_cur_sublevel(self):
return self.trace_type(self, cur_sublevel(), **self.payload)
2018-11-17 18:03:33 -08:00
class TraceStack:
# See comments in https://github.com/google/jax/pull/3370
stack: List[MainTrace]
dynamic: MainTrace
2018-11-17 18:03:33 -08:00
def __init__(self):
2020-09-15 08:06:46 -07:00
eval_trace = MainTrace(0, EvalTrace)
self.stack = [eval_trace]
self.dynamic = eval_trace
2018-11-17 18:03:33 -08:00
2020-09-15 08:06:46 -07:00
def next_level(self) -> int:
return len(self.stack)
2018-11-17 18:03:33 -08:00
2020-09-15 08:06:46 -07:00
def push(self, main_trace: MainTrace) -> None:
self.stack.append(main_trace)
2018-11-17 18:03:33 -08:00
2020-09-15 08:06:46 -07:00
def pop(self) -> None:
self.stack.pop()
2018-11-17 18:03:33 -08:00
def __repr__(self) -> str:
2020-09-15 08:06:46 -07:00
stack_str = map(' {}\n'.format, self.stack[::-1])
return f'Trace stack\n{stack_str}\n{self.dynamic}'
2018-11-17 18:03:33 -08:00
def copy(self):
2020-09-15 08:06:46 -07:00
new = self.__new__(TraceStack)
new.stack = self.stack[:]
new.dynamic = self.dynamic
return new
2018-11-17 18:03:33 -08:00
@total_ordering
class Sublevel:
def __init__(self, level: int):
self.level = level
def __repr__(self):
return str(self.level)
def __eq__(self, other):
return type(other) is Sublevel and self.level == other.level
def __lt__(self, other):
return type(other) is Sublevel and self.level < other.level
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
AxisName = Hashable
no_axis_name = object()
class TraceState:
trace_stack: TraceStack
substack: List[Sublevel]
2020-09-15 08:06:46 -07:00
axis_env: List[AxisEnvFrame]
def __init__(self) -> None:
self.trace_stack = TraceStack()
self.substack = [Sublevel(0)]
2020-09-15 08:06:46 -07:00
self.axis_env = []
def copy(self):
2020-09-15 08:06:46 -07:00
new = self.__new__(TraceState)
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
2020-09-15 08:06:46 -07:00
new.axis_env = self.axis_env[:]
return new
def _update_thread_local_jit_state(dynamic):
# Copies the MainTrace instance, removing any .debug_info or .jaxpr_stack
# fields that should not be kept alive as part of a cache key.
# TODO(mattjj): split debug_info and jaxpr_stack out of MainTrace.
# TODO(mattjj): add a test that verifies that JIT-ted functions are not kept
# alive by the JIT cache, particularly for nested JIT-ted functions.
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy)
# The global state of the tracer is accessed by a thread-local object.
# This allows concurrent tracing in separate threads; passing traced objects
# between threads is forbidden.
class ThreadLocalState(threading.local):
def __init__(self):
self.trace_state = TraceState()
_update_thread_local_jit_state(self.trace_state.trace_stack.dynamic)
thread_local_state = ThreadLocalState()
def trace_state_clean() -> bool:
trace_state = thread_local_state.trace_state
return (trace_state.substack == [Sublevel(0)] and
trace_state.axis_env == [] and
trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and
trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace))
def reset_trace_state() -> bool:
"Reset the global trace state and return True if it was already clean."
if not trace_state_clean():
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
return True
def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]
2018-11-17 18:03:33 -08:00
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
It's possible there's none! eg. there's some cases where JAX itself holds a
reference to `x` inside of a lambda closure, and no tracers were leaked
by the user. In this case an empty list is returned.
"""
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
return tracers
2018-11-17 18:03:33 -08:00
@contextmanager
def new_main(trace_type: Type[Trace],
dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
2020-09-15 08:06:46 -07:00
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
main = MainTrace(level, trace_type, **payload)
2020-09-15 08:06:46 -07:00
stack.push(main)
if dynamic:
prev_dynamic, stack.dynamic = stack.dynamic, main
_update_thread_local_jit_state(stack.dynamic)
2018-11-17 18:03:33 -08:00
try:
yield main
2018-11-17 18:03:33 -08:00
finally:
stack.pop()
2020-09-15 08:06:46 -07:00
if dynamic:
stack.dynamic = prev_dynamic
_update_thread_local_jit_state(stack.dynamic)
2018-11-17 18:03:33 -08:00
2021-03-19 13:49:38 -07:00
if config.jax_check_tracer_leaks:
t = ref(main)
del main
2018-11-17 18:03:33 -08:00
if t() is not None:
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
2018-11-17 18:03:33 -08:00
2020-09-15 08:06:46 -07:00
@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
2020-09-15 08:06:46 -07:00
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type)
prev_dynamic, stack.dynamic = stack.dynamic, main
prev_base, stack.stack[0] = stack.stack[0], main
_update_thread_local_jit_state(stack.dynamic)
2020-09-15 08:06:46 -07:00
try:
yield main
finally:
stack.dynamic = prev_dynamic
stack.stack[0] = prev_base
_update_thread_local_jit_state(stack.dynamic)
2020-09-15 08:06:46 -07:00
2021-03-19 13:49:38 -07:00
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
2020-09-15 08:06:46 -07:00
@contextmanager
def eval_context():
with new_base_main(EvalTrace):
yield
2018-11-17 18:03:33 -08:00
@contextmanager
def new_sublevel() -> Generator[None, None, None]:
sublevel = Sublevel(len(thread_local_state.trace_state.substack))
thread_local_state.trace_state.substack.append(sublevel)
2018-11-17 18:03:33 -08:00
try:
yield
finally:
thread_local_state.trace_state.substack.pop()
2018-11-17 18:03:33 -08:00
2021-03-19 13:49:38 -07:00
if config.jax_check_tracer_leaks:
t = ref(sublevel)
del sublevel
if t() is not None:
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked sublevel {t()}. Leaked tracer(s): {leaked_tracers}.')
2018-11-17 18:03:33 -08:00
def full_lower(val):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def find_top_trace(xs, axis_names=None) -> Trace:
top_main: Optional[MainTrace] = None
if axis_names:
top_main = max((axis_frame(a).main_trace for a in axis_names),
default=None, key=lambda t: getattr(t, 'level', -1))
top_tracer = max((x for x in xs if isinstance(x, Tracer)),
default=None, key=attrgetter('_trace.level'))
if top_tracer is not None:
top_tracer._assert_live()
if top_tracer._trace.main.level > getattr(top_main, 'level', -1):
top_main = top_tracer._trace.main
2020-09-15 08:06:46 -07:00
dynamic = thread_local_state.trace_state.trace_stack.dynamic
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
else top_main)
return top_main and top_main.with_cur_sublevel() # type: ignore
2020-07-31 22:20:58 -07:00
2018-11-17 18:03:33 -08:00
# -------------------- abstract values --------------------
class AbstractValue:
__slots__: List[str] = []
_num_buffers: int = 1 # number of buffers used to represent the value.
2018-11-17 18:03:33 -08:00
def at_least_vspace(self):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def __repr__(self):
try:
kv_pairs = ('{}={}'.format(k, v) for k, v in self.__dict__.items())
return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
except AttributeError:
return self.__class__.__name__
def strip_weak_type(self) -> 'AbstractValue':
return self
2018-11-17 18:03:33 -08:00
def strip_named_shape(self) -> 'AbstractValue':
return self
def join(self, other):
raise NotImplementedError("must override")
def update(self, **kwargs):
raise NotImplementedError("must override")
def str_short(self, short_dtypes=False):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
class Bot(AbstractValue): pass
bot = Bot()
class AbstractUnit(AbstractValue):
# TODO(jakevdp): make it possible to set zero buffers
# _num_buffers = 0
def at_least_vspace(self): return self
def join(self, other):
2021-03-19 13:49:38 -07:00
if config.jax_enable_checks:
assert other is abstract_unit, other
return self
def _eq(self, self_traced, other): return get_aval(other) is self
def str_short(self, short_dtypes=False): return '*'
abstract_unit = AbstractUnit()
2018-11-17 18:03:33 -08:00
def lattice_join(x: Optional[AbstractValue],
y: Optional[AbstractValue]) -> AbstractValue:
2018-11-17 18:03:33 -08:00
if x is None:
return cast(AbstractValue, y)
2018-11-17 18:03:33 -08:00
elif y is None:
return cast(AbstractValue, x)
2018-11-17 18:03:33 -08:00
elif isinstance(x, type(y)):
return y.join(x)
elif isinstance(y, type(x)):
return x.join(y)
else:
raise TypeError(x, y)
2018-11-17 18:03:33 -08:00
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
2018-11-17 18:03:33 -08:00
def valid_jaxtype(x):
try:
concrete_aval(x)
except TypeError:
return False
2019-05-06 22:43:31 -07:00
else:
return True
2018-11-17 18:03:33 -08:00
def check_valid_jaxtype(x):
if not valid_jaxtype(x):
raise TypeError(
f"Value {repr(x)} of type {type(x)} is not a valid JAX type")
2018-11-17 18:03:33 -08:00
def concrete_aval(x):
for typ in type(x).mro():
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return concrete_aval(x.__jax_array__())
raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
"type")
2018-11-17 18:03:33 -08:00
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
return concrete_aval(x)
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
2018-11-17 18:03:33 -08:00
class Unit:
def __repr__(self): return '*'
unit = Unit()
literalable_types.add(Unit)
class UnitVar(Var):
count = -1
suffix = ''
def __init__(self): pass
2020-03-09 09:14:23 +00:00
@property
def aval(self): return abstract_unit
def __repr__(self): return '*'
unitvar = UnitVar()
2018-11-17 18:03:33 -08:00
pytype_aval_mappings[Unit] = lambda _: abstract_unit
2018-11-17 18:03:33 -08:00
2020-09-15 08:06:46 -07:00
def concretization_function_error(fun, suggest_astype=False):
2020-03-09 09:14:23 +00:00
fname = getattr(fun, "__name__", fun)
2020-09-15 08:06:46 -07:00
fname_context = f"The problem arose with the `{fname}` function. "
if suggest_astype:
fname_context += ("If trying to convert the data type of a value, "
f"try using `x.astype({fun.__name__})` "
f"or `jnp.array(x, {fun.__name__})` instead.")
def error(self, arg):
raise ConcretizationTypeError(arg, fname_context)
2020-03-09 09:14:23 +00:00
return error
def concrete_or_error(force: Any, val: Any, context=""):
"""Like force(val), but gives the context in the error message."""
if force is None:
force = lambda x: x
if isinstance(val, Tracer):
if isinstance(val.aval, ConcreteArray):
return force(val.aval.val)
else:
raise ConcretizationTypeError(val, context)
else:
return force(val)
convert_element_type_p = Primitive('convert_element_type')
def _short_dtype_name(dtype):
return (dtype.name.replace('float', 'f').replace('uint', 'u')
.replace('int', 'i').replace('complex', 'c'))
2020-03-09 09:14:23 +00:00
class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 2
def __init__(self, dtype, weak_type=False):
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
2020-03-09 09:14:23 +00:00
self.weak_type = weak_type
def update(self, dtype=None, weak_type=None):
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
return UnshapedArray(dtype, weak_type)
2020-03-09 09:14:23 +00:00
def __eq__(self, other):
return (type(self) is type(other) and self.dtype == other.dtype and
self.weak_type == other.weak_type)
def __ne__(self, other):
return not self == other
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
2020-03-09 09:14:23 +00:00
# the unique character code via hash(self.dtype.char)
return hash((self.dtype, self.weak_type))
def __repr__(self):
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
", weak_type=True" if self.weak_type else "")
_bool = _nonzero = concretization_function_error(bool)
2020-09-15 08:06:46 -07:00
_float = concretization_function_error(float, True)
_int = concretization_function_error(int, True)
_complex = concretization_function_error(complex, True)
2020-03-09 09:14:23 +00:00
_hex = concretization_function_error(hex)
_oct = concretization_function_error(oct)
def at_least_vspace(self) -> AbstractValue:
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
2020-03-09 09:14:23 +00:00
def join(self, other):
if self.dtype == other.dtype:
if self.weak_type == other.weak_type:
return self
else:
return UnshapedArray(self.dtype, weak_type=False)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False) -> str:
return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
2020-03-09 09:14:23 +00:00
def strip_weak_type(self):
2020-03-09 09:14:23 +00:00
"""Returns a copy of the aval with weak_type=False."""
return self.update(weak_type=False)
2020-03-09 09:14:23 +00:00
@property
def shape(self):
msg = ("UnshapedArray has no shape. Please open an issue at "
"https://github.com/google/jax/issues because it's unexpected for "
"UnshapedArray instances to ever be produced.")
raise TypeError(msg)
2020-03-09 09:14:23 +00:00
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'named_shape']
2020-03-09 09:14:23 +00:00
array_abstraction_level = 1
def __init__(self, shape, dtype, weak_type=False, named_shape={}):
super().__init__(dtype, weak_type=weak_type)
2020-03-09 09:14:23 +00:00
self.shape = canonicalize_shape(shape)
self.named_shape = dict(named_shape)
2020-03-09 09:14:23 +00:00
def update(self, shape=None, dtype=None, weak_type=None, named_shape=None):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
if named_shape is None:
named_shape = self.named_shape
return ShapedArray(shape, dtype, weak_type, named_shape)
2020-03-09 09:14:23 +00:00
ndim = property(lambda self: len(self.shape))
size = property(lambda self: prod(self.shape))
broadcast: ClassVar[Optional[aval_method]] = None
transpose: ClassVar[Optional[aval_method]] = None
reshape: ClassVar[Optional[aval_method]] = None
_iter: ClassVar[Optional[staticmethod]] = None
2020-03-09 09:14:23 +00:00
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type
and self.named_shape == other.named_shape)
2020-03-09 09:14:23 +00:00
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
2020-03-09 09:14:23 +00:00
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type,
tuple(self.named_shape.items())))
2020-03-09 09:14:23 +00:00
def at_least_vspace(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, self.named_shape)
2020-03-09 09:14:23 +00:00
def join(self, other):
if symbolic_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
named_shape = join_named_shapes(self.named_shape, other.named_shape)
return self.update(weak_type=weak_type, named_shape=named_shape)
2020-03-09 09:14:23 +00:00
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False):
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
2020-03-09 09:14:23 +00:00
shapestr = ','.join(map(str, self.shape))
if self.named_shape:
named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items())
return f'{dt_str}[{shapestr};{named_shapestr}]'
else:
return f'{dt_str}[{shapestr}]'
def strip_named_shape(self):
return self.update(named_shape={})
2020-03-09 09:14:23 +00:00
def __len__(self):
try:
return self.shape[0]
2020-09-30 01:20:00 +09:00
except IndexError as err:
raise TypeError("len() of unsized object") from err # same as numpy error
2020-03-09 09:14:23 +00:00
def _len(self, ignored_tracer):
return len(self)
def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)
class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
def __init__(self, val, weak_type=False):
super().__init__(np.shape(val), np.result_type(val),
weak_type=weak_type)
2020-03-09 09:14:23 +00:00
# Note: canonicalized self.dtype doesn't necessarily match self.val
self.val = val
assert self.dtype != np.dtype('O'), val
2020-03-09 09:14:23 +00:00
def update(self, val=None, weak_type=None):
if val is None:
val = self.val
if weak_type is None:
weak_type = self.weak_type
return ConcreteArray(val, weak_type)
2020-03-09 09:14:23 +00:00
def __eq__(self, other):
if (type(self) is type(other) and self.dtype == other.dtype
and self.shape == other.shape and self.weak_type == other.weak_type):
with eval_context(): # in case self.val is a DeviceArray
return (self.val == other.val).all()
else:
return False
2020-03-09 09:14:23 +00:00
def __hash__(self):
return id(self.val)
def join(self, other) -> AbstractValue:
2020-03-09 09:14:23 +00:00
if self == other:
return self
elif self.shape == other.shape and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
named_shape = join_named_shapes(self.named_shape, other.named_shape)
return ShapedArray(
self.shape, self.dtype, weak_type=weak_type, named_shape=named_shape)
2020-03-09 09:14:23 +00:00
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype,
weak_type=self.weak_type and other.weak_type)
else:
raise TypeError(self, other)
def str_short(self, short_dtypes=False) -> str:
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
return f'{self.val}, dtype={dt_str}'
2020-03-09 09:14:23 +00:00
_bool = _nonzero = partialmethod(_forward_to_value, bool)
2020-09-15 08:06:46 -07:00
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
2020-03-09 09:14:23 +00:00
def primal_dtype_to_tangent_dtype(primal_dtype):
if not dtypes.issubdtype(primal_dtype, np.inexact):
return dtypes.float0
else:
return primal_dtype
2020-03-09 09:14:23 +00:00
class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
return self
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self, short_dtypes=False): return 'Tok'
def at_least_vspace(self): return self
2020-03-09 09:14:23 +00:00
abstract_token: AbstractToken = AbstractToken()
2020-03-09 09:14:23 +00:00
def raise_to_shaped(aval: AbstractValue, weak_type=None):
if weak_type is None:
weak_type = getattr(aval, 'weak_type', False)
for typ in type(aval).mro():
handler = raise_to_shaped_mappings.get(typ)
if handler: return handler(aval, weak_type)
raise TypeError(type(aval))
raise_to_shaped_mappings : Dict[type, Callable] = {
AbstractUnit: lambda aval, _: aval,
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
ShapedArray: lambda aval, weak_type: ShapedArray(
aval.shape, aval.dtype, weak_type, aval.named_shape)
}
2020-03-09 09:14:23 +00:00
2021-04-05 16:37:35 +03:00
### Operations on shapes and dimension sizes.
2020-03-09 09:14:23 +00:00
# Shapes are tuples of dimension sizes, which are normally integers. We allow
# modules to extend the set of dimension sizes to contain other types, e.g.,
2021-04-05 16:37:35 +03:00
# symbolic dimensions in jax2tf.shape_poly.DimVar and masking.Poly.
DimSize = Union[int, Any] # extensible
Shape = Sequence[DimSize]
2021-04-05 16:37:35 +03:00
class InconclusiveDimensionOperation(Exception):
"""Raised when we cannot conclusively compute with symbolic dimensions."""
pass
class DimensionHandler:
2021-04-05 16:37:35 +03:00
"""Operations on dimension sizes.
Dimension sizes are normally integer constants, but can also be symbolic,
e.g., masking.Poly or jax2tf.shape_poly.DimVar.
The base class works for integers only. Subclasses are invoked when at
least one of the operands has a type registered in _SPECIAL_DIMENSION_HANDLERS.
In that case, all operands are guaranteed to be either the special dimension
type, or Python integer scalars.
2021-04-05 16:37:35 +03:00
Subclasses should raise InconclusiveDimensionOperation if the result cannot
be computed in some contexts.
"""
def is_constant(self, d: DimSize) -> bool:
"""The dimension is a constant."""
return True
def symbolic_equal(self, d1: DimSize, d2: DimSize) -> bool:
2021-04-05 16:37:35 +03:00
"""True iff the dimension sizes are equal in all contexts; False otherwise.
Unlike `d1 == d2` this never raises InconclusiveDimensionOperation.
"""
return d1 == d2
def greater_equal(self, d1: DimSize, d2: DimSize) -> bool:
2021-04-05 16:37:35 +03:00
"""Computes `d1 >= d2`.
Raise InconclusiveDimensionOperation if the result is different in
different contexts.
"""
return d1 >= d2
2021-04-05 16:37:35 +03:00
def sum(self, *ds: DimSize) -> DimSize:
"""Sum of dimensions.
Raises InconclusiveDimensionOperation if the result cannot be represented
by the same DimSize in all contexts.
"""
return sum(ds)
def diff(self, d1: DimSize, d2: DimSize) -> DimSize:
"""Difference of dimensions.
Raises InconclusiveDimensionOperation if the result cannot be represented
by the same DimSize in all contexts.
"""
return d1 - d2
def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
"""Computes integer "i" such that i * size(s2) == size(s1).
Raise InconclusiveDimensionOperation if there is no such integer for all
contexts,
"""
sz1 = int(np.prod(s1))
sz2 = int(np.prod(s2))
if sz1 == 0 and sz2 == 0:
return 1
if sz1 % sz2:
2021-04-05 16:37:35 +03:00
raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
return sz1 // sz2
2021-04-04 16:23:24 +03:00
def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
"""(d - window_size) // window_stride + 1"""
return (d - window_size) // window_stride + 1
def dilate(self, d: DimSize, dilation: int) -> DimSize:
"""Implements `0 if d == 0 else 1 + dilation * (d - 1))`"""
return 0 if d == 0 else 1 + dilation * (d - 1)
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values. The goal of this change is to support shape polymorphism for operations such as average (which needs to divide by the size of a dimension) or indexing (which needs to normalize indices by comparing them with 0 and adding dimension size for negative indices). In both of these cases the size of a dimenion needs to be used as a value in the array computation. In general, the size of a dimension is used only to customize primitives. This change introduces `core.dim_as_value` which must be used on a dimension size before using it as a value in the array computation. E.g., ``` def average(x): return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0]) ``` This function is the identity function if the dimension size is constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`. Note that this does not change fundamentally the flavor of shape polymorphism supported in jax2tf: intermediate shapes and their values may depend on the input shapes, but never does a shape depend on the input values. In fact, one could have expressed the `dim_as_value` already: ``` def dim_as_value(d): jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,))) ``` We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`, `lax.dynamic_slice`, `lax.dynamic_update_slice` by using `core.dim_as_value` internally, but to fully roll-up the solution we need to make `core.dim_as_value` a public API and teach the users how to use it when they want to use shape polymorphism. Alternatively, perhaps there is a way to automatically convert dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
def as_value(self, d: DimSize):
"""Turns a dimension size into a JAX value that we can compute with."""
return d
_dimension_handler_int = DimensionHandler()
_SPECIAL_DIMENSION_HANDLERS: Dict[type, DimensionHandler] = {}
def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple[DimSize, ...]]:
"""Finds the handler for the given dimensions; also returns the canonical dimensions.
A dimension is canonical if it is a Python integer scalar, or has a type
registered in _SPECIAL_DIMENSION_HANDLERS.
2021-04-05 16:37:35 +03:00
"""
special_handlers = set()
canonical = []
for d in dlist:
handler = _SPECIAL_DIMENSION_HANDLERS.get(type(d))
if handler:
special_handlers.add(handler)
canonical.append(d)
else:
try:
canonical.append(operator.index(d))
except TypeError:
raise _invalid_shape_error(dlist)
if len(special_handlers) > 1:
msg = (f"Dimension size operation involves multiple special dimension types {dlist}")
raise ValueError(msg)
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)
def is_constant_dim(d: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d)
return handler.is_constant(*ds)
2021-04-05 16:37:35 +03:00
def symbolic_equal_dim(d1: DimSize, d2: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.symbolic_equal(*ds)
2021-04-05 16:37:35 +03:00
def symbolic_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
handler, ds = _dim_handler_and_canonical(d1, *dlist)
return any([handler.symbolic_equal(ds[0], d) for d in ds[1:]])
2021-04-05 16:37:35 +03:00
def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:
return (len(s1) == len(s2) and
all(map(symbolic_equal_dim, s1, s2)))
2021-04-05 16:37:35 +03:00
def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.greater_equal(*ds)
2021-04-05 16:37:35 +03:00
def greater_equal_shape(s1: Shape, s2: Shape) -> bool:
return all(map(greater_equal_dim, s1, s2))
2021-04-04 17:05:18 +03:00
2021-04-05 16:37:35 +03:00
def sum_dim(*ds: DimSize) -> DimSize:
handler, ds = _dim_handler_and_canonical(*ds)
return handler.sum(*ds)
2021-04-04 17:05:18 +03:00
2021-04-05 16:37:35 +03:00
def sum_shapes(*ss: Shape) -> Shape:
return tuple(map(sum_dim, *ss))
2021-04-05 16:37:35 +03:00
def diff_dim(d1: DimSize, d2: DimSize) -> DimSize:
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.diff(*ds)
2021-04-05 16:37:35 +03:00
def diff_shape(s1: Shape, s2: Shape) -> Shape:
return tuple(map(diff_dim, s1, s2))
2021-04-05 16:37:35 +03:00
def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize:
"""Returns an integer "i" s.t., i * size(s2) == size(s1).
Raises if there is no such integer."""
s1 = s1 or (1,)
s2 = s2 or (1,)
handler, ds = _dim_handler_and_canonical(*s1, *s2)
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
2021-04-05 16:37:35 +03:00
def same_shape_sizes(s1: Shape, s2: Shape) -> bool:
return 1 == divide_shape_sizes(s1, s2)
def is_empty_shape(s: Shape) -> bool:
return any(symbolic_equal_dim(d, 0) for d in s)
2021-04-05 16:37:35 +03:00
def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
2021-04-04 16:23:24 +03:00
"""Implements `0 if d == 0 else 1 + dilation * (d - 1))`"""
handler, ds = _dim_handler_and_canonical(d, dilation)
return handler.dilate(*ds)
2021-04-04 16:23:24 +03:00
2021-04-05 16:37:35 +03:00
def dilate_shape(s: Shape, dilations: Sequence[int]) -> Shape:
return tuple(map(dilate_dim, s, dilations))
2021-04-04 16:23:24 +03:00
2021-04-05 16:37:35 +03:00
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
handler, ds = _dim_handler_and_canonical(d, window_size, window_stride)
return handler.stride(*ds)
2021-04-04 16:23:24 +03:00
2021-04-05 16:37:35 +03:00
def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape:
2021-04-04 16:23:24 +03:00
"""(s - window_size) // window_stride + 1"""
return tuple(map(stride_dim, s, window_size, window_stride))
2021-04-04 16:23:24 +03:00
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values. The goal of this change is to support shape polymorphism for operations such as average (which needs to divide by the size of a dimension) or indexing (which needs to normalize indices by comparing them with 0 and adding dimension size for negative indices). In both of these cases the size of a dimenion needs to be used as a value in the array computation. In general, the size of a dimension is used only to customize primitives. This change introduces `core.dim_as_value` which must be used on a dimension size before using it as a value in the array computation. E.g., ``` def average(x): return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0]) ``` This function is the identity function if the dimension size is constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`. Note that this does not change fundamentally the flavor of shape polymorphism supported in jax2tf: intermediate shapes and their values may depend on the input shapes, but never does a shape depend on the input values. In fact, one could have expressed the `dim_as_value` already: ``` def dim_as_value(d): jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,))) ``` We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`, `lax.dynamic_slice`, `lax.dynamic_update_slice` by using `core.dim_as_value` internally, but to fully roll-up the solution we need to make `core.dim_as_value` a public API and teach the users how to use it when they want to use shape polymorphism. Alternatively, perhaps there is a way to automatically convert dimension polynomials to values when passed to the lax primitives.
2021-07-16 20:01:22 +03:00
def dimension_as_value(d: DimSize):
"""Turns a dimension size into a JAX value that we can compute with.
This is the identity function for constant dimensions."""
handler, ds = _dim_handler_and_canonical(d)
return handler.as_value(*ds)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
if type(dim) in _SPECIAL_DIMENSION_HANDLERS:
2020-03-09 09:14:23 +00:00
return dim
else:
return operator.index(dim)
def canonicalize_shape(shape: Shape, context: str="") -> Shape:
2020-03-09 09:14:23 +00:00
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
2020-03-09 09:14:23 +00:00
"""
try:
return tuple(map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
f: a Python value that represents a dimension.
Returns:
A canonical dimension value.
"""
return canonicalize_shape((d,), context)[0]
def _invalid_shape_error(shape: Shape, context: str=""):
2020-03-09 09:14:23 +00:00
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if context:
msg += f" {context}."
2020-03-09 09:14:23 +00:00
if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
return TypeError(msg)
2020-03-09 09:14:23 +00:00
# ------------------- Named shapes -------------------
class NamedShape:
def __init__(self, *args, **kwargs):
self.__positional = canonicalize_shape(args)
# TODO: Assert that kwargs match axis env?
self.__named = dict(kwargs)
@property
def rank(self):
return len(self.__positional) + len(self.__named)
@property
def positional_rank(self):
return len(self.__positional)
@property
def named_rank(self):
return len(self.__named)
@property
def positional(self):
return self.__positional
@property
def names(self):
return self.__named.keys()
@property
def named_sizes(self):
2021-04-16 14:20:25 +01:00
return self.__named.values()
@property
def named_items(self):
return self.__named.items()
def __getitem__(self, idx):
try:
idx = operator.index(idx)
return self.__positional[idx]
except TypeError:
pass
return self.__named[idx]
@property
def total(self):
total = 1
for s in self.__positional: total *= s
for s in self.__named.values(): total *= s
return total
def __str__(self):
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")
def __eq__(self, other):
if isinstance(other, NamedShape):
return (self.__positional, self.__named) == (other.__positional, other.__named)
if isinstance(other, tuple):
return not self.__named and self.__positional == other
raise TypeError(f"NamedShape doesn't support comparisons with {type(other)}")
def __hash__(self):
return hash((self.__positional, tuple(self.__named.items())))
2021-04-05 16:37:35 +03:00
def join_named_shapes(*named_shapes):
result = {}
for named_shape in named_shapes:
for name, size in named_shape.items():
if result.setdefault(name, size) != size:
raise TypeError(
f"Axis name {name} used with inconsistent sizes: {result[name]} != {size}")
return result
2021-04-05 16:37:35 +03:00
# TODO: Make canonicalize_shape return named shapes?
def as_named_shape(shape) -> NamedShape:
if isinstance(shape, NamedShape):
return shape
return NamedShape(*shape)
# ------------------- Call -------------------
2018-11-17 18:03:33 -08:00
def apply_todos(todos, outs):
todos_list = list(todos)
while todos_list:
outs = map(full_lower, todos_list.pop()(outs))
return outs
2018-11-17 18:03:33 -08:00
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
class _IgnoreElemList(list):
"""Compares equal to all other _ignore_elem_lists."""
def __hash__(self): return 0
def __eq__(self, other):
return type(other) is _IgnoreElemList
@lu.transformation_with_aux
def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
level: int, params_tuple: tuple, out_axes_transforms, *args):
outs = yield args, {}
params = dict(params_tuple)
2018-11-17 18:03:33 -08:00
todo = []
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
assert not out_axes_transforms
while True:
tracers = [x for x in outs if isinstance(x, Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=lambda x: x._trace.level)
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, cur_todo = primitive.post_process(trace, outs, params)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
if isinstance(primitive, MapPrimitive):
cur_todo, out_axes_transform = cur_todo
out_axes_transforms.append(out_axes_transform)
2018-11-17 18:03:33 -08:00
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable
2018-11-17 18:03:33 -08:00
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
2020-09-15 08:06:46 -07:00
fun, *args, **params):
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
out_axes_transforms = _IgnoreElemList()
if primitive.map_primitive:
out_axes_thunk = params['out_axes_thunk']
# The new thunk depends deterministically on the old thunk and the wrapped function.
# Any caching already has to include the wrapped function as part of the key, so we
# only use the previous thunk for equality checks.
@as_hashable_function(closure=out_axes_thunk)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def new_out_axes_thunk():
out_axes = out_axes_thunk()
for t in out_axes_transforms:
out_axes = t(out_axes)
return out_axes
params = dict(params, out_axes_thunk=new_out_axes_thunk)
params_tuple = tuple(params.items())
2018-11-17 18:03:33 -08:00
top_trace = find_top_trace(args)
2020-09-15 08:06:46 -07:00
fun, env_trace_todo = process_env_traces(
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
fun, primitive, top_trace and top_trace.level,
params_tuple, out_axes_transforms)
2020-09-15 08:06:46 -07:00
tracers = map(top_trace.full_raise, args)
outs = primitive.process(top_trace, fun, tracers, params)
2020-09-15 08:06:46 -07:00
return map(full_lower, apply_todos(env_trace_todo(), outs))
2018-11-17 18:03:33 -08:00
class CallPrimitive(Primitive):
multiple_results = True
call_primitive = True
def bind(self, fun, *args, **params):
return call_bind(self, fun, *args, **params)
def process(self, trace, fun, tracers, params):
return trace.process_call(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
return trace.post_process_call(self, out_tracers, params)
2018-11-17 18:03:33 -08:00
def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
with new_sublevel():
return f.call_wrapped(*args)
2018-11-17 18:03:33 -08:00
call_p = CallPrimitive('call')
call = call_p.bind
2018-11-17 18:03:33 -08:00
call_p.def_impl(call_impl)
2020-11-12 17:36:46 -08:00
named_call_p = CallPrimitive('named_call')
named_call_p.def_impl(call_impl)
# ------------------- Map -------------------
def mapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue:
handler, _ = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis, aval)
else:
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")
def unmapped_aval(size: int, axis_name, axis: int, aval: AbstractValue) -> AbstractValue:
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis_name, axis, aval)
else:
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
def _map_unit(*_) -> AbstractUnit:
return abstract_unit
def _map_shaped_array(size: int, axis: int, aval: ShapedArray) -> ShapedArray:
assert aval.shape[axis] == size
# TODO: Extend the named shape
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
named_shape=aval.named_shape)
def _unmap_shaped_array(size: int, axis_name, axis: int, aval: ShapedArray) -> ShapedArray:
named_shape = dict(aval.named_shape)
# TODO: Make this mandatory
named_shape.pop(axis_name, None)
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
named_shape=named_shape)
AvalMapHandlerPair = Tuple[Callable, Callable]
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
AbstractUnit: (_map_unit, _map_unit),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
}
class MapPrimitive(Primitive):
multiple_results = True
map_primitive = True
def bind(self, fun, *args, **params):
assert len(params['in_axes']) == len(args)
return call_bind(self, fun, *args, **params)
def process(self, trace, fun, tracers, params):
return trace.process_map(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
return trace.post_process_map(self, out_tracers, params)
2018-11-17 18:03:33 -08:00
@contextmanager
def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
2020-09-15 08:06:46 -07:00
frame = AxisEnvFrame(axis_name, size, tag)
thread_local_state.trace_state.axis_env.append(frame)
try:
yield
finally:
thread_local_state.trace_state.axis_env.pop()
@contextmanager
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
thread_local_state.trace_state.axis_env.extend(frames)
try:
yield
finally:
for _ in frames:
thread_local_state.trace_state.axis_env.pop()
# When a mapped function is given no axis name, we generate a name object based
# on the id of the function object. Collisions aren't important because this
# name can't be used in collectives, as user code never gets a ref to this
# object. We don't want to use the function object itself because that might
# persist references to the function object.
# TODO(mattjj): revisit this unique axis name strategy
@total_ordering
class _TempAxisName:
def __init__(self, obj):
self.id = id(obj)
def __repr__(self):
return f'<axis {hex(self.id)}>'
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
return type(other) is _TempAxisName and self.id == other.id
def __lt__(self, other):
return type(other) is _TempAxisName and self.id < other.id
2020-09-15 08:06:46 -07:00
def axis_frame(axis_name):
frames = thread_local_state.trace_state.axis_env
for frame in reversed(frames):
if frame.name == axis_name:
return frame
named_axes = [frame.name for frame in reversed(frames)
if not isinstance(frame.name, _TempAxisName)]
raise NameError(
f'unbound axis name: {axis_name}. The following axis names (e.g. defined '
f'by pmap) are available to collective operations: {named_axes}')
2020-09-15 08:06:46 -07:00
ParamDict = Dict[str, Any]
AxisSubst = Callable[[AxisName], Tuple[AxisName, ...]]
class NameGatheringSubst:
def __init__(self):
self.axis_names = set()
def __call__(self, axis_name):
self.axis_names.add(axis_name)
return (axis_name,)
def used_axis_names(primitive: Primitive, params: ParamDict) -> Set[AxisName]:
subst = NameGatheringSubst()
subst_axis_names(primitive, params, subst)
return subst.axis_names
def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst, traverse: bool = True) -> ParamDict:
if primitive in axis_substitution_rules:
return axis_substitution_rules[primitive](params, subst, traverse)
if not traverse:
return params
# Default implementation: substitute names in all jaxpr parameters
if isinstance(primitive, MapPrimitive):
def shadowed_subst(name):
return (name,) if name == params['axis_name'] else subst(name)
else:
shadowed_subst = subst
jaxpr_params = [(n, v) for n, v in params.items() if isinstance(v, (Jaxpr, ClosedJaxpr))]
if not jaxpr_params:
return params
new_params = dict(params)
for name, jaxpr in jaxpr_params:
new_params[name] = subst_axis_names_jaxpr(jaxpr, shadowed_subst)
return new_params
class DuplicateAxisNameError(Exception):
def __init__(self, var):
self.var = var
self.eqn = None
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> Var:
# Var identity is load-bearing, so we can't have duplicates!
if v is unitvar: return v
if v is dropvar: return v
assert v not in var_map
if not hasattr(v.aval, 'named_shape'):
var_map[v] = v
return v
names = tuple(it.chain.from_iterable(subst(name) for name in v.aval.named_shape))
named_shape = {name: axis_frame(name).size for name in names}
if len(named_shape) != len(names):
raise DuplicateAxisNameError(v)
new_v = Var(v.count, v.suffix, v.aval.update(named_shape=named_shape))
var_map[v] = new_v
return new_v
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: Dict[Var, Var]) -> JaxprEqn:
invars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
try:
outvars = [subst_axis_names_var(v, subst, var_map) for v in eqn.outvars]
except DuplicateAxisNameError as e:
e.eqn = eqn
raise
params = subst_axis_names(eqn.primitive, eqn.params, subst)
2021-06-16 11:10:42 -07:00
return new_jaxpr_eqn(invars, outvars, eqn.primitive, params, eqn.source_info)
def do_subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
consts = None
if isinstance(jaxpr, ClosedJaxpr):
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
var_map: Dict[Var, Var] = {unitvar: unitvar}
invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars]
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars]
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns]
outvars: List[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars]
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns)
if consts is not None:
return ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr
@cache()
def used_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr]):
subst = NameGatheringSubst()
do_subst_axis_names_jaxpr(jaxpr, subst)
return frozenset(subst.axis_names)
def subst_axis_names_jaxpr(jaxpr: Union[Jaxpr, ClosedJaxpr], subst: AxisSubst):
if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it!
subst.axis_names |= used_axis_names_jaxpr(jaxpr)
return jaxpr
return do_subst_axis_names_jaxpr(jaxpr, subst)
axis_substitution_rules: Dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {}
# ------------------- AxisPrimitive -------------------
# Primitives that store axis names in params and want those axis names to
# participate in dispatch should subclass AxisPrimitive.
class AxisPrimitive(Primitive):
_dispatch_on_params = True
2020-04-15 11:05:32 -07:00
# ------------------- Jaxpr checking -------------------
def typecheck(aval: AbstractValue, x) -> bool:
return typecompat(aval, get_aval(x))
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
"""Determine whether `aval` conforms to `aval_ref`.
Ignores weak_type and named_shape, other than to check that an axis name isn't
used with different sizes.
"""
try:
return typematch(aval_ref, lattice_join(aval_ref, aval))
except TypeError:
return False
def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
"""Determine whether `aval1` and `aval2` are equivalent.
Ignores weak_type and named_shape, other than to check that an axis name isn't
used with different sizes.
"""
if aval1 == aval2: return True
# unequal avals may still represent the same type, because type is represented
# by avals at the shaped level, and because weak type tags and (for now) named
# shape components aren't considered part of the type
if isinstance(aval1, ShapedArray) and isinstance(aval2, ShapedArray):
# a bonus check for whether any named axes have inconsistent sizes
join_named_shapes(aval1.named_shape, aval2.named_shape)
return (raise_to_shaped(aval1, weak_type=False).strip_named_shape() ==
raise_to_shaped(aval2, weak_type=False).strip_named_shape())
2020-06-24 15:31:33 -07:00
class JaxprTypeError(TypeError): pass
def typecheck_assert(pred, msg):
if not pred:
raise JaxprTypeError(msg)
custom_typechecks: Dict[Primitive, Callable] = {}
def check_jaxpr(jaxpr: Jaxpr):
"""Checks well-formedness of a jaxpr.
2020-04-15 17:02:48 -07:00
Specifically, check that:
- variables that are read are bound beforehand
- variables are typed equally throughout a jaxpr
- variable type annotations are compatible with their binding expression
2020-11-13 18:00:33 -08:00
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
otherwise.
"""
try:
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
2020-06-24 15:31:33 -07:00
except JaxprTypeError as e:
if len(e.args) == 2:
msg, eqn_idx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqn_idx - 10, eqn_idx + 10))
else:
msg, = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20))
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
raise JaxprTypeError(msg) from None
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
2020-04-14 22:22:35 -07:00
def read(v: Atom) -> AbstractValue:
if isinstance(v, Literal):
return raise_to_shaped(get_aval(v.val))
else:
2020-06-24 15:31:33 -07:00
typecheck_assert(v in env, f"Variable '{v}' not defined")
return env[v]
2020-04-14 22:22:35 -07:00
def write(v: Var, a: AbstractValue) -> None:
2020-06-24 15:31:33 -07:00
typecheck_assert(v not in env, f"Variable '{v}' already bound")
2020-06-08 16:13:30 -07:00
if v is not dropvar:
2020-06-24 15:31:33 -07:00
typecheck_assert(typecompat(v.aval, a),
f"Variable '{v}' inconsistently typed as {a}, "
f"bound as {v.aval}")
2020-06-08 16:13:30 -07:00
env[v] = a
env : Dict[Var, AbstractValue] = {}
write(unitvar, abstract_unit)
map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars])
map(write, jaxpr.invars, in_avals)
for eqn_idx, eqn in enumerate(jaxpr.eqns):
prim = eqn.primitive
try:
in_avals = map(read, eqn.invars)
typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals),
"Equation given ConcreteArray type inputs")
if prim in custom_typechecks:
out_avals = custom_typechecks[prim](*in_avals, **eqn.params)
if out_avals is None:
out_avals = [v.aval for v in eqn.outvars]
elif prim.call_primitive:
out_avals = check_call(prim, in_avals, eqn.params)
elif prim.map_primitive:
out_avals = check_map(prim, in_avals, eqn.params)
else:
out_avals = check_eqn(prim, in_avals, eqn.params)
map(write, eqn.outvars, out_avals)
2020-06-24 15:31:33 -07:00
except JaxprTypeError as e:
msg, = e.args
src = source_info_util.summarize(eqn.source_info)
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn))),
f"from source: {src}"])
raise JaxprTypeError(msg, eqn_idx) from None
map(read, jaxpr.outvars)
def check_eqn(prim, in_avals, params):
for jaxpr in jaxprs_in_params(params):
check_jaxpr(jaxpr)
out_avals = prim.abstract_eval(*in_avals, **params)
if not prim.multiple_results:
out_avals = [out_avals]
return out_avals
2018-11-17 18:03:33 -08:00
def check_call(prim, in_avals, params):
2020-06-24 15:31:33 -07:00
typecheck_assert("call_jaxpr" in params,
f"Call primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
# These checks also happen in recursive call, but give better errors here.
2020-06-24 15:31:33 -07:00
typecheck_assert(len(in_avals) == len(call_jaxpr.invars),
f"Call primitive {prim} with {len(call_jaxpr.invars)} "
f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
f"inputs")
binder_avals = [v.aval for v in call_jaxpr.invars]
for binder_aval, in_aval in zip(binder_avals, in_avals):
2020-06-24 15:31:33 -07:00
typecheck_assert(typecompat(binder_aval, in_aval),
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
_check_jaxpr(call_jaxpr, in_avals)
out_avals = [v.aval for v in call_jaxpr.outvars]
return out_avals
def check_map(prim, in_avals, params):
2020-06-24 15:31:33 -07:00
typecheck_assert("call_jaxpr" in params,
f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
2020-06-24 15:31:33 -07:00
typecheck_assert("axis_size" in params,
f"Map primitive {prim} missing 'axis_size' parameter")
axis_size = params["axis_size"]
typecheck_assert("axis_name" in params,
f"Map primitive {prim} missing 'axis_name' parameter")
axis_name = params["axis_name"]
typecheck_assert("in_axes" in params,
f"Map primitive {prim} missing 'in_axes' parameter")
in_axes = params["in_axes"]
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
typecheck_assert("out_axes" in params,
f"Map primitive {prim} missing 'out_axes' parameter")
out_axes = params["out_axes"]
binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval)
if in_axis is not None else v.aval
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
for binder_aval, in_aval in zip(binder_avals, in_avals):
2020-06-24 15:31:33 -07:00
typecheck_assert(typecompat(binder_aval, in_aval),
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
mapped_avals = [mapped_aval(axis_size, in_axis, aval)
if in_axis is not None else aval
for aval, in_axis in zip(in_avals, in_axes)]
2021-03-05 17:59:16 +00:00
with extend_axis_env(params['axis_name'], axis_size, None):
_check_jaxpr(call_jaxpr, mapped_avals)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
for aval, out_axis in zip(mapped_out_avals, out_axes)]
return out_avals
2020-04-15 11:05:32 -07:00
# ------------------- Jaxpr printed representation -------------------
def pp_vars(vs: Sequence[Any], *, print_shapes: bool = False) -> pp.Doc:
if print_shapes:
return pp.nest(2, pp.group(
pp.join(pp.brk(), [
pp.text(str(v)) +
pp.dim(pp.text(":" + v.aval.str_short(short_dtypes=True)))
for v in vs
])
))
else:
return pp.nest(2, pp.group(
pp.join(pp.brk(), [pp.text(str(v)) for v in vs])
))
2018-11-17 18:03:33 -08:00
def pp_kv_pair(k:str, v: Any) -> pp.Doc:
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
pp_v = pp_jaxprs(v)
elif isinstance(v, Jaxpr):
pp_v = pp_jaxpr(v)
elif isinstance(v, ClosedJaxpr):
pp_v = pp_jaxpr(v.jaxpr)
else:
pp_v = pp.text(str(v))
return pp.text(f'{k}=') + pp_v
def pp_kv_pairs(kv_pairs) -> pp.Doc:
if not kv_pairs:
return pp.nil()
return pp.group(
pp.nest(2, pp.concat([
pp.text("["), pp.brk(""),
pp.join(pp.brk(), [pp_kv_pair(k, v) for k, v in kv_pairs])
]))
+ pp.brk("") + pp.text("]")
)
def pp_eqn(eqn, *, print_shapes=True, source_info=False) -> pp.Doc:
lhs = pp_vars(eqn.outvars, print_shapes=print_shapes)
annotation = (source_info_util.summarize(eqn.source_info)
if source_info else None)
return pp.concat([
lhs, pp.text(" = ", annotation=annotation), pp.text(eqn.primitive.name),
pp_kv_pairs(sorted(eqn.params.items())),
pp.text(" ") + pp_vars(eqn.invars)
])
def pp_eqns(eqns, *, print_shapes=True, source_info=False) -> pp.Doc:
return pp.join(
pp.brk("; "),
map(partial(pp_eqn, print_shapes=print_shapes, source_info=source_info),
eqns))
def pp_eqn_compact(primitive_name: str, params: Dict) -> pp.Doc:
filtered_params = {k: v for k, v in params.items()
if (k != 'branches' and
not isinstance(v, (Jaxpr, ClosedJaxpr)))}
return pp.text(primitive_name) + pp_kv_pairs(sorted(filtered_params.items()))
def pp_jaxpr_skeleton(jaxpr, eqns_pp, *, print_shapes=True) -> pp.Doc:
str_outvars = str(tuple(jaxpr.outvars))
return pp.group(pp.nest(2, pp.concat([
pp.text("{ "), pp.bright(pp.text("lambda ")),
pp_vars(jaxpr.constvars, print_shapes=print_shapes),
pp.text("; "), pp_vars(jaxpr.invars, print_shapes=print_shapes),
pp.text(". "), pp.bright(pp.text("let")),
pp.nest(2, pp.brk() + eqns_pp), pp.brk(),
pp.bright(pp.text("in")),
pp.text(f" {str_outvars}")
])) + pp.text(" }"))
def pp_jaxpr(jaxpr, *, print_shapes=True, source_info=False) -> pp.Doc:
pps = pp_eqns(jaxpr.eqns, print_shapes=print_shapes, source_info=source_info)
return pp_jaxpr_skeleton(jaxpr, pps, print_shapes=print_shapes)
def pp_jaxprs(jaxprs) -> pp.Doc:
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
return pp.group(pp.nest(2, pp.concat([
pp.text('('), pp.brk(""), pp.join(pp.brk(), map(pp_jaxpr, jaxprs))]))
+ pp.brk("") + pp.text(')')
)
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, print_shapes=True,
source_info: bool = False) -> pp.Doc:
lo = max(lo, 0)
hi = max(lo, min(hi, len(jaxpr.eqns)))
eqns = jaxpr.eqns[lo:hi]
pps = []
if len(eqns) == 0 and len(jaxpr.eqns) != 0:
pps.append(pp.text('...'))
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
else:
if lo != 0:
pps.append(pp.text('...'))
pps.extend(map(partial(pp_eqn, print_shapes=print_shapes,
source_info=source_info), eqns))
if hi != len(jaxpr.eqns):
pps.append(pp.text('...'))
return pp_jaxpr_skeleton(jaxpr, pp.join(pp.brk("; "), pps),
print_shapes=print_shapes)