rocm_jax/jax/core.py

1665 lines
54 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-03-09 09:14:23 +00:00
import operator
2018-11-17 18:03:33 -08:00
from operator import attrgetter
from contextlib import contextmanager, suppress
from collections import namedtuple
from functools import total_ordering
import itertools as it
2018-11-17 18:03:33 -08:00
from weakref import ref
import threading
2018-11-17 18:03:33 -08:00
import types
2020-06-02 10:26:43 -04:00
from typing import (Any, Callable, ClassVar, Dict, Generator,
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
2020-09-15 08:06:46 -07:00
Type, Union, cast)
2018-11-17 18:03:33 -08:00
import numpy as np
2020-03-09 09:14:23 +00:00
from . import dtypes
from .config import FLAGS, config
2018-11-17 18:03:33 -08:00
from . import linear_util as lu
from jax._src import source_info_util
2020-03-09 09:14:23 +00:00
from .util import safe_zip, safe_map, partial, curry, prod, partialmethod
2020-06-06 10:51:34 -07:00
from .pprint_util import pp, vcat, PrettyPrint
2018-11-17 18:03:33 -08:00
from ._src import traceback_util
traceback_util.register_exclusion(__file__)
# TODO(dougalm): compilation cache breaks the leak detector. Consisder solving.
2018-11-17 18:03:33 -08:00
check_leaks = False
# Disables internal invariant checks
skip_checks = not FLAGS.jax_enable_checks # not __debug__ # google doesn't use -O
@contextmanager
def skipping_checks():
"""Context manager for temporarily disabling checks."""
global skip_checks
old_value, skip_checks = skip_checks, True
try:
yield
finally:
skip_checks = old_value
2018-11-17 18:03:33 -08:00
zip = safe_zip
map = safe_map
# -------------------- jaxprs --------------------
class Jaxpr:
constvars: List['Var']
invars: List['Var']
outvars: List['Atom']
eqns: List['JaxprEqn']
def __init__(self, constvars: Sequence['Var'], invars: Sequence['Var'],
outvars: Sequence['Atom'], eqns: Sequence['JaxprEqn']):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
replaced with such variables while scalar constants are kept inline.
invars: list of input variables. Together, `constvars` and `invars` are
the inputs to the Jaxpr.
outvars: list of output variables.
eqns: list of equations.
"""
self.constvars = list(constvars)
self.invars = list(invars)
self.outvars = list(outvars)
self.eqns = list(eqns)
2018-11-17 18:03:33 -08:00
def __str__(self):
return str(pp_jaxpr(self))
__repr__ = __str__
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, Jaxpr):
yield v
elif isinstance(v, ClosedJaxpr):
yield v.jaxpr
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
for eqn in jaxpr.eqns:
yield from jaxprs_in_params(eqn.params)
class ClosedJaxpr:
jaxpr: Jaxpr
consts: List['Any']
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
self.jaxpr = jaxpr
self.consts = list(consts)
@property
def in_avals(self):
return [v.aval for v in self.jaxpr.invars]
2020-03-09 09:14:23 +00:00
@property
def out_avals(self):
return [v.aval for v in self.jaxpr.outvars]
@property
def literals(self):
return self.consts # backwards compatible alias
2020-11-03 12:11:03 +00:00
def map_jaxpr(self, f):
return ClosedJaxpr(f(self.jaxpr), self.consts)
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
2018-11-17 18:03:33 -08:00
@curry
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
class JaxprEqn(NamedTuple):
invars: List['Atom']
outvars: List['Var']
primitive: 'Primitive'
2020-06-02 10:26:43 -04:00
params: Dict[str, Any]
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
source_info: Optional[source_info_util.Traceback]
def __repr__(self): return str(pp_eqn(self)).rstrip()
def new_jaxpr_eqn(invars, outvars, primitive, params, source_info=None):
return JaxprEqn(invars, outvars, primitive, params, source_info)
2020-01-06 13:29:21 +00:00
@total_ordering
class Var:
# TODO(frostig,mattjj): We don't override __eq__ or __hash__, so comparison is
# by object id, but pretty printing might collide.
count: int
suffix: str
aval: 'AbstractValue'
def __init__(self, count: int, suffix: str, aval: 'AbstractValue'):
self.count = count
self.suffix = suffix
2020-03-09 09:14:23 +00:00
self.aval = raise_to_shaped(aval)
2020-01-06 13:29:21 +00:00
def __lt__(self, other):
if not isinstance(other, Var):
return NotImplemented
else:
return (self.count, self.suffix) < (other.count, other.suffix)
def __repr__(self):
rem = self.count
s = ''
while True:
rem, i = rem // 26, rem % 26
s = chr(97 + i % 26) + s
if not rem:
break
return s + self.suffix
def _jaxpr_vars(jaxpr):
return it.chain(
jaxpr.invars, jaxpr.constvars,
(v for eqn in jaxpr.eqns for v in eqn.outvars))
def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
suffix: str = '') -> Callable[['AbstractValue'], Var]:
"""Produce distinct variables, printed with the optional suffix.
If `jaxprs` is provided, the variables produced will be distinct from those in
any of the given jaxprs.
"""
if jaxprs is None:
start = 0
else:
all_vars = it.chain.from_iterable(_jaxpr_vars(j) for j in jaxprs)
start = 1 + max((v.count for v in all_vars), default=-1)
counter = it.count(start=start)
2020-03-09 09:14:23 +00:00
return lambda aval: Var(next(counter), suffix, aval)
2020-06-08 16:13:30 -07:00
# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
# the assignment is dropped, i.e. that an expression's output value will never
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
count = -1
suffix = ''
def __init__(self): pass
@property
def aval(self): return abstract_unit
def __repr__(self): return '_'
dropvar = DropVar()
class Literal:
__slots__ = ["val", "hash"]
val: Any
hash: Optional[int]
def __init__(self, val):
self.val = val
try:
self.hash = hash(val)
except TypeError:
if type(val) in literalable_types:
try:
self.hash = hash((val.item(), val.dtype))
except (TypeError, AttributeError, ValueError):
self.hash = None
2020-03-09 09:14:23 +00:00
@property
def aval(self):
return raise_to_shaped(get_aval(self.val))
def __hash__(self):
assert False
def __repr__(self):
if hasattr(self, 'hash'):
return '{}'.format(self.val)
else:
return 'Literal(val={})'.format(self.val)
literalable_types: Set[type] = set()
Atom = Union[Var, Literal]
class Primitive:
name: str
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
multiple_results = False # set for multi-output primitives
call_primitive = False # set for call primitives processed in final style
map_primitive = False # set for map primitives processed in final style
2019-07-27 10:43:40 -04:00
def __init__(self, name: str):
2018-11-17 18:03:33 -08:00
self.name = name
def __repr__(self):
return '{}'.format(self.name)
2020-09-15 08:06:46 -07:00
def bind(self, *args, **params):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
2018-11-17 18:03:33 -08:00
top_trace = find_top_trace(args)
tracers = map(top_trace.full_raise, args)
2020-09-15 08:06:46 -07:00
out = top_trace.process_primitive(self, tracers, params)
return map(full_lower, out) if self.multiple_results else full_lower(out)
2018-11-17 18:03:33 -08:00
def def_impl(self, impl):
self.impl = impl
return impl
def def_abstract_eval(self, abstract_eval):
self.abstract_eval = abstract_eval
return abstract_eval
2018-11-17 18:03:33 -08:00
def def_custom_bind(self, bind):
self.bind = bind
return bind
def impl(self, *args, **params):
2018-11-17 18:03:33 -08:00
raise NotImplementedError("Evaluation rule for '{}' not implemented"
.format(self.name))
def abstract_eval(self, *args, **params):
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
.format(self.name))
2018-11-17 18:03:33 -08:00
# -------------------- lifting --------------------
# TODO(necula): this belongs next to pe.new_eqn_recipe, but is needed in
# core.py. Plan to move all these utilities to jaxpr.py.
def extract_call_jaxpr(
primitive: Primitive,
params: Dict[str, Any]) -> Tuple[Optional[Jaxpr], Dict[str, Any]]:
"""Extract the call primitive subjaxpr from the params.
Returns the subjaxpr and the params without the "call_jaxpr" value. If this is
not a call primitive then returns (None, params).
"""
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
if not (primitive.call_primitive or primitive.map_primitive):
return (None, params)
else:
assert "call_jaxpr" in params
new_params = dict(params)
del new_params["call_jaxpr"]
return (params["call_jaxpr"], new_params)
2018-11-17 18:03:33 -08:00
def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
2018-11-17 18:03:33 -08:00
def read(v):
if type(v) is Literal:
return v.val
else:
return env[v]
2018-11-17 18:03:33 -08:00
def write(v, val):
env[v] = val
env: Dict[Var, Any] = {}
2018-11-17 18:03:33 -08:00
write(unitvar, unit)
2019-08-21 13:53:57 -07:00
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
2018-11-17 18:03:33 -08:00
for eqn in jaxpr.eqns:
2019-07-27 10:43:40 -04:00
in_vals = map(read, eqn.invars)
call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params)
if call_jaxpr:
subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))]
else:
subfuns = []
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
with source_info_util.user_context(eqn.source_info):
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
2019-07-27 10:43:40 -04:00
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
return map(read, jaxpr.outvars)
2018-11-17 18:03:33 -08:00
# -------------------- tracing --------------------
class Trace:
__slots__ = ['main', 'level', 'sublevel']
main: 'MainTrace'
level: int
sublevel: 'Sublevel'
def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None:
self.main = main
self.level = main.level
2018-11-17 18:03:33 -08:00
self.sublevel = sublevel
def full_raise(self, val) -> 'Tracer':
2018-11-17 18:03:33 -08:00
if not isinstance(val, Tracer):
return self.pure(val)
val._assert_live()
2018-11-17 18:03:33 -08:00
level = self.level
sublevel = self.sublevel
if val._trace.main is self.main:
if val._trace.sublevel == sublevel:
2018-11-17 18:03:33 -08:00
return val
elif val._trace.sublevel < sublevel:
2018-11-17 18:03:33 -08:00
return self.sublift(val)
else:
raise escaped_tracer_error("Can't lift sublevels {} to {}"
.format(val._trace.sublevel, sublevel))
elif val._trace.level < level:
if val._trace.sublevel > sublevel:
raise escaped_tracer_error("Incompatible sublevel: {}, {}"
.format(val._trace, (level, sublevel)))
2018-11-17 18:03:33 -08:00
return self.lift(val)
elif val._trace.level > level:
raise escaped_tracer_error("Can't lift level {} to {}"
.format(val, self))
else: # val._trace.level == self.level:
raise escaped_tracer_error("Different traces at same level: {}, {}"
.format(val, self))
2018-11-17 18:03:33 -08:00
def pure(self, val):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def lift(self, tracer):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def sublift(self, tracer):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def __repr__(self):
return '{}(level={}/{})'.format(
self.__class__.__name__, self.level, self.sublevel)
2020-03-30 22:06:00 -07:00
def process_call(self, call_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_call to handle call-like "
"primitives")
raise NotImplementedError(msg)
def process_map(self, call_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_map to handle map-like "
"primitives")
raise NotImplementedError(msg)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
msg = (f"{type(self)} must override process_custom_jvp_call "
"to handle custom_jvp primitives")
raise NotImplementedError(msg)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
def escaped_tracer_error(detail=None):
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
"through global state from a previously traced function.\n"
"The functions being transformed should not save traced values to "
"global state.")
if detail:
msg += " Detail: {}.".format(detail)
return UnexpectedTracerError(msg)
class UnexpectedTracerError(Exception): pass
2018-11-17 18:03:33 -08:00
class Tracer:
2018-11-17 18:03:33 -08:00
__array_priority__ = 1000
__slots__ = ['_trace', '__weakref__']
2018-11-17 18:03:33 -08:00
def __array__(self, *args, **kw):
msg = ("The numpy.ndarray conversion method __array__() was called on "
f"the JAX Tracer object {self}.\n\n"
"This error can occur when a JAX Tracer object is passed to a raw "
"numpy function, or a method on a numpy.ndarray object. You might "
"want to check that you are using `jnp` together with "
"`import jax.numpy as jnp` rather than using `np` via "
"`import numpy as np`. If this error arises on a line that involves "
"array indexing, like `x[idx]`, it may be that the array being "
"indexed `x` is a raw numpy.ndarray while the indices `idx` are a "
"JAX Tracer instance; in that case, you can instead write "
"`jax.device_put(x)[idx]`.")
raise Exception(msg)
2018-11-17 18:03:33 -08:00
def __init__(self, trace: Trace):
self._trace = trace
2018-11-17 18:03:33 -08:00
def __iter__(self):
return iter(self.aval._iter(self))
def __len__(self):
return self.aval._len(self)
2018-11-17 18:03:33 -08:00
@property
def aval(self):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
def _assert_live(self) -> None:
pass # Override for liveness checking
Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
# Python looks up special methods only on classes, not instances. This means
# these methods needs to be defined explicitly rather than relying on
# __getattr__.
2018-11-17 18:03:33 -08:00
def __neg__(self): return self.aval._neg(self)
def __pos__(self): return self.aval._pos(self)
2018-11-17 18:03:33 -08:00
def __eq__(self, other): return self.aval._eq(self, other)
def __ne__(self, other): return self.aval._ne(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __le__(self, other): return self.aval._le(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __ge__(self, other): return self.aval._ge(self, other)
def __abs__(self): return self.aval._abs(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __sub__(self, other): return self.aval._sub(self, other)
def __rsub__(self, other): return self.aval._rsub(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __div__(self, other): return self.aval._div(self, other)
def __rdiv__(self, other): return self.aval._rdiv(self, other)
def __truediv__(self, other): return self.aval._truediv(self, other)
def __rtruediv__(self, other): return self.aval._rtruediv(self, other)
2018-11-17 18:03:33 -08:00
def __floordiv__(self, other): return self.aval._floordiv(self, other)
def __rfloordiv__(self, other): return self.aval._rfloordiv(self, other)
def __divmod__(self, other): return self.aval._divmod(self, other)
def __rdivmod__(self, other): return self.aval._rdivmod(self, other)
def __mod__(self, other): return self.aval._mod(self, other)
def __rmod__(self, other): return self.aval._rmod(self, other)
def __pow__(self, other): return self.aval._pow(self, other)
def __rpow__(self, other): return self.aval._rpow(self, other)
def __matmul__(self, other): return self.aval._matmul(self, other)
def __rmatmul__(self, other): return self.aval._rmatmul(self, other)
def __and__(self, other): return self.aval._and(self, other)
def __rand__(self, other): return self.aval._rand(self, other)
def __or__(self, other): return self.aval._or(self, other)
def __ror__(self, other): return self.aval._ror(self, other)
def __xor__(self, other): return self.aval._xor(self, other)
def __rxor__(self, other): return self.aval._rxor(self, other)
2019-02-15 14:09:06 -08:00
def __invert__(self): return self.aval._invert(self)
2018-11-17 18:03:33 -08:00
def __lshift__(self, other): return self.aval._lshift(self, other)
def __rlshift__(self, other): return self.aval._rlshift(self, other)
2018-11-17 18:03:33 -08:00
def __rshift__(self, other): return self.aval._rshift(self, other)
def __rrshift__(self, other): return self.aval._rrshift(self, other)
2018-11-17 18:03:33 -08:00
def __getitem__(self, idx): return self.aval._getitem(self, idx)
def __nonzero__(self): return self.aval._nonzero(self)
def __bool__(self): return self.aval._bool(self)
def __int__(self): return self.aval._int(self)
def __long__(self): return self.aval._long(self)
def __hex__(self): return self.aval._hex(self)
def __oct__(self): return self.aval._oct(self)
2020-09-15 08:06:46 -07:00
def __float__(self): return self.aval._float(self)
def __complex__(self): return self.aval._complex(self)
def __setitem__(self, idx, val):
raise TypeError("JAX 'Tracer' objects do not support item assignment")
2018-11-17 18:03:33 -08:00
Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
# NumPy also only looks up special methods on classes.
def __array_module__(self, types): return self.aval._array_module(self, types)
2018-11-17 18:03:33 -08:00
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert skip_checks or name != "aval"
try:
attr = getattr(self.aval, name)
except KeyError as err:
2018-11-17 18:03:33 -08:00
raise AttributeError(
"{} has no attribute {}".format(self.__class__.__name__, name)
) from err
2018-11-17 18:03:33 -08:00
else:
t = type(attr)
if t is aval_property:
return attr.fget(self)
elif t is aval_method:
return types.MethodType(attr.fun, self)
2018-11-17 18:03:33 -08:00
else:
return attr
def __repr__(self):
base = pp('Traced<{}>with<{}>'.format(self.aval, self._trace))
contents = self._contents()
if contents:
base += pp(' with ') >> vcat(pp('{} = '.format(name)) >> pp_payload
for name, pp_payload in contents)
return str(base)
def _contents(self):
try:
return [(name, pp(repr(getattr(self, name)))) for name in self.__slots__]
except AttributeError:
return ()
2018-11-17 18:03:33 -08:00
def __copy__(self):
return self
def __deepcopy__(self, unused_memo):
return self
2020-09-15 08:06:46 -07:00
def _origin_msg(self) -> str:
return ""
2018-11-17 18:03:33 -08:00
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
class EvalTrace(Trace):
# See comments in https://github.com/google/jax/pull/3370
def pure(self, x): return x
lift = sublift = pure
def process_primitive(self, primitive, tracers, params):
return primitive.impl(*tracers, **params)
def process_call(self, primitive, f, tracers, params):
return primitive.impl(f, *tracers, **params)
process_map = process_call
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
del primitive, jvp # Unused.
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)
class MainTrace:
level: int
trace_type: Type[Trace]
payload: Dict[str, Any]
def __init__(self, level, trace_type, **payload) -> None:
2018-11-17 18:03:33 -08:00
self.level = level
self.trace_type = trace_type
self.payload = payload
2018-11-17 18:03:33 -08:00
def __repr__(self) -> str:
return "MainTrace({},{})".format(self.level, self.trace_type.__name__)
2018-11-17 18:03:33 -08:00
def __hash__(self) -> int:
2018-11-17 18:03:33 -08:00
return hash((self.level, self.trace_type))
def __eq__(self, other: object) -> bool:
return (isinstance(other, MainTrace) and
self.level == other.level and
self.trace_type == other.trace_type and
self.payload == other.payload)
def with_cur_sublevel(self):
return self.trace_type(self, cur_sublevel(), **self.payload)
2018-11-17 18:03:33 -08:00
class TraceStack:
# See comments in https://github.com/google/jax/pull/3370
upward: List[MainTrace]
downward: List[MainTrace]
2018-11-17 18:03:33 -08:00
def __init__(self):
2020-09-15 08:06:46 -07:00
eval_trace = MainTrace(0, EvalTrace)
self.stack = [eval_trace]
self.dynamic = eval_trace
2018-11-17 18:03:33 -08:00
2020-09-15 08:06:46 -07:00
def next_level(self) -> int:
return len(self.stack)
2018-11-17 18:03:33 -08:00
2020-09-15 08:06:46 -07:00
def push(self, main_trace: MainTrace) -> None:
self.stack.append(main_trace)
2018-11-17 18:03:33 -08:00
2020-09-15 08:06:46 -07:00
def pop(self) -> None:
self.stack.pop()
2018-11-17 18:03:33 -08:00
def __repr__(self) -> str:
2020-09-15 08:06:46 -07:00
stack_str = map(' {}\n'.format, self.stack[::-1])
return f'Trace stack\n{stack_str}\n{self.dynamic}'
2018-11-17 18:03:33 -08:00
def copy(self):
2020-09-15 08:06:46 -07:00
new = self.__new__(TraceStack)
new.stack = self.stack[:]
new.dynamic = self.dynamic
return new
2018-11-17 18:03:33 -08:00
class Sublevel(int): pass
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
class TraceState:
trace_stack: TraceStack
substack: List[Sublevel]
2020-09-15 08:06:46 -07:00
axis_env: List[AxisEnvFrame]
def __init__(self) -> None:
self.trace_stack = TraceStack()
self.substack = [Sublevel(0)]
2020-09-15 08:06:46 -07:00
self.axis_env = []
def copy(self):
2020-09-15 08:06:46 -07:00
new = self.__new__(TraceState)
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
2020-09-15 08:06:46 -07:00
new.axis_env = self.axis_env[:]
return new
# The global state of the tracer is accessed by a thread-local object.
# This allows concurrent tracing in separate threads; passing traced objects
# between threads is forbidden.
class ThreadLocalState(threading.local):
def __init__(self):
self.trace_state = TraceState()
thread_local_state = ThreadLocalState()
def trace_state_clean() -> bool:
trace_state = thread_local_state.trace_state
return (trace_state.substack == [Sublevel(0)] and
trace_state.axis_env == [] and
trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and
trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace))
def reset_trace_state() -> bool:
"Reset the global trace state and return True if it was already clean."
if not trace_state_clean():
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
return True
def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]
2018-11-17 18:03:33 -08:00
@contextmanager
def new_main(trace_type: Type[Trace],
dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
2020-09-15 08:06:46 -07:00
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
main = MainTrace(level, trace_type, **payload)
2020-09-15 08:06:46 -07:00
stack.push(main)
if dynamic:
prev_dynamic, stack.dynamic = stack.dynamic, main
2018-11-17 18:03:33 -08:00
try:
yield main
2018-11-17 18:03:33 -08:00
finally:
2020-09-15 08:06:46 -07:00
thread_local_state.trace_state.trace_stack.pop()
if dynamic:
stack.dynamic = prev_dynamic
2018-11-17 18:03:33 -08:00
if check_leaks:
t = ref(main)
del main
2018-11-17 18:03:33 -08:00
if t() is not None:
print(thread_local_state.trace_state.trace_stack)
2018-11-17 18:03:33 -08:00
raise Exception('Leaked trace {}'.format(t()))
2020-09-15 08:06:46 -07:00
@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
2020-09-15 08:06:46 -07:00
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type)
prev_dynamic, stack.dynamic = stack.dynamic, main
prev_base, stack.stack[0] = stack.stack[0], main
try:
yield main
finally:
stack.dynamic = prev_dynamic
stack.stack[0] = prev_base
@contextmanager
def eval_context():
with new_base_main(EvalTrace):
yield
2018-11-17 18:03:33 -08:00
@contextmanager
def new_sublevel() -> Generator[None, None, None]:
sublevel = Sublevel(len(thread_local_state.trace_state.substack))
thread_local_state.trace_state.substack.append(sublevel)
2018-11-17 18:03:33 -08:00
try:
yield
finally:
thread_local_state.trace_state.substack.pop()
2018-11-17 18:03:33 -08:00
if check_leaks:
t = ref(sublevel)
del sublevel
if t() is not None:
raise Exception('Leaked sublevel {}'.format(t()))
2020-09-15 08:06:46 -07:00
def maybe_new_sublevel(trace):
# dynamic traces run the WrappedFun, so we raise the sublevel for them
dynamic = thread_local_state.trace_state.trace_stack.dynamic
return new_sublevel() if trace.main is dynamic else suppress()
def full_lower(val):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
2020-09-15 08:06:46 -07:00
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=None, key=attrgetter('level'))
2020-09-15 08:06:46 -07:00
dynamic = thread_local_state.trace_state.trace_stack.dynamic
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
else top_main)
return top_main and top_main.with_cur_sublevel() # type: ignore
2020-07-31 22:20:58 -07:00
2018-11-17 18:03:33 -08:00
# -------------------- abstract values --------------------
class AbstractValue:
__slots__: List[str] = []
_num_buffers: int = 1 # number of buffers used to represent the value.
2018-11-17 18:03:33 -08:00
def at_least_vspace(self):
return self
2018-11-17 18:03:33 -08:00
def __repr__(self):
try:
kv_pairs = ('{}={}'.format(k, v) for k, v in self.__dict__.items())
return '{}({})'.format(self.__class__.__name__, ','.join(kv_pairs))
except AttributeError:
return self.__class__.__name__
def strip_weak_type(self) -> 'AbstractValue':
return self
2018-11-17 18:03:33 -08:00
def join(self, other):
raise NotImplementedError("must override")
2018-11-17 18:03:33 -08:00
class Bot(AbstractValue): pass
bot = Bot()
class AbstractUnit(AbstractValue):
# TODO(jakevdp): make it possible to set zero buffers
# _num_buffers = 0
def join(self, other):
if not skip_checks:
assert other is abstract_unit, other
return self
def _eq(self, self_traced, other): return get_aval(other) is self
def str_short(self): return '*'
abstract_unit = AbstractUnit()
2018-11-17 18:03:33 -08:00
def lattice_join(x: Optional[AbstractValue],
y: Optional[AbstractValue]) -> AbstractValue:
2018-11-17 18:03:33 -08:00
if x is None:
return cast(AbstractValue, y)
2018-11-17 18:03:33 -08:00
elif y is None:
return cast(AbstractValue, x)
2018-11-17 18:03:33 -08:00
elif isinstance(x, type(y)):
return y.join(x)
elif isinstance(y, type(x)):
return x.join(y)
else:
raise TypeError((x, y))
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
2018-11-17 18:03:33 -08:00
def valid_jaxtype(x):
try:
concrete_aval(x)
except TypeError:
return False
2019-05-06 22:43:31 -07:00
else:
return True
2018-11-17 18:03:33 -08:00
def check_valid_jaxtype(x):
if not valid_jaxtype(x):
raise TypeError(f"{x} of type {type(x)} is not a valid JAX type")
2018-11-17 18:03:33 -08:00
def concrete_aval(x):
for typ in type(x).mro():
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
raise TypeError(f"{type(x)} is not a valid JAX type")
2018-11-17 18:03:33 -08:00
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
return concrete_aval(x)
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
2018-11-17 18:03:33 -08:00
class Unit:
def __repr__(self): return '*'
unit = Unit()
literalable_types.add(Unit)
class UnitVar(Var):
count = -1
suffix = ''
def __init__(self): pass
2020-03-09 09:14:23 +00:00
@property
def aval(self): return abstract_unit
def __repr__(self): return '*'
unitvar = UnitVar()
2018-11-17 18:03:33 -08:00
pytype_aval_mappings[Unit] = lambda _: abstract_unit
2018-11-17 18:03:33 -08:00
class ConcretizationTypeError(TypeError): pass
2020-09-15 08:06:46 -07:00
def raise_concretization_error(val: Tracer, context=""):
msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
+ context + "\n\n"
+ val._origin_msg() + "\n\n"
"See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
f"Encountered tracer value: {val}")
raise ConcretizationTypeError(msg)
2020-09-15 08:06:46 -07:00
def concretization_function_error(fun, suggest_astype=False):
2020-03-09 09:14:23 +00:00
fname = getattr(fun, "__name__", fun)
2020-09-15 08:06:46 -07:00
fname_context = f"The problem arose with the `{fname}` function. "
if suggest_astype:
fname_context += ("If trying to convert the data type of a value, "
f"try using `x.astype({fun.__name__})` "
f"or `jnp.array(x, {fun.__name__})` instead.")
def error(self, arg):
raise_concretization_error(arg, fname_context)
2020-03-09 09:14:23 +00:00
return error
def concrete_or_error(force: Any, val: Any, context=""):
"""Like force(val), but gives the context in the error message."""
if force is None:
force = lambda x: x
if isinstance(val, Tracer):
if isinstance(val.aval, ConcreteArray):
return force(val.aval.val)
else:
raise_concretization_error(val, context)
else:
return force(val)
2020-03-09 09:14:23 +00:00
class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 2
def __init__(self, dtype, weak_type=False):
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
2020-03-09 09:14:23 +00:00
self.weak_type = weak_type
def __eq__(self, other):
return (type(self) is type(other) and self.dtype == other.dtype and
self.weak_type == other.weak_type)
def __ne__(self, other):
return not self == other
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
2020-03-09 09:14:23 +00:00
# the unique character code via hash(self.dtype.char)
return hash((self.dtype, self.weak_type))
def __repr__(self):
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
", weak_type=True" if self.weak_type else "")
_bool = _nonzero = concretization_function_error(bool)
2020-09-15 08:06:46 -07:00
_float = concretization_function_error(float, True)
_int = concretization_function_error(int, True)
_complex = concretization_function_error(complex, True)
2020-03-09 09:14:23 +00:00
_hex = concretization_function_error(hex)
_oct = concretization_function_error(oct)
def at_least_vspace(self) -> AbstractValue:
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
2020-03-09 09:14:23 +00:00
def join(self, other):
if self.dtype == other.dtype:
if self.weak_type == other.weak_type:
return self
else:
return UnshapedArray(self.dtype, weak_type=False)
else:
raise TypeError(self, other)
def str_short(self) -> str:
2020-03-09 09:14:23 +00:00
return self.dtype.name
def strip_weak_type(self) -> 'UnshapedArray':
2020-03-09 09:14:23 +00:00
"""Returns a copy of the aval with weak_type=False."""
return UnshapedArray(self.dtype) if self.weak_type else self
@property
def shape(self):
msg = ("UnshapedArray has no shape. Please open an issue at "
"https://github.com/google/jax/issues because it's unexpected for "
"UnshapedArray instances to ever be produced.")
raise TypeError(msg)
2020-03-09 09:14:23 +00:00
class ShapedArray(UnshapedArray):
__slots__ = ['shape']
array_abstraction_level = 1
def __init__(self, shape, dtype, weak_type=False):
super(ShapedArray, self).__init__(dtype, weak_type=weak_type)
self.shape = canonicalize_shape(shape)
ndim = property(lambda self: len(self.shape))
size = property(lambda self: prod(self.shape))
broadcast: ClassVar[Optional[aval_method]] = None
transpose: ClassVar[Optional[aval_method]] = None
reshape: ClassVar[Optional[aval_method]] = None
_iter: ClassVar[Optional[staticmethod]] = None
2020-03-09 09:14:23 +00:00
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type)
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
2020-03-09 09:14:23 +00:00
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type))
def at_least_vspace(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
2020-03-09 09:14:23 +00:00
def join(self, other):
if self.shape == other.shape and self.dtype == other.dtype:
if self.weak_type == other.weak_type:
return self
else:
return ShapedArray(self.shape, self.dtype, weak_type=False)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError(self, other)
def str_short(self):
shapestr = ','.join(map(str, self.shape))
return '{}[{}]'.format(self.dtype.name, shapestr)
def __len__(self):
try:
return self.shape[0]
2020-09-30 01:20:00 +09:00
except IndexError as err:
raise TypeError("len() of unsized object") from err # same as numpy error
2020-03-09 09:14:23 +00:00
def _len(self, ignored_tracer):
return len(self)
def strip_weak_type(self):
return ShapedArray(self.shape, self.dtype) if self.weak_type else self
def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)
class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
def __init__(self, val, weak_type=False):
super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val),
2020-03-09 09:14:23 +00:00
weak_type=weak_type)
# Note: canonicalized self.dtype doesn't necessarily match self.val
self.val = val
assert self.dtype != np.dtype('O'), val
2020-03-09 09:14:23 +00:00
def __eq__(self, other):
if (type(self) is type(other) and self.dtype == other.dtype
and self.shape == other.shape and self.weak_type == other.weak_type):
with eval_context(): # in case self.val is a DeviceArray
return (self.val == other.val).all()
else:
return False
2020-03-09 09:14:23 +00:00
def __hash__(self):
return id(self.val)
def at_least_vspace(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
weak_type=self.weak_type)
2020-03-09 09:14:23 +00:00
def join(self, other) -> UnshapedArray:
2020-03-09 09:14:23 +00:00
if self == other:
return self
elif self.shape == other.shape and self.dtype == other.dtype:
return ShapedArray(self.shape, self.dtype,
weak_type=self.weak_type and other.weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype,
weak_type=self.weak_type and other.weak_type)
else:
raise TypeError(self, other)
def str_short(self) -> str:
2020-03-09 09:14:23 +00:00
return str(self.val)
def strip_weak_type(self) -> 'ConcreteArray':
2020-03-09 09:14:23 +00:00
return ConcreteArray(self.val) if self.weak_type else self
_bool = _nonzero = partialmethod(_forward_to_value, bool)
2020-09-15 08:06:46 -07:00
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
2020-03-09 09:14:23 +00:00
def primal_dtype_to_tangent_dtype(primal_dtype):
if not dtypes.issubdtype(primal_dtype, np.inexact):
return dtypes.float0
else:
return primal_dtype
2020-03-09 09:14:23 +00:00
class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
return self
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self): return 'Tok'
2020-03-09 09:14:23 +00:00
abstract_token = AbstractToken()
def raise_to_shaped(aval: AbstractValue, weak_type=None):
if weak_type is None:
weak_type = getattr(aval, 'weak_type', False)
for typ in type(aval).mro():
handler = raise_to_shaped_mappings.get(typ)
if handler: return handler(aval, weak_type)
raise TypeError(type(aval))
raise_to_shaped_mappings : Dict[type, Callable] = {
AbstractUnit: lambda aval, _: aval,
AbstractToken: lambda aval, _: aval,
ShapedArray: lambda aval, weak_type: ShapedArray(aval.shape, aval.dtype, weak_type=weak_type)
}
2020-03-09 09:14:23 +00:00
# Registry for valid dimension types. This is used by masking.Poly.
_DIMENSION_TYPES: Set[type] = {int}
2020-03-09 09:14:23 +00:00
def _canonicalize_dimension(dim):
if type(dim) in _DIMENSION_TYPES:
return dim
else:
return operator.index(dim)
def canonicalize_shape(shape):
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of integers.
"""
try:
return tuple(map(_canonicalize_dimension, shape))
except TypeError:
pass
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
"got {}.")
if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
raise TypeError(msg.format(shape))
2018-11-17 18:03:33 -08:00
# ------------------- Call -------------------
2018-11-17 18:03:33 -08:00
def apply_todos(todos, outs):
todos_list = list(todos)
while todos_list:
outs = map(full_lower, todos_list.pop()(outs))
return outs
2018-11-17 18:03:33 -08:00
@lu.transformation_with_aux
def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
level: int, params_tuple: tuple, *args):
outs = yield args, {}
params = dict(params_tuple)
2018-11-17 18:03:33 -08:00
todo = []
while True:
tracers = [x for x in outs if isinstance(x, Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=lambda x: x._trace.level)
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, cur_todo = primitive.post_process(trace, outs, params)
2018-11-17 18:03:33 -08:00
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable
2018-11-17 18:03:33 -08:00
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
2020-09-15 08:06:46 -07:00
fun, *args, **params):
params_tuple = tuple(params.items())
2018-11-17 18:03:33 -08:00
top_trace = find_top_trace(args)
2020-09-15 08:06:46 -07:00
fun, env_trace_todo = process_env_traces(
fun, primitive, top_trace and top_trace.level, params_tuple)
tracers = map(top_trace.full_raise, args)
with maybe_new_sublevel(top_trace):
outs = primitive.process(top_trace, fun, tracers, params)
2020-09-15 08:06:46 -07:00
return map(full_lower, apply_todos(env_trace_todo(), outs))
2018-11-17 18:03:33 -08:00
class CallPrimitive(Primitive):
multiple_results = True
call_primitive = True
def bind(self, fun, *args, **params):
return call_bind(self, fun, *args, **params)
def process(self, trace, fun, tracers, params):
return trace.process_call(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
return trace.post_process_call(self, out_tracers, params)
2018-11-17 18:03:33 -08:00
def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)
2018-11-17 18:03:33 -08:00
call_p = CallPrimitive('call')
call = call_p.bind
2018-11-17 18:03:33 -08:00
call_p.def_impl(call_impl)
# ------------------- Map -------------------
class MapPrimitive(Primitive):
multiple_results = True
map_primitive = True
def bind(self, fun, *args, **params):
assert len(params['mapped_invars']) == len(args)
return call_bind(self, fun, *args, **params)
def process(self, trace, fun, tracers, params):
return trace.process_map(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
return trace.post_process_map(self, out_tracers, params)
2018-11-17 18:03:33 -08:00
@contextmanager
def extend_axis_env(axis_name, size: int, tag: Any):
2020-09-15 08:06:46 -07:00
frame = AxisEnvFrame(axis_name, size, tag)
thread_local_state.trace_state.axis_env.append(frame)
try:
yield
finally:
thread_local_state.trace_state.axis_env.pop()
# When a mapped function is given no axis name, we generate a name object based
# on the id of the function object. Collisions aren't important because this
# name can't be used in collectives, as user code never gets a ref to this
# object. We don't want to use the function object itself because that might
# persist references to the function object.
# TODO(mattjj): revisit this unique axis name strategy
class _TempAxisName:
def __init__(self, obj):
self.id = id(obj)
def __repr__(self):
return f'<axis {hex(self.id)}>'
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
return type(other) is _TempAxisName and self.id == other.id
2020-09-15 08:06:46 -07:00
def axis_frame(axis_name):
frames = thread_local_state.trace_state.axis_env
for frame in reversed(frames):
if frame.name == axis_name:
return frame
named_axis = [
frame.name
for frame in reversed(frames)
if not isinstance(frame.name, _TempAxisName)
]
raise NameError(
f'unbound axis name: {axis_name}. The following axis names (e.g. defined '
2020-11-03 12:11:03 +00:00
'by pmap) are available to collective operations:'
f'{named_axis}')
2020-09-15 08:06:46 -07:00
2020-04-15 11:05:32 -07:00
# ------------------- Jaxpr checking -------------------
def mapped_aval(size: int, aval: AbstractValue) -> AbstractValue:
if aval is abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
# might be raising abstraction level from Concrete here
assert aval.shape[0] == size
return ShapedArray(aval.shape[1:], aval.dtype)
else:
raise TypeError(f"Mapped operand {aval}")
def unmapped_aval(size: int, aval: AbstractValue) -> AbstractValue:
if aval is abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
return ShapedArray((size,) + aval.shape, aval.dtype)
else:
raise TypeError(f"Mapped output {aval}")
def typecheck(aval: AbstractValue, x) -> bool:
return typecompat(aval, get_aval(x))
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
"""Determine whether `aval` conforms to `aval_ref`"""
aval_ref = raise_to_shaped(aval_ref).strip_weak_type()
try:
return aval_ref == lattice_join(aval_ref, aval).strip_weak_type()
except TypeError:
return False
def typematch(aval1: UnshapedArray, aval2: UnshapedArray) -> bool:
return raise_to_shaped(aval1, weak_type=False) == raise_to_shaped(aval2, weak_type=False)
2020-06-24 15:31:33 -07:00
class JaxprTypeError(TypeError): pass
def typecheck_assert(pred, msg):
if not pred:
raise JaxprTypeError(msg)
custom_typechecks: Dict[Primitive, Callable] = {}
def check_jaxpr(jaxpr: Jaxpr):
"""Checks well-formedness of a jaxpr.
2020-04-15 17:02:48 -07:00
Specifically, check that:
- variables that are read are bound beforehand
- variables are typed equally throughout a jaxpr
- variable type annotations are compatible with their binding expression
Raises `TypeError` if `jaxpr` is determined invalid. Returns `None` otherwise.
"""
try:
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
2020-06-24 15:31:33 -07:00
except JaxprTypeError as e:
if len(e.args) == 2:
msg, eqn_idx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqn_idx - 10, eqn_idx + 10))
else:
msg, = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20))
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
raise JaxprTypeError(msg) from None
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
2020-04-14 22:22:35 -07:00
def read(v: Atom) -> AbstractValue:
if isinstance(v, Literal):
return raise_to_shaped(get_aval(v.val))
else:
2020-06-24 15:31:33 -07:00
typecheck_assert(v in env, f"Variable '{v}' not defined")
return env[v]
2020-04-14 22:22:35 -07:00
def write(v: Var, a: AbstractValue) -> None:
2020-06-24 15:31:33 -07:00
typecheck_assert(v not in env, f"Variable '{v}' already bound")
2020-06-08 16:13:30 -07:00
if v is not dropvar:
2020-06-24 15:31:33 -07:00
typecheck_assert(typecompat(v.aval, a),
f"Variable '{v}' inconsistently typed as {a}, "
f"bound as {v.aval}")
2020-06-08 16:13:30 -07:00
env[v] = a
env : Dict[Var, AbstractValue] = {}
write(unitvar, abstract_unit)
map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars])
map(write, jaxpr.invars, in_avals)
for eqn_idx, eqn in enumerate(jaxpr.eqns):
prim = eqn.primitive
try:
in_avals = map(read, eqn.invars)
typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals),
"Equation given ConcreteArray type inputs")
if prim in custom_typechecks:
custom_typechecks[prim](*in_avals, **eqn.params)
if prim.call_primitive:
out_avals = check_call(prim, in_avals, eqn.params)
elif prim.map_primitive:
out_avals = check_map(prim, in_avals, eqn.params)
else:
out_avals = check_eqn(prim, in_avals, eqn.params)
map(write, eqn.outvars, out_avals)
2020-06-24 15:31:33 -07:00
except JaxprTypeError as e:
msg, = e.args
src = source_info_util.summarize(eqn.source_info)
msg = "\n\n".join([msg, "in equation:", str(pp_eqn(eqn).indent(2)),
f"from source: {src}"])
raise JaxprTypeError(msg, eqn_idx) from None
map(read, jaxpr.outvars)
def check_eqn(prim, in_avals, params):
for jaxpr in jaxprs_in_params(params):
check_jaxpr(jaxpr)
out_avals = prim.abstract_eval(*in_avals, **params)
if not prim.multiple_results:
out_avals = [out_avals]
return out_avals
2018-11-17 18:03:33 -08:00
def check_call(prim, in_avals, params):
2020-06-24 15:31:33 -07:00
typecheck_assert("call_jaxpr" in params,
f"Call primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
# These checks also happen in recursive call, but give better errors here.
2020-06-24 15:31:33 -07:00
typecheck_assert(len(in_avals) == len(call_jaxpr.invars),
f"Call primitive {prim} with {len(call_jaxpr.invars)} "
f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
f"inputs")
binder_avals = [v.aval for v in call_jaxpr.invars]
for binder_aval, in_aval in zip(binder_avals, in_avals):
2020-06-24 15:31:33 -07:00
typecheck_assert(typecompat(binder_aval, in_aval),
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
_check_jaxpr(call_jaxpr, in_avals)
out_avals = [v.aval for v in call_jaxpr.outvars]
return out_avals
def check_map(prim, in_avals, params):
2020-06-24 15:31:33 -07:00
typecheck_assert("call_jaxpr" in params,
f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
2020-06-24 15:31:33 -07:00
typecheck_assert("axis_size" in params,
f"Map primitive {prim} missing 'axis_size' parameter")
axis_size = params["axis_size"]
2020-06-24 15:31:33 -07:00
typecheck_assert("mapped_invars" in params,
f"Map primitive {prim} missing 'mapped_invars' parameter")
mapped_invars = params["mapped_invars"]
binder_avals = [unmapped_aval(axis_size, v.aval) if mapped else v.aval
for v, mapped in zip(call_jaxpr.invars, mapped_invars)]
for binder_aval, in_aval in zip(binder_avals, in_avals):
2020-06-24 15:31:33 -07:00
typecheck_assert(typecompat(binder_aval, in_aval),
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
mapped_avals = [mapped_aval(axis_size, aval) if mapped else aval
for aval, mapped in zip(in_avals, mapped_invars)]
_check_jaxpr(call_jaxpr, mapped_avals)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, aval) for aval in mapped_out_avals]
return out_avals
2020-04-15 11:05:32 -07:00
# ------------------- Jaxpr printed representation -------------------
2018-11-17 18:03:33 -08:00
def pp_vars(vs: Sequence[Any], print_shapes: bool = False) -> str:
if print_shapes:
return ' '.join(f'{v}:{v.aval.str_short()}' for v in vs)
else:
return ' '.join(map(str, vs))
2018-11-17 18:03:33 -08:00
def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
filtered_params = {k: v for k, v in params.items()
if (k != 'branches' and
not isinstance(v, (Jaxpr, ClosedJaxpr)))}
return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items()))
def pp_eqn(eqn: JaxprEqn, print_shapes: bool = False) -> PrettyPrint:
lhs = pp_vars(eqn.outvars, print_shapes)
pp_lhs = pp(f'{lhs} =')
pp_rhs = (pp(eqn.primitive.name) >>
pp_kv_pairs(sorted(eqn.params.items())) >> pp(' ') >>
pp(pp_vars(eqn.invars, print_shapes)))
2020-09-15 08:06:46 -07:00
if len(lhs) <= 6 or print_shapes:
return pp_lhs >> pp(' ') >> pp_rhs
else:
return pp_lhs + pp_rhs.indent(2)
def pp_eqns(eqns: Sequence[JaxprEqn],
source_info: bool = False) -> Sequence[PrettyPrint]:
pps = map(pp_eqn, eqns)
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
if source_info:
l = max((i + len(s) for x in pps for i, s in x.lines), default=None)
if l is not None:
return [p.annotate(l, source_info_util.summarize(e.source_info))
for e, p in zip(eqns, pps)]
return pps
def pp_jaxpr(jaxpr: Jaxpr, source_info: bool = False) -> PrettyPrint:
pps = pp_eqns(jaxpr.eqns, source_info=source_info)
str_outvars = str(tuple(jaxpr.outvars))
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
pp_vars(jaxpr.invars))) +
((pp('let ') >> vcat(pps))
+ pp('in {} }}'.format(str_outvars))).indent(2))
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int,
source_info: bool = False) -> PrettyPrint:
lo = max(lo, 0)
hi = max(lo, min(hi, len(jaxpr.eqns)))
eqns = jaxpr.eqns[lo:hi]
pps = []
if len(eqns) == 0 and len(jaxpr.eqns) != 0:
pps.append(pp('...'))
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
else:
if lo != 0:
pps.append(pp('...'))
pps.extend(pp_eqns(eqns, source_info=source_info))
if hi != len(jaxpr.eqns):
pps.append(pp('...'))
str_outvars = str(tuple(jaxpr.outvars))
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
pp_vars(jaxpr.invars))) +
((pp('let ') >> vcat(pps))
+ pp('in {} }}'.format(str_outvars))).indent(2))
def pp_jaxprs(jaxprs) -> PrettyPrint:
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
return pp('( ') >> vcat(map(pp_jaxpr, jaxprs)) >> pp(' )')
def pp_kv_pair(k, v):
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
2020-07-30 14:02:48 -07:00
pp_v = pp_jaxprs(v)
else:
pp_v = pp(v)
return pp(f'{k}=') >> pp_v
def pp_kv_pairs(kv_pairs):
if kv_pairs:
return pp('[ ') >> vcat([pp_kv_pair(k, v) for k, v in kv_pairs]) >> pp(' ]')
else:
return pp('')
2020-09-15 08:06:46 -07:00
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env, \
eval_context
class TraceStack:
2020-09-15 08:06:46 -07:00
upward: List[MainTrace]
downward: List[MainTrace]
def __init__(self):
2020-09-15 08:06:46 -07:00
self.upward = []
self.downward = []
2020-09-15 08:06:46 -07:00
def next_level(self, bottom: bool) -> int:
if bottom:
return - (len(self.downward) + 1)
else:
return len(self.upward)
2020-09-15 08:06:46 -07:00
def push(self, main_trace: MainTrace, bottom: bool) -> None:
if bottom:
self.downward.append(main_trace)
else:
self.upward.append(main_trace)
2020-09-15 08:06:46 -07:00
def pop(self, bottom: bool) -> None:
if bottom:
self.downward.pop()
else:
self.upward.pop()
def __repr__(self) -> str:
2020-09-15 08:06:46 -07:00
return 'Trace stack\n{} ---\n{}'.format(
map(' {}\n'.format, self.upward[::-1]),
map(' {}\n'.format, self.downward))
def copy(self):
2020-09-15 08:06:46 -07:00
new = TraceStack()
new.upward = self.upward[:]
new.downward = self.downward[:]
return new
class TraceState:
trace_stack: TraceStack
substack: List[Sublevel]
2020-09-15 08:06:46 -07:00
initial_style: bool
def __init__(self) -> None:
2020-09-15 08:06:46 -07:00
self.trace_stack = TraceStack() # type: ignore
self.substack = [Sublevel(0)]
2020-09-15 08:06:46 -07:00
self.initial_style = False
def copy(self):
2020-09-15 08:06:46 -07:00
new = TraceState()
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
2020-09-15 08:06:46 -07:00
new.initial_style = self.initial_style
return new
thread_local_state = ThreadLocalState()
def reset_trace_state() -> bool:
"Reset the global trace state and return True if it was already clean."
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
2020-09-15 08:06:46 -07:00
thread_local_state.trace_state.trace_stack.downward or
thread_local_state.trace_state.trace_stack.upward):
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
return True
@contextmanager
def new_main(trace_type: Type[Trace], bottom=False, **payload) -> Generator[MainTrace, None, None]:
2020-09-15 08:06:46 -07:00
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
main = MainTrace(level, trace_type, **payload)
2020-09-15 08:06:46 -07:00
thread_local_state.trace_state.trace_stack.push(main, bottom)
try:
yield main
finally:
2020-09-15 08:06:46 -07:00
thread_local_state.trace_state.trace_stack.pop(bottom)
if check_leaks:
t = ref(main)
del main
if t() is not None:
print(thread_local_state.trace_state.trace_stack)
raise Exception('Leaked trace {}'.format(t()))
2020-09-15 08:06:46 -07:00
def find_top_trace(xs) -> Optional[Trace]:
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'), default=None)
return top_trace and top_trace.main.with_cur_sublevel()
@contextmanager
def eval_context():
2020-09-15 08:06:46 -07:00
yield # dummy implementation for forward compatibility
2020-09-15 08:06:46 -07:00
def bind(self, *args, **kwargs):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
2020-09-15 08:06:46 -07:00
if top_trace is None:
return self.impl(*args, **kwargs)
tracers = map(top_trace.full_raise, args)
2020-09-15 08:06:46 -07:00
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
if self.multiple_results:
return map(full_lower, out_tracer)
else:
return full_lower(out_tracer)
Primitive.bind = bind # type: ignore
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
fun: lu.WrappedFun, *args, **params):
params_tuple = tuple(params.items())
top_trace = find_top_trace(args)
level = (thread_local_state.trace_state.trace_stack.next_level(True)
if top_trace is None else top_trace.level)
params_tuple = tuple(params.items())
fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple)
if top_trace is None:
with new_sublevel():
outs = primitive.impl(fun, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
outs = primitive.process(top_trace, fun, tracers, params)
return apply_todos(env_trace_todo(), map(full_lower, outs))
@contextmanager
def extend_axis_env(axis_name, size: int, tag: Any):
yield
@contextmanager
2020-09-15 08:06:46 -07:00
def initial_style_staging():
trace_state = thread_local_state.trace_state
prev, trace_state.initial_style = trace_state.initial_style, True
try:
yield
finally:
2020-09-15 08:06:46 -07:00
trace_state.initial_style = prev
# Casting float0 array to a float-valued zero array.
def zeros_like_float0(array, dtype=None):
if not dtype:
dtype = np.float
return np.zeros(array.shape, dtype)