2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2022-04-26 20:34:14 -07:00
|
|
|
from __future__ import annotations
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:42:08 +01:00
|
|
|
from collections import namedtuple
|
2020-07-30 12:59:36 -07:00
|
|
|
import contextlib
|
2022-03-30 17:52:55 -07:00
|
|
|
from dataclasses import dataclass
|
2020-07-30 12:59:36 -07:00
|
|
|
import functools
|
2021-09-13 17:24:44 -04:00
|
|
|
from functools import partial
|
|
|
|
import inspect
|
|
|
|
import itertools as it
|
2021-10-15 16:51:37 -07:00
|
|
|
import operator as op
|
2020-07-30 12:59:36 -07:00
|
|
|
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
|
2022-03-30 17:52:55 -07:00
|
|
|
List, Union, Hashable, cast)
|
2019-11-19 12:26:30 -08:00
|
|
|
from weakref import ref
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2019-05-28 22:38:06 -07:00
|
|
|
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax import core
|
|
|
|
from jax._src import dtypes
|
|
|
|
from jax import linear_util as lu
|
2021-12-06 15:13:01 -08:00
|
|
|
from jax._src import profiler
|
2021-06-07 14:51:04 -07:00
|
|
|
from jax._src.ad_util import Zero
|
2022-02-18 13:44:06 -08:00
|
|
|
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
|
|
|
|
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
|
|
|
tree_leaves)
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
2022-03-18 19:51:29 -07:00
|
|
|
merge_lists, partition_list, OrderedSet,
|
|
|
|
as_hashable_function, weakref_lru_cache)
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
2022-04-26 13:01:01 -07:00
|
|
|
ClosedJaxpr, new_jaxpr_eqn, ConcreteArray,
|
|
|
|
raise_to_shaped, Var, DropVar, Atom, JaxprEqn, Primitive,
|
|
|
|
ShapedArray, DShapedArray, AbstractBInt, mapped_aval,
|
|
|
|
unmapped_aval)
|
2020-11-04 11:54:01 -08:00
|
|
|
from jax._src import source_info_util
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax.config import config
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def _update_annotation(
|
|
|
|
f: lu.WrappedFun,
|
|
|
|
orig_type: Optional[Tuple[Tuple[AbstractValue, bool], ...]],
|
|
|
|
in_knowns: List[bool]) -> lu.WrappedFun:
|
|
|
|
if orig_type is None:
|
|
|
|
return f
|
|
|
|
return lu.annotate(f, tuple([ty for k, ty in zip(in_knowns, orig_type) if k]))
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
class PartialVal(tuple):
|
|
|
|
"""Partial value: either a known value or an unknown (abstract) value.
|
|
|
|
|
|
|
|
Represented as a pair `(aval_opt, const)` of one of two kinds:
|
|
|
|
* `(None, <Constant>)` indicates a known value, either a Python regular
|
|
|
|
value, or a Tracer.
|
|
|
|
* `(<AbstractValue>, *)` indicates an unknown value characterized by an
|
|
|
|
abstract value.
|
|
|
|
"""
|
|
|
|
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
|
|
|
|
pv, const = xs
|
2021-03-19 13:49:38 -07:00
|
|
|
if config.jax_enable_checks:
|
2020-03-18 07:11:44 +01:00
|
|
|
# type checks
|
|
|
|
assert isinstance(pv, (AbstractValue, type(None))), xs
|
2020-05-27 13:57:47 +00:00
|
|
|
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
|
2020-03-18 07:11:44 +01:00
|
|
|
# invariant checks
|
|
|
|
if isinstance(pv, AbstractValue):
|
2020-07-30 12:59:36 -07:00
|
|
|
assert get_aval(const) == core.abstract_unit, xs
|
2020-03-18 07:11:44 +01:00
|
|
|
return tuple.__new__(cls, xs)
|
|
|
|
|
|
|
|
@classmethod
|
2022-04-26 20:34:14 -07:00
|
|
|
def known(cls, const: core.Value) -> PartialVal:
|
2020-03-18 07:11:44 +01:00
|
|
|
return PartialVal((None, const))
|
|
|
|
|
|
|
|
@classmethod
|
2022-04-26 20:34:14 -07:00
|
|
|
def unknown(cls, aval: AbstractValue) -> PartialVal:
|
2020-03-18 07:11:44 +01:00
|
|
|
return PartialVal((aval, core.unit))
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def is_known(self) -> bool:
|
2020-03-18 07:11:44 +01:00
|
|
|
return self[0] is None
|
|
|
|
|
|
|
|
def get_known(self) -> Optional[core.Value]:
|
|
|
|
"""Get the known value, if known, else None."""
|
|
|
|
return self[1] if self[0] is None else None
|
|
|
|
|
|
|
|
def get_aval(self) -> AbstractValue:
|
2020-06-23 09:39:45 -07:00
|
|
|
"""Get AbstractValue directly (if unknown) or from the constant (known)."""
|
2020-03-18 07:11:44 +01:00
|
|
|
known = self.get_known()
|
|
|
|
if known is not None:
|
|
|
|
return get_aval(known)
|
|
|
|
else:
|
|
|
|
return self[0]
|
|
|
|
|
2019-05-09 07:23:39 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTrace(Trace):
|
2021-10-28 11:06:58 -07:00
|
|
|
|
|
|
|
def __init__(self, *args, name_stack: source_info_util.NameStack):
|
|
|
|
super().__init__(*args)
|
|
|
|
self.name_stack = name_stack
|
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def pure(self, val) -> JaxprTracer:
|
2019-09-20 15:35:43 -07:00
|
|
|
return self.new_const(val)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def lift(self, val) -> JaxprTracer:
|
2018-11-17 18:03:33 -08:00
|
|
|
return self.new_const(val)
|
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def sublift(self, val) -> JaxprTracer:
|
2018-11-17 18:03:33 -08:00
|
|
|
return JaxprTracer(self, val.pval, FreeVar(val))
|
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def new_const(self, val) -> JaxprTracer:
|
2020-01-29 16:23:27 -05:00
|
|
|
if isinstance(val, Tracer) and val._trace.level == self.level:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise Exception
|
2022-04-26 13:01:01 -07:00
|
|
|
return JaxprTracer(self, PartialVal.known(val), core.unit)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def new_instantiated_literal(self, val) -> JaxprTracer:
|
2022-01-20 22:58:09 -08:00
|
|
|
aval = get_aval(val)
|
|
|
|
return JaxprTracer(self, PartialVal.unknown(aval),
|
|
|
|
Literal(val, raise_to_shaped(aval)))
|
2019-05-13 08:48:13 -07:00
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def new_instantiated_const(self, val) -> JaxprTracer:
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def new_arg(self, pval: PartialVal) -> JaxprTracer:
|
2020-06-05 17:22:55 +02:00
|
|
|
const = pval.get_known()
|
Add reverse-mode AD support for pjit
This is a somewhat big patch, because the transposition process turns out to be
quite difficult. The biggest issue appears when we do partial evaluation and we have
to add a whole bunch of intermediate values as outputs of the primal computation,
but we don't have any partition specs for them!
A simple workaround would be to mark all of them as replicated, but that would
likely tank performance which is why we didn't go with that option. Instead, we use
a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile
a throwaway executable that lets us query output sharding that XLA considers convenient
for the computation.
However, there's one more difficulty: XLA's `OpSharding` is much less constrained
than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent
"block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding`
allows arbitrary assignment (permutation) of tensor chunks to devices. This means that
not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a
(somewhat involved) procedure that should recover one whenever it exists.
Unfortunately this makes our support for reverse-mode AD partial, because we might
be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA
actually comes up with sharding specifications on its own. If it merely propagates
the sharding obtained from `PartitionSpec`s into the middle of the computation, then
we should be good. In any case, if we end up seeing failures in this path, we should
consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided
to avoid it unless there's no other way.
PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
|
|
|
# XXX: Think twice before changing this constant argument pruning!
|
|
|
|
# This has really important consequences for partial_eval_jaxpr.
|
|
|
|
# Most importantly, this guarantees that the unknown jaxpr never uses
|
|
|
|
# known inputs (if it needs them, then they get passed through residuals).
|
2020-06-05 17:22:55 +02:00
|
|
|
if const is None:
|
|
|
|
return JaxprTracer(self, pval, LambdaBinding())
|
|
|
|
else:
|
|
|
|
return self.new_const(const)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 10:26:43 -04:00
|
|
|
def instantiate_const(self, tracer) -> Tracer:
|
2020-03-18 07:11:44 +01:00
|
|
|
const = tracer.pval.get_known()
|
|
|
|
if const is None:
|
2018-11-17 18:03:33 -08:00
|
|
|
return tracer
|
2020-03-18 07:11:44 +01:00
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
if type(const) in core.literalable_types and np.shape(const) == ():
|
2019-09-20 15:35:43 -07:00
|
|
|
return self.new_instantiated_literal(const)
|
2019-05-13 08:48:13 -07:00
|
|
|
else:
|
|
|
|
return self.new_instantiated_const(const)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def instantiate_const_abstracted(self, tracer) -> JaxprTracer:
|
2020-03-18 07:11:44 +01:00
|
|
|
const = tracer.pval.get_known()
|
|
|
|
if const is None:
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
return tracer
|
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-04-01 16:03:56 -04:00
|
|
|
if primitive in custom_partial_eval_rules:
|
2019-07-27 15:46:14 -07:00
|
|
|
return custom_partial_eval_rules[primitive](self, *tracers, **params)
|
2019-04-01 16:03:56 -04:00
|
|
|
else:
|
2020-01-30 15:03:00 -08:00
|
|
|
return self.default_process_primitive(primitive, tracers, params)
|
|
|
|
|
|
|
|
def default_process_primitive(self, primitive, tracers, params):
|
2022-04-26 20:34:14 -07:00
|
|
|
# By default, if all the input tracers are known, then bind the primitive
|
|
|
|
# and consider all outputs known. Otherwise, stage the application into the
|
|
|
|
# jaxpr and consider all outputs unknown.
|
2020-07-30 12:59:36 -07:00
|
|
|
consts = [t.pval.get_known() for t in tracers]
|
2020-03-18 07:11:44 +01:00
|
|
|
if all(c is not None for c in consts):
|
2020-02-09 21:06:37 -08:00
|
|
|
return primitive.bind(*consts, **params)
|
|
|
|
tracers = map(self.instantiate_const, tracers)
|
|
|
|
avals = [t.aval for t in tracers]
|
2022-02-28 13:36:39 -08:00
|
|
|
out_aval, effects = primitive.abstract_eval(*avals, **params)
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = self._current_truncated_name_stack()
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
2020-02-09 21:06:37 -08:00
|
|
|
if primitive.multiple_results:
|
2020-03-18 07:11:44 +01:00
|
|
|
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
|
2020-02-09 21:06:37 -08:00
|
|
|
for aval in out_aval]
|
2022-02-28 13:36:39 -08:00
|
|
|
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, source)
|
2020-02-09 21:06:37 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
|
2022-02-28 13:36:39 -08:00
|
|
|
params, effects, source)
|
2020-02-09 21:06:37 -08:00
|
|
|
return out_tracer
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def process_call(self, primitive, f, tracers, params):
|
2020-06-12 15:03:26 +02:00
|
|
|
if primitive in call_partial_eval_rules:
|
|
|
|
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
|
2020-05-28 17:39:13 +02:00
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
|
|
|
|
in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers])
|
|
|
|
|
|
|
|
# We want to partially evaluate this call into two calls: one evaluated now
|
|
|
|
# taking known values (in_consts) as inputs and producing known values
|
|
|
|
# (out_consts) as outputs, and the other staged out as an eqn into the jaxpr
|
|
|
|
# being built. The latter takes as input residuals (res) produced as outputs
|
|
|
|
# of the first call, shared closed-over values (env), and explicit arguments
|
|
|
|
# which were unknown to the first call (corresponding to in_avals).
|
|
|
|
|
|
|
|
# Wrap f to perform the partial evaluation and plumb out aux data.
|
2022-03-30 17:52:55 -07:00
|
|
|
f_ = trace_to_subjaxpr_nounits(f, self.main, False)
|
|
|
|
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals))
|
2022-02-06 17:21:31 -08:00
|
|
|
# Adjust parameters (e.g. donated_invars) for the call to be evaluated now.
|
|
|
|
const_params = update_params(params, in_knowns, 0)
|
|
|
|
|
|
|
|
# Run the call, getting known out vals and aux data used for staged-out call
|
2022-03-30 17:52:55 -07:00
|
|
|
out = primitive.bind(_update_annotation(f_, f.in_type, in_knowns),
|
|
|
|
*in_consts, **const_params)
|
2022-02-06 17:21:31 -08:00
|
|
|
out_knowns, out_avals, jaxpr, env = aux()
|
|
|
|
# Split apart known outputs from the original call and residuals.
|
|
|
|
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
|
|
|
|
|
|
|
# Create the input tracers for the staged-out (unknown-value) call.
|
|
|
|
const_tracers = map(self.new_instantiated_const, res)
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
|
|
|
|
# Adjust parameters (e.g. donated_invars) for the staged-out call's args.
|
|
|
|
num_new_args = len(const_tracers) + len(env_tracers)
|
|
|
|
staged_params = update_params(params, map(op.not_, in_knowns), num_new_args)
|
|
|
|
staged_params = dict(staged_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
|
|
|
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
|
|
|
|
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
|
|
|
|
for a in out_avals]
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = self._current_truncated_name_stack()
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
2022-02-06 17:21:31 -08:00
|
|
|
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
|
2022-02-28 13:36:39 -08:00
|
|
|
out_tracers, primitive, staged_params, jaxpr.effects, source)
|
2022-02-06 17:21:31 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return merge_lists(out_knowns, out_tracers, out_consts)
|
|
|
|
|
|
|
|
def process_map(self, primitive, f: lu.WrappedFun, tracers, params):
|
2022-02-06 17:21:31 -08:00
|
|
|
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
|
|
|
|
in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers])
|
|
|
|
|
|
|
|
# This method is like process_call above, except:
|
|
|
|
# 1. we delete an axis from mapped-over input avals' shapes, and
|
|
|
|
# analogously add an axis to mapped-over output avals' shapes;
|
|
|
|
# 2. we update the in_axes and out_axes/out_axes_thunk parameters to
|
|
|
|
# reflect the inputs and outputs pruned from the unknown/known sides.
|
|
|
|
|
|
|
|
# Map (delete an axis from) unknown inputs' avals as dictated by in_axes.
|
|
|
|
unk_in_axes, const_in_axes = partition_list(in_knowns, params['in_axes'])
|
|
|
|
in_avals_mapped = [mapped_aval(params['axis_size'], ax, aval)
|
|
|
|
for ax, aval in zip(unk_in_axes, in_avals)]
|
|
|
|
|
|
|
|
# Wrap f to perform partial evaluation and plumb out aux data.
|
|
|
|
f = trace_to_subjaxpr_nounits(f, self.main, False)
|
|
|
|
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns),
|
|
|
|
tuple(in_avals_mapped))
|
|
|
|
# Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk)
|
|
|
|
const_params = update_params(params, in_knowns, 0) # handles donated_invars
|
|
|
|
out_axes_thunk = params['out_axes_thunk']
|
|
|
|
@as_hashable_function(closure=out_axes_thunk)
|
|
|
|
def const_out_axes_thunk():
|
|
|
|
out_knowns, _, jaxpr, _ = aux()
|
|
|
|
_, out_axes = partition_list(out_knowns, out_axes_thunk())
|
|
|
|
return tuple(out_axes) + (0,) * len(jaxpr.constvars) # res mapped axis 0
|
|
|
|
const_params = dict(const_params, in_axes=tuple(const_in_axes),
|
|
|
|
out_axes_thunk=const_out_axes_thunk)
|
|
|
|
|
|
|
|
# Run the map, getting known out vals and aux data used for staged-out map.
|
2022-04-28 12:44:19 -07:00
|
|
|
out = primitive.bind(f, *in_consts, **const_params)
|
2022-02-06 17:21:31 -08:00
|
|
|
out_knowns, out_avals_mapped, jaxpr, env = aux()
|
|
|
|
# Split apart known outputs from the original call and residuals.
|
|
|
|
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
2022-02-06 17:21:31 -08:00
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
# We can only check_jaxpr with the dynamic axis environment extended:
|
2022-02-06 17:21:31 -08:00
|
|
|
with core.extend_axis_env(params['axis_name'], params['axis_size'], None):
|
2022-02-06 17:21:31 -08:00
|
|
|
call_jaxpr = convert_constvars_jaxpr(jaxpr)
|
|
|
|
|
|
|
|
# Compute staged and const out_axes, taking into account residuals.
|
|
|
|
out_axes = params['out_axes_thunk']()
|
|
|
|
staged_out_axes, _ = partition_list(out_knowns, out_axes)
|
|
|
|
staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,)
|
|
|
|
|
|
|
|
# Create the input tracers for the staged-out (unkonwn-value) call.
|
|
|
|
const_tracers = map(self.new_instantiated_const, res)
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
|
|
|
|
# Adjust params for staged-out call on unknown values.
|
|
|
|
num_new_args = len(const_tracers) + len(env_tracers)
|
|
|
|
staged_params = update_params(params, map(op.not_, in_knowns), num_new_args)
|
|
|
|
staged_params = dict(staged_params, in_axes=staged_in_axes,
|
|
|
|
out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr)
|
|
|
|
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
|
|
|
|
out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a)
|
|
|
|
for ax, a in zip(staged_out_axes, out_avals_mapped)]
|
|
|
|
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
|
|
|
|
for a in out_avals]
|
|
|
|
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
|
|
|
|
out_tracers, primitive, staged_params,
|
2022-02-28 13:36:39 -08:00
|
|
|
jaxpr.effects,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info_util.current())
|
2022-02-06 17:21:31 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
|
|
|
|
return merge_lists(out_knowns, out_tracers, out_consts)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
def post_process_call(self, primitive, out_tracers, params):
|
2022-02-06 17:21:31 -08:00
|
|
|
unknown_out_tracers = [t for t in out_tracers if not t.is_known()]
|
|
|
|
jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers)
|
|
|
|
out_pvals = [t.pval for t in out_tracers]
|
|
|
|
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
|
|
|
|
out = [*out_consts, *res]
|
|
|
|
main = self.main
|
|
|
|
|
|
|
|
def todo(out):
|
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
|
|
|
const_tracers = map(trace.new_instantiated_const, res)
|
|
|
|
in_tracers = (*const_tracers, *map(trace.full_raise, env))
|
|
|
|
out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None)
|
|
|
|
for a in out_avals]
|
|
|
|
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
|
|
|
|
new_params = update_params(params, [], len(in_tracers))
|
|
|
|
new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = self._current_truncated_name_stack()
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
2022-02-28 13:36:39 -08:00
|
|
|
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
|
|
|
|
jaxpr.effects, source)
|
2022-02-06 17:21:31 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return merge_lists(out_knowns, out_tracers, out_consts)
|
|
|
|
|
|
|
|
return out, todo
|
|
|
|
|
|
|
|
def post_process_map(self, primitive, out_tracers, params):
|
2022-02-06 17:21:31 -08:00
|
|
|
unknown_out_tracers = [t for t in out_tracers if not t.is_known()]
|
|
|
|
jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers)
|
|
|
|
out_pvals = [t.pval for t in out_tracers]
|
|
|
|
out_knowns, out_avals_mapped, out_consts = partition_pvals(out_pvals)
|
|
|
|
out = [*out_consts, *res]
|
2020-08-30 12:38:14 +03:00
|
|
|
main = self.main
|
2020-04-21 18:12:02 -07:00
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
with core.extend_axis_env(params['axis_name'], params['axis_size'], None):
|
|
|
|
call_jaxpr = convert_constvars_jaxpr(jaxpr)
|
|
|
|
|
|
|
|
def todo(out):
|
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
|
|
|
const_tracers = map(trace.new_instantiated_const, res)
|
|
|
|
env_tracers = map(trace.full_raise, env)
|
2020-06-23 09:39:45 -07:00
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
staged_out_axes = tuple(out_axes_unknown) # set by out_axes_transform
|
|
|
|
staged_in_axes = (0,) * len(res) + (None,) * len(env)
|
2020-06-23 09:39:45 -07:00
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
|
|
|
|
staged_params = update_params(params, [], len(res) + len(env))
|
|
|
|
staged_params = dict(staged_params, in_axes=staged_in_axes,
|
|
|
|
out_axes=tuple(staged_out_axes),
|
|
|
|
call_jaxpr=call_jaxpr)
|
|
|
|
|
|
|
|
out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], d, a)
|
|
|
|
for d, a in zip(staged_out_axes, out_avals_mapped)]
|
|
|
|
out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None)
|
|
|
|
for a in out_avals]
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = self._current_truncated_name_stack()
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
2022-02-06 17:21:31 -08:00
|
|
|
eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
|
2022-02-28 13:36:39 -08:00
|
|
|
primitive, staged_params, jaxpr.effects, source)
|
2022-02-06 17:21:31 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
2022-02-06 17:21:31 -08:00
|
|
|
return merge_lists(out_knowns, out_tracers, out_consts)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
def out_axes_transform(out_axes):
|
2022-02-06 17:21:31 -08:00
|
|
|
nonlocal out_axes_unknown
|
|
|
|
out_axes_unknown, out_axes_known = partition_list(out_knowns, out_axes)
|
|
|
|
return tuple(out_axes_known) + (0,) * len(jaxpr.constvars)
|
|
|
|
out_axes_unknown: Optional[list] = None
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
return out, (todo, out_axes_transform)
|
2020-06-23 09:39:45 -07:00
|
|
|
|
2021-10-28 11:06:58 -07:00
|
|
|
def _current_truncated_name_stack(self):
|
|
|
|
return source_info_util.current_name_stack()[len(self.name_stack):]
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
|
2020-10-16 00:21:04 -07:00
|
|
|
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
|
|
|
|
instantiate: bool):
|
2020-07-30 12:59:36 -07:00
|
|
|
"""Partially evaluate f on a sequence of PartialVals."""
|
|
|
|
in_avals, in_consts = unzip2(pvals)
|
2020-10-16 00:21:04 -07:00
|
|
|
f = trace_to_subjaxpr(f, self.main, instantiate)
|
2020-07-30 12:59:36 -07:00
|
|
|
f, aux = partial_eval_wrapper(f, tuple(in_avals))
|
|
|
|
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
|
|
|
|
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvs = map(PartialVal, zip(out_avals, out_consts))
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
return jaxpr, out_pvs, consts, env_tracers
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
tracers = map(self.instantiate_const_abstracted, tracers)
|
|
|
|
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
|
|
|
|
fun = trace_to_subjaxpr(fun, self.main, True)
|
|
|
|
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
|
|
|
|
out_flat = prim.bind(fun, jvp, *in_consts)
|
|
|
|
out_avals, jaxpr, env = aux()
|
|
|
|
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
|
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
|
|
|
in_tracers = (*const_tracers, *env_tracers, *tracers)
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
|
|
|
|
@_memoize
|
|
|
|
def jvp_jaxpr_thunk():
|
|
|
|
jvp_ = trace_to_subjaxpr(jvp, self.main, True)
|
|
|
|
jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
|
2021-05-03 21:40:50 -07:00
|
|
|
with core.new_sublevel():
|
|
|
|
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
|
2020-10-16 00:21:04 -07:00
|
|
|
out_avals, jaxpr, env = aux()
|
|
|
|
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
|
|
|
return converted_jaxpr, (*consts, *env)
|
|
|
|
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = self._current_truncated_name_stack()
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
2020-10-16 00:21:04 -07:00
|
|
|
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
|
|
|
|
dict(fun_jaxpr=closed_jaxpr,
|
|
|
|
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
|
|
|
num_consts=len(consts) + len(env)),
|
2022-02-28 13:36:39 -08:00
|
|
|
jaxpr.effects,
|
2021-10-28 11:06:58 -07:00
|
|
|
source)
|
2020-10-16 00:21:04 -07:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
2021-12-11 14:07:30 -08:00
|
|
|
def post_process_custom_jvp_call(self, out_tracers, _):
|
2020-10-16 00:21:04 -07:00
|
|
|
# This path should only be reachable if we expose a partial eval API
|
|
|
|
# unrelated to autodiff, since we raise an error when differentiation with
|
|
|
|
# respect to values over which a custom_jvp function closes is detected.
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
|
2022-01-20 20:07:32 -08:00
|
|
|
def process_custom_transpose(self, prim, call, tracers, **params):
|
|
|
|
res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves])
|
|
|
|
assert all(t.is_known() for t in res_ts)
|
|
|
|
lin_all_known = all(t.is_known() for t in lin_ts)
|
|
|
|
if lin_all_known:
|
|
|
|
res_cvals = [t.pval[1] for t in res_ts]
|
|
|
|
lin_cvals = [t.pval[1] for t in lin_ts]
|
|
|
|
return prim.bind(call, *res_cvals, *lin_cvals, **params)
|
|
|
|
else:
|
|
|
|
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
|
|
|
|
for aval in params['out_types']]
|
|
|
|
in_tracers = map(self.instantiate_const, tracers)
|
|
|
|
new_params = dict(params, call=call)
|
|
|
|
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
|
2022-02-28 13:36:39 -08:00
|
|
|
core.no_effects, source_info_util.current())
|
2022-01-20 20:07:32 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
tracers = map(self.instantiate_const_abstracted, tracers)
|
|
|
|
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
|
|
|
|
fun = trace_to_subjaxpr(fun, self.main, True)
|
|
|
|
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
|
|
|
|
out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees)
|
|
|
|
out_avals, jaxpr, env = aux()
|
|
|
|
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
|
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
|
|
|
in_tracers = (*const_tracers, *env_tracers, *tracers)
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
|
|
|
|
@_memoize
|
|
|
|
def fwd_jaxpr_thunk():
|
|
|
|
fwd_ = trace_to_subjaxpr(fwd, self.main, True)
|
|
|
|
fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
|
2021-05-03 21:40:50 -07:00
|
|
|
with core.new_sublevel():
|
|
|
|
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
|
2020-10-16 00:21:04 -07:00
|
|
|
out_avals, jaxpr, env = aux()
|
|
|
|
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
|
|
|
return converted_jaxpr, (*consts, *env)
|
|
|
|
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = self._current_truncated_name_stack()
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
2020-10-16 00:21:04 -07:00
|
|
|
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
|
|
|
|
dict(fun_jaxpr=closed_jaxpr,
|
|
|
|
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
|
|
|
num_consts=len(consts) + len(env),
|
|
|
|
bwd=bwd, out_trees=out_trees),
|
2022-02-28 13:36:39 -08:00
|
|
|
jaxpr.effects,
|
2021-10-28 11:06:58 -07:00
|
|
|
source)
|
2020-10-16 00:21:04 -07:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
2021-12-11 14:07:30 -08:00
|
|
|
def post_process_custom_vjp_call(self, out_tracers, _):
|
2020-10-16 00:21:04 -07:00
|
|
|
# This path should only be reachable if we expose a partial eval API
|
|
|
|
# unrelated to autodiff, since we raise an error when differentiation with
|
|
|
|
# respect to values over which a custom_vjp function closes is detected.
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
def partition_pvals(
|
|
|
|
pvals: List[PartialVal]
|
|
|
|
) -> Tuple[List[bool], List[AbstractValue], List[Any]]:
|
|
|
|
knowns = [pval.is_known() for pval in pvals ]
|
|
|
|
avals = [pval.get_aval() for pval in pvals if not pval.is_known()]
|
|
|
|
consts = [pval.get_known() for pval in pvals if pval.is_known()]
|
|
|
|
return knowns, avals, consts
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2022-02-06 17:21:31 -08:00
|
|
|
def partial_eval_wrapper_nounits(
|
|
|
|
in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue],
|
|
|
|
*in_consts: Any):
|
|
|
|
in_avals_, in_consts_ = iter(in_avals), iter(in_consts)
|
|
|
|
in_pvals = [PartialVal.known(next(in_consts_)) if known else
|
|
|
|
PartialVal.unknown(next(in_avals_)) for known in in_knowns]
|
|
|
|
sentinel = object()
|
|
|
|
assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel
|
|
|
|
jaxpr, (out_pvals, res, env) = yield (in_pvals,), {}
|
|
|
|
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
|
|
|
|
yield (*out_consts, *res), (out_knowns, out_avals, jaxpr, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-08-06 11:09:29 -07:00
|
|
|
custom_partial_eval_rules: Dict[Primitive, Callable] = {}
|
|
|
|
call_partial_eval_rules: Dict[Primitive, Callable] = {}
|
|
|
|
call_param_updaters: Dict[Primitive, Callable] = {}
|
2020-05-28 17:39:13 +02:00
|
|
|
|
|
|
|
|
2021-05-01 12:28:12 -07:00
|
|
|
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
|
|
|
|
_, avals_out, _ = trace_to_jaxpr_dynamic(
|
|
|
|
lu.wrap_init(fun, params), avals, debug_info)
|
2021-03-29 13:58:04 -07:00
|
|
|
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
|
2019-07-26 16:48:17 -04:00
|
|
|
return avals_out
|
2019-02-13 14:28:30 -08:00
|
|
|
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar',
|
|
|
|
'ConstVar', Literal, core.Unit]
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['pval', 'recipe']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, trace: JaxprTrace, pval: PartialVal,
|
|
|
|
recipe: Optional[JaxprTracerRecipe]):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert isinstance(pval, PartialVal)
|
|
|
|
pv, const = pval
|
2020-02-15 06:35:49 +01:00
|
|
|
if isinstance(const, Tracer) and const._trace.level >= trace.level:
|
2020-03-28 14:55:58 -07:00
|
|
|
raise core.escaped_tracer_error(
|
2022-04-26 13:01:01 -07:00
|
|
|
const, f"Tracer from a higher level: {const} in trace {trace}")
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.pval = pval
|
|
|
|
self.recipe = recipe
|
|
|
|
|
|
|
|
def __repr__(self):
|
2020-01-29 16:23:27 -05:00
|
|
|
return 'Traced<{}:{}>'.format(self.aval, self._trace)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
2020-06-01 21:45:36 -04:00
|
|
|
def aval(self) -> AbstractValue:
|
2020-03-18 07:11:44 +01:00
|
|
|
return self.pval.get_aval()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
2022-04-26 20:34:14 -07:00
|
|
|
def parents(self) -> Sequence[JaxprTracer]:
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(self.recipe, JaxprEqnRecipe):
|
2022-04-28 22:51:41 -07:00
|
|
|
return self.recipe.in_tracers
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return []
|
|
|
|
|
|
|
|
def full_lower(self):
|
2020-03-18 07:11:44 +01:00
|
|
|
known = self.pval.get_known()
|
|
|
|
if known is not None:
|
|
|
|
return core.full_lower(known)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
def is_known(self):
|
2021-06-16 11:10:42 -07:00
|
|
|
return self.pval.is_known()
|
2020-06-12 15:03:26 +02:00
|
|
|
|
2021-12-06 15:13:01 -08:00
|
|
|
@profiler.annotate_function
|
2022-02-06 17:21:31 -08:00
|
|
|
def trace_to_jaxpr(
|
|
|
|
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
|
|
|
instantiate: Union[bool, Sequence[bool]] = False,
|
|
|
|
) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]:
|
|
|
|
"""
|
|
|
|
Partially evaluate a function, building a jaxpr for un-evaluated computation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: lu.WrappedFun representing the function to be partially evaluated. The
|
|
|
|
function must be flattened, in the sense of accepting jaxpr type arguments
|
|
|
|
and returning a flat list of jaxpr type outputs.
|
|
|
|
pvals: sequence of PartialVals of length equal to the number of inputs to
|
|
|
|
`fun` indicating which inputs are known or unknown.
|
|
|
|
instantiate: optional bool or sequence of bools of length equal to the
|
|
|
|
number of outputs of `fun` indicating which outputs should be forced to be
|
|
|
|
treated as unknown and hence instantiated in the jaxpr. If a single bool,
|
|
|
|
the value is applied to all outputs. Default False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A triple where the first element is a jaxpr representing the computation
|
|
|
|
which depends on unknown inputs; the second element is a list of PartialVals
|
|
|
|
of length equal to the length of the output of `fun` representing which
|
|
|
|
outputs are known and unknown (along with their values and abstract values,
|
|
|
|
respectively); the third element is a list of known residual values. The
|
|
|
|
returned jaxpr takes as inputs the known residual values followed by values
|
|
|
|
of the originally unknown inputs.
|
2020-03-18 07:11:44 +01:00
|
|
|
"""
|
2021-10-28 11:06:58 -07:00
|
|
|
current_name_stack = source_info_util.current_name_stack()
|
|
|
|
with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
|
2020-08-30 12:38:14 +03:00
|
|
|
fun = trace_to_subjaxpr(fun, main, instantiate)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
assert not env
|
2021-01-19 18:38:53 -08:00
|
|
|
del main, fun, env
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
return jaxpr, out_pvals, consts
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-04-23 12:08:57 -07:00
|
|
|
@profiler.annotate_function
|
|
|
|
def trace_to_jaxpr_nounits(
|
|
|
|
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
|
|
|
instantiate: Union[bool, Sequence[bool]] = False,
|
|
|
|
) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]:
|
|
|
|
current_name_stack = source_info_util.current_name_stack()
|
|
|
|
with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
|
|
|
|
fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
|
|
|
|
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
|
|
|
assert not env
|
|
|
|
del main, fun, env
|
|
|
|
return jaxpr, out_pvals, consts
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2022-02-06 17:21:31 -08:00
|
|
|
def trace_to_subjaxpr_nounits(
|
|
|
|
main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
|
|
|
|
in_pvals: Sequence[PartialVal]):
|
|
|
|
assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals
|
2021-10-28 11:06:58 -07:00
|
|
|
trace = main.with_cur_sublevel()
|
2022-02-06 17:21:31 -08:00
|
|
|
in_knowns = [pval.is_known() for pval in in_pvals]
|
|
|
|
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
|
|
|
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
|
|
|
|
in_args = merge_lists(in_knowns, in_tracers, in_consts)
|
|
|
|
ans = yield in_args, {}
|
2021-01-28 15:36:15 -08:00
|
|
|
assert isinstance(ans, (list, tuple)), (
|
|
|
|
f"Got unexpected return type when tracing function to jaxpr: {ans}")
|
|
|
|
assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), (
|
|
|
|
f"Got unexpected return type when tracing function to jaxpr: {ans}")
|
2022-02-06 17:21:31 -08:00
|
|
|
if isinstance(instantiate, bool):
|
|
|
|
instantiate = [instantiate] * len(ans)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
|
|
|
|
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
|
2019-07-26 16:48:17 -04:00
|
|
|
out_pvals = [t.pval for t in out_tracers]
|
2022-02-06 17:21:31 -08:00
|
|
|
out_tracers_ = [t for t in out_tracers if not t.is_known()]
|
|
|
|
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers_)
|
|
|
|
del trace, in_tracers, out_tracers, out_tracers_
|
2019-07-26 16:48:17 -04:00
|
|
|
yield jaxpr, (out_pvals, consts, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
|
2019-07-27 15:46:14 -07:00
|
|
|
if instantiate:
|
|
|
|
return trace.instantiate_const(trace.full_raise(tracer))
|
2019-05-10 08:58:05 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return tracer
|
2019-05-10 08:58:05 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
FreeVar = namedtuple('FreeVar', ['val'])
|
|
|
|
ConstVar = namedtuple('ConstVar', ['val'])
|
|
|
|
LambdaBinding = namedtuple('LambdaBinding', [])
|
2020-06-01 21:45:36 -04:00
|
|
|
class JaxprEqnRecipe(NamedTuple):
|
2022-04-28 22:51:41 -07:00
|
|
|
eqn_id: Any
|
|
|
|
in_tracers: Sequence[JaxprTracer]
|
|
|
|
out_tracer_refs: Sequence[ref[JaxprTracer]]
|
|
|
|
out_avals: Sequence[core.AbstractValue]
|
2021-08-06 11:09:29 -07:00
|
|
|
primitive: Primitive
|
2020-06-01 21:45:36 -04:00
|
|
|
params: Dict[str, Any]
|
2022-02-28 13:36:39 -08:00
|
|
|
effects: core.Effects
|
2021-10-29 15:49:31 -07:00
|
|
|
source_info: source_info_util.SourceInfo
|
2020-06-01 21:45:36 -04:00
|
|
|
|
2022-04-28 22:51:41 -07:00
|
|
|
def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
|
|
|
|
out_tracers: Sequence[JaxprTracer],
|
2021-08-06 11:09:29 -07:00
|
|
|
primitive: Primitive,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
params: Dict[str, Any],
|
2022-02-28 13:36:39 -08:00
|
|
|
effects: core.Effects,
|
2021-10-29 15:49:31 -07:00
|
|
|
source_info: source_info_util.SourceInfo
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
) -> JaxprEqnRecipe:
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
|
|
|
|
if primitive.call_primitive or primitive.map_primitive:
|
2020-02-05 15:38:25 +01:00
|
|
|
assert "call_jaxpr" in params
|
2021-08-30 11:10:10 -07:00
|
|
|
# assert len(invars) == len(params["call_jaxpr"].invars) # TODO constvars?
|
2022-04-28 22:51:41 -07:00
|
|
|
assert len(out_tracers) == len(params["call_jaxpr"].outvars)
|
2022-02-06 17:21:31 -08:00
|
|
|
assert ("donated_invars" not in params or
|
|
|
|
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
if primitive.map_primitive:
|
2020-11-05 11:54:05 +00:00
|
|
|
assert ("in_axes" in params and
|
|
|
|
len(params["in_axes"]) == len(params["call_jaxpr"].invars))
|
2020-06-23 09:39:45 -07:00
|
|
|
assert ("donated_invars" in params and
|
|
|
|
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
|
2022-04-28 22:51:41 -07:00
|
|
|
out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers]
|
|
|
|
return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers),
|
|
|
|
out_avals, primitive, params, effects, source_info)
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2019-11-19 12:26:30 -08:00
|
|
|
|
2021-08-06 11:09:29 -07:00
|
|
|
def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
|
2020-06-01 21:45:36 -04:00
|
|
|
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
|
2022-04-28 22:51:41 -07:00
|
|
|
(_, in_tracers, out_tracer_refs, out_avals, prim, params, eff, src) = recipe
|
2019-11-20 09:12:15 -08:00
|
|
|
invars = [getvar(t) for t in in_tracers]
|
2022-04-28 22:51:41 -07:00
|
|
|
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
|
|
|
outvars = [DropVar(a) if t is None else getvar(t) # type: ignore
|
|
|
|
for a, t in zip(out_avals, out_tracers)]
|
|
|
|
return new_jaxpr_eqn(invars, outvars, prim, params, eff, src)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def tracers_to_jaxpr(
|
2021-01-15 17:36:49 -08:00
|
|
|
in_tracers: Sequence[JaxprTracer],
|
|
|
|
out_tracers: Sequence[JaxprTracer]
|
2020-06-02 10:26:43 -04:00
|
|
|
) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Constructs Jaxpr given tracers for inputs and outputs.
|
|
|
|
|
|
|
|
Params:
|
|
|
|
in_tracers: the tracers that were created for the function inputs
|
|
|
|
out_tracers: the tracers that were output by the function.
|
|
|
|
|
|
|
|
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
|
|
|
|
the `constvars` in the returned Jaxps, and a list of environment values.
|
2020-03-18 07:11:44 +01:00
|
|
|
The vars for the environment values have been prepended to the Jaxpr's
|
2020-01-07 13:11:32 -08:00
|
|
|
`invars`.
|
|
|
|
"""
|
2020-05-26 11:28:50 -07:00
|
|
|
newvar = core.gensym()
|
2021-08-06 11:09:29 -07:00
|
|
|
t_to_var: Dict[int, Atom] = {}
|
|
|
|
def getvar(t: JaxprTracer) -> Atom:
|
2020-03-09 09:14:23 +00:00
|
|
|
var = t_to_var.get(id(t))
|
|
|
|
if var is None:
|
2022-04-26 13:01:01 -07:00
|
|
|
aval = t.pval.get_aval() if not t.pval.is_known() else core.abstract_unit
|
2020-04-08 22:29:07 -07:00
|
|
|
var = t_to_var[id(t)] = newvar(aval)
|
2020-03-09 09:14:23 +00:00
|
|
|
return var
|
2019-07-26 16:48:17 -04:00
|
|
|
sorted_tracers = toposort(out_tracers)
|
2019-11-20 09:12:15 -08:00
|
|
|
invars = map(getvar, in_tracers)
|
2020-06-01 21:45:36 -04:00
|
|
|
eqns: List[core.JaxprEqn] = []
|
2021-08-06 11:09:29 -07:00
|
|
|
env: Dict[Var, Any] = {}
|
|
|
|
consts: Dict[Var, Any] = {}
|
|
|
|
const_to_var: Dict[int, Var] = {}
|
2020-03-09 09:14:23 +00:00
|
|
|
def getconstvar(c):
|
|
|
|
var = const_to_var.get(id(c))
|
|
|
|
if var is None:
|
2020-04-08 22:29:07 -07:00
|
|
|
var = const_to_var[id(c)] = newvar(get_aval(c))
|
2020-03-09 09:14:23 +00:00
|
|
|
return var
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids = set()
|
2018-11-17 18:03:33 -08:00
|
|
|
for t in sorted_tracers:
|
|
|
|
recipe = t.recipe
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(recipe, JaxprEqnRecipe):
|
2019-11-20 09:12:15 -08:00
|
|
|
if recipe.eqn_id not in processed_eqn_ids:
|
2020-06-23 09:39:45 -07:00
|
|
|
eqns.append(recipe_to_eqn(getvar, recipe))
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids.add(recipe.eqn_id)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, LambdaBinding):
|
2020-02-15 06:35:49 +01:00
|
|
|
if not any(t is in_tracer for in_tracer in in_tracers):
|
2020-03-28 14:55:58 -07:00
|
|
|
raise core.escaped_tracer_error(
|
2021-01-05 14:52:54 -08:00
|
|
|
t, "Tracer not among input tracers {}".format(t))
|
2018-11-17 18:03:33 -08:00
|
|
|
assert in_tracers, "Lambda binding with no args"
|
|
|
|
elif isinstance(recipe, FreeVar):
|
2022-04-28 22:51:41 -07:00
|
|
|
env[getvar(t)] = recipe.val # type: ignore
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, ConstVar):
|
2020-03-09 09:14:23 +00:00
|
|
|
v = t_to_var[id(t)] = getconstvar(recipe.val)
|
2019-06-18 08:09:37 -07:00
|
|
|
consts[v] = recipe.val
|
2019-05-13 08:48:13 -07:00
|
|
|
elif isinstance(recipe, Literal):
|
|
|
|
t_to_var[id(t)] = recipe
|
2022-04-26 13:01:01 -07:00
|
|
|
elif recipe is core.unit:
|
|
|
|
t_to_var[id(t)] = core.unitvar
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError(recipe)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
env_vars, env_vals = unzip2(env.items())
|
|
|
|
const_vars, const_vals = unzip2(consts.items())
|
2020-01-07 13:11:32 -08:00
|
|
|
# The env_vars are pre-pended to the invars
|
2022-02-28 13:36:39 -08:00
|
|
|
effects = core.join_effects(*(eqn.effects for eqn in eqns))
|
|
|
|
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers),
|
|
|
|
eqns, effects)
|
2021-03-19 13:49:38 -07:00
|
|
|
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
2018-11-17 18:03:33 -08:00
|
|
|
return jaxpr, const_vals, env_vals
|
|
|
|
|
2022-03-18 19:51:29 -07:00
|
|
|
@weakref_lru_cache
|
2021-12-02 08:19:16 +02:00
|
|
|
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Moves the constvars to the start of invars."""
|
2021-03-19 13:49:38 -07:00
|
|
|
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
2020-01-07 13:11:32 -08:00
|
|
|
lifted_jaxpr = Jaxpr(constvars=(),
|
2019-07-27 15:46:14 -07:00
|
|
|
invars=jaxpr.constvars + jaxpr.invars,
|
2022-02-28 13:36:39 -08:00
|
|
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
|
|
|
effects=jaxpr.effects)
|
2021-03-19 13:49:38 -07:00
|
|
|
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
|
2019-04-09 08:45:34 -07:00
|
|
|
return lifted_jaxpr
|
|
|
|
|
2021-12-02 08:19:16 +02:00
|
|
|
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
|
2021-03-19 13:49:38 -07:00
|
|
|
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
2020-10-16 00:21:04 -07:00
|
|
|
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
|
|
|
|
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
|
2022-02-28 13:36:39 -08:00
|
|
|
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
|
|
|
|
effects=jaxpr.effects)
|
2021-03-19 13:49:38 -07:00
|
|
|
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
|
2020-10-16 00:21:04 -07:00
|
|
|
return converted_jaxpr
|
|
|
|
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
|
2022-04-26 13:01:01 -07:00
|
|
|
return (core.abstract_unit, aval) if unknown else (aval, core.abstract_unit)
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2022-04-23 12:08:57 -07:00
|
|
|
|
|
|
|
def partial_eval_jaxpr_nounits(
|
|
|
|
jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
|
|
|
|
instantiate: Union[bool, Sequence[bool]],
|
2022-04-23 15:05:26 -07:00
|
|
|
) -> Tuple[ClosedJaxpr, ClosedJaxpr, List[bool], List[AbstractValue]]:
|
2022-04-24 13:41:31 -07:00
|
|
|
"""Unzip a jaxpr in two by data dependence into 'known' and 'unknown' parts.
|
|
|
|
|
|
|
|
That is, given a jaxpr and a sequence of booleans indicating which jaxpr
|
2022-04-23 15:05:26 -07:00
|
|
|
inputs (i.e. invars) are considered unknown, produce two jaxprs, a list of
|
2022-04-24 13:41:31 -07:00
|
|
|
booleans representing which of the original jaxpr's outputs are unknown (i.e.
|
2022-04-23 15:05:26 -07:00
|
|
|
have a data dependence on an unknown input), and a list of abstract values
|
|
|
|
representing residuals (part of the first jaxpr's output and the second
|
|
|
|
jaxpr's input). The two jaxprs result from partitioning the original jaxpr's
|
|
|
|
first-order primitive applications based on whether all the inputs to the
|
|
|
|
application are known (in which case the application is represented in the
|
|
|
|
'known' jaxpr and its result is considered known) or whether any inputs to the
|
|
|
|
application are unknown (in which case the application is represented in the
|
|
|
|
'unknown' jaxpr and its result is considered unknown). Higher-order primitives
|
|
|
|
are recursively unzipped in two.
|
|
|
|
|
|
|
|
The `instantiate` argument can be used to ensure some outputs are lifted into
|
2022-04-24 13:41:31 -07:00
|
|
|
the 'unknown' jaxpr.
|
|
|
|
|
|
|
|
For example, give an input jaxpr:
|
|
|
|
|
|
|
|
{ lambda ; a:f32[] b:f32[]. let
|
|
|
|
c:f32[] = cos a
|
|
|
|
d:f32[] = sin a
|
|
|
|
e:f32[] = neg d
|
|
|
|
f:f32[] = mul e b
|
|
|
|
in (c, f) }
|
|
|
|
|
|
|
|
then applying this function with `unknowns=[False, True]` and
|
|
|
|
`instantiate=False` produces as an output triple:
|
|
|
|
|
|
|
|
# jaxpr_known
|
|
|
|
{ lambda ; a:f32[]. let
|
|
|
|
b:f32[] = cos a
|
|
|
|
c:f32[] = sin a
|
|
|
|
d:f32[] = neg c
|
|
|
|
in (b, d) }
|
|
|
|
|
|
|
|
# jaxpr_unknown
|
|
|
|
{ lambda ; a:f32[] b:f32[]. let c:f32[] = mul b a in (c,) }
|
|
|
|
|
|
|
|
# out_unknowns
|
|
|
|
[False, True]
|
|
|
|
|
|
|
|
Notice in particular that the first output (jaxpr_known) contains all the
|
|
|
|
primitive applications which do not have a data dependence on an unknown
|
|
|
|
input. Also notice the input and output types: the input type of the first
|
|
|
|
jaxpr produced represents the type of the known inputs of the original jaxpr,
|
|
|
|
and the output type of the second jaxpr produced represents the type of the
|
|
|
|
unknown outputs of the original jaxpr.
|
|
|
|
|
|
|
|
In the above example, the output of jaxpr_known named `d` is a _residual_
|
|
|
|
output, and corresponds to the input named `a` in jaxpr_unknown. In general,
|
|
|
|
jaxpr_known will produce extra outputs (at the end of its output list)
|
|
|
|
corresponding to intermeidate values of the original jaxpr which must be
|
|
|
|
passed to jaxpr_unknown (as leading inputs).
|
|
|
|
"""
|
2022-01-18 11:17:48 -08:00
|
|
|
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
2022-04-23 12:08:57 -07:00
|
|
|
return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate)
|
2022-01-18 11:17:48 -08:00
|
|
|
|
2022-03-18 19:51:29 -07:00
|
|
|
@weakref_lru_cache
|
2022-04-23 12:08:57 -07:00
|
|
|
def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
|
2019-04-23 09:15:16 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-04-11 14:50:58 -07:00
|
|
|
|
|
|
|
cell = []
|
2022-04-23 12:08:57 -07:00
|
|
|
def fun(*known_vals_in):
|
|
|
|
known_vals_in = iter(known_vals_in)
|
|
|
|
unknown_avals = (a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk)
|
|
|
|
in_pvals = [PartialVal.unknown(next(unknown_avals)) if uk
|
|
|
|
else PartialVal.known(next(known_vals_in)) for uk in in_unknowns]
|
|
|
|
assert next(known_vals_in, None) is next(unknown_avals, None) is None
|
2022-04-23 15:05:26 -07:00
|
|
|
jaxpr_unknown_, out_pvals, residuals = trace_to_jaxpr_nounits(
|
2022-04-23 12:08:57 -07:00
|
|
|
f, in_pvals, instantiate=instantiate)
|
2022-04-23 15:05:26 -07:00
|
|
|
jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_)
|
2022-04-23 12:08:57 -07:00
|
|
|
out_unknowns = [not pval.is_known() for pval in out_pvals]
|
|
|
|
res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals]
|
|
|
|
cell.append((out_unknowns, jaxpr_unknown, res_avals))
|
|
|
|
known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()]
|
|
|
|
return [*known_vals_out, *residuals]
|
|
|
|
|
2022-04-23 15:05:26 -07:00
|
|
|
known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
|
|
|
|
jaxpr_known, _, consts_known = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
|
2022-04-23 12:08:57 -07:00
|
|
|
(out_unknowns, jaxpr_unknown, res_avals), = cell
|
|
|
|
|
|
|
|
# check jaxpr_known and jaxpr_unknown in isolation
|
|
|
|
if config.jax_enable_checks:
|
|
|
|
core.check_jaxpr(jaxpr_known)
|
|
|
|
core.check_jaxpr(jaxpr_unknown)
|
|
|
|
# check jaxpr_known has input type corresponding to known inputs of jaxpr
|
|
|
|
assert ([v.aval for v in jaxpr_known.invars] ==
|
|
|
|
[a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk])
|
|
|
|
# check jaxpr_known has out type corresponding to known outs of jaxpr plus res
|
2022-04-26 23:25:47 -07:00
|
|
|
# TODO(mattjj): enable weak type checking here
|
|
|
|
assert ([v.aval.strip_weak_type() for v in jaxpr_known.outvars] ==
|
|
|
|
[a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns)
|
|
|
|
if not uk] + [a.strip_weak_type() for a in res_avals])
|
2022-04-23 12:08:57 -07:00
|
|
|
# check jaxpr_unknown has input type corresponding to unknown inputs plus res
|
|
|
|
assert ([v.aval for v in jaxpr_unknown.invars] ==
|
2022-04-23 15:05:26 -07:00
|
|
|
res_avals + [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk])
|
2022-04-23 12:08:57 -07:00
|
|
|
# check jaxpr_unknown has output type corresponding to unknown outputs
|
2022-04-26 20:34:14 -07:00
|
|
|
# TODO(mattjj): enable weak type checking here
|
|
|
|
assert ([v.aval.strip_weak_type() for v in jaxpr_unknown.outvars] ==
|
|
|
|
[a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns)
|
|
|
|
if uk])
|
2022-04-23 12:08:57 -07:00
|
|
|
|
|
|
|
closed_jaxpr_known = ClosedJaxpr(jaxpr_known, consts_known)
|
|
|
|
closed_jaxpr_unknown = ClosedJaxpr(jaxpr_unknown, ())
|
2022-04-23 15:05:26 -07:00
|
|
|
return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2021-08-06 11:09:29 -07:00
|
|
|
remat_call_p: Primitive = core.CallPrimitive('remat_call')
|
2020-06-23 09:39:45 -07:00
|
|
|
remat_call = remat_call_p.bind
|
2019-11-22 10:53:11 -08:00
|
|
|
remat_call_p.def_impl(core.call_impl)
|
|
|
|
|
2020-08-11 11:45:58 +02:00
|
|
|
def _remat_partial_eval(trace, _, f, tracers, params):
|
2019-11-22 10:53:11 -08:00
|
|
|
concrete = params['concrete']
|
|
|
|
|
|
|
|
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
|
|
|
# the function being called, not just for the unknown parts. To do that, we
|
|
|
|
# instantiate all the input tracers as constants in the jaxpr being formed.
|
|
|
|
# Those tracers might have concrete avals, and doing abstract interpretation
|
|
|
|
# on concrete avals engenders a tradeoff: it allows data-dependent Python
|
|
|
|
# control flow to work, but it can in some cases lead to redundant FLOPs (done
|
|
|
|
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
|
|
|
|
# `concrete` parameter to switch this behavior, and if `concrete` is False
|
|
|
|
# then we raise the avals to the Shaped level.
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
if concrete:
|
|
|
|
instantiated_tracers = map(trace.instantiate_const, tracers)
|
|
|
|
else:
|
|
|
|
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
2022-04-27 08:10:01 -07:00
|
|
|
# This way we record all primitives applied to the inputs (all treated as
|
|
|
|
# unknown/instantiated) to produce the output. In the context of autodiff,
|
|
|
|
# that means we record primal, residual, and tangent computations (e.g. sine,
|
|
|
|
# cosine, and multiply).
|
|
|
|
in_pvals = [t.pval for t in instantiated_tracers]
|
|
|
|
in_knowns, in_avals, () = partition_pvals(in_pvals) # all are unknown
|
|
|
|
assert not any(in_knowns)
|
|
|
|
f = trace_to_subjaxpr_nounits(f, trace.main, True)
|
|
|
|
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
|
|
|
|
consts = remat_call_p.bind(f, **params) # no known inputs
|
|
|
|
_, out_avals, jaxpr, env = aux()
|
|
|
|
env_tracers = map(trace.full_raise, env)
|
2020-06-15 18:42:53 -07:00
|
|
|
jaxpr = convert_constvars_jaxpr(jaxpr)
|
2022-04-27 08:10:01 -07:00
|
|
|
if jaxpr.effects: raise NotImplementedError
|
|
|
|
del in_pvals, in_knowns, in_avals, out_avals, f, aux, env
|
|
|
|
# When concrete=True, we could avoid some redundant computation by extracting
|
|
|
|
# values from any ConcreteArrays in `out_avals`, but we eschew that
|
|
|
|
# optimization.
|
|
|
|
|
|
|
|
# We're done with `f`, and in the steps to follow we work with `jaxpr`. To
|
|
|
|
# that end, we want a list of which inputs to `jaxpr` are known/unknown.
|
2020-06-15 18:42:53 -07:00
|
|
|
in_unknowns = ([False] * len(consts) +
|
|
|
|
[not t.is_known() for t in it.chain(env_tracers, tracers)])
|
2022-04-27 08:10:01 -07:00
|
|
|
|
2021-08-19 17:12:13 -07:00
|
|
|
if params['policy']:
|
2021-10-14 18:49:56 -07:00
|
|
|
# unzip into jaxpr_known and jaxpr_unknown
|
2021-08-25 20:46:11 -07:00
|
|
|
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = _partial_eval_jaxpr_custom(
|
2021-08-19 17:12:13 -07:00
|
|
|
jaxpr, in_unknowns, params['policy'])
|
2021-08-25 20:46:11 -07:00
|
|
|
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
|
|
|
|
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
|
|
|
|
jaxpr_unknown, in_used_unknown = dce_jaxpr(jaxpr_unknown, used_outs_unknown)
|
2021-08-06 11:09:29 -07:00
|
|
|
|
|
|
|
# compute known outputs and residuals (hoisted out of a remat_call)
|
|
|
|
_, in_consts_ = unzip2(t.pval for t in it.chain(env_tracers, tracers)
|
|
|
|
if t.pval.is_known())
|
2022-04-27 08:10:01 -07:00
|
|
|
_, known_inputs = partition_list(in_used_known, [*consts, *in_consts_])
|
|
|
|
outs = core.eval_jaxpr(jaxpr_known, (), *known_inputs)
|
|
|
|
known_outputs, res = split_list(outs, [len(out_unknowns)-sum(out_unknowns)])
|
2021-08-06 11:09:29 -07:00
|
|
|
|
|
|
|
# set up unknown outputs with a recipe to call remat
|
2022-04-27 08:10:01 -07:00
|
|
|
res_tracers = map(trace.new_instantiated_const, res)
|
2021-08-06 11:09:29 -07:00
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
|
|
|
in_jaxpr_tracers = [*res_tracers, *const_tracers, *env_tracers,
|
|
|
|
*instantiated_tracers]
|
2021-08-25 20:46:11 -07:00
|
|
|
_, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
|
2022-04-27 08:10:01 -07:00
|
|
|
unknown_outputs = [JaxprTracer(trace, PartialVal.unknown(x.aval), None)
|
|
|
|
for x in jaxpr_unknown.outvars]
|
2021-08-25 20:46:11 -07:00
|
|
|
new_params = dict(params, call_jaxpr=jaxpr_unknown, differentiated=True)
|
2022-04-27 08:10:01 -07:00
|
|
|
recipe = new_eqn_recipe(in_jaxpr_tracers, unknown_outputs, remat_call_p,
|
2022-02-28 13:36:39 -08:00
|
|
|
new_params, jaxpr_unknown.effects, source_info_util.current())
|
2022-04-27 08:10:01 -07:00
|
|
|
for t in unknown_outputs: t.recipe = recipe
|
|
|
|
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
|
2021-08-06 11:09:29 -07:00
|
|
|
else:
|
|
|
|
# TODO(mattjj): this is an old parallel code path, to be deleted once the
|
|
|
|
# new path is fully functional
|
2022-04-27 08:10:01 -07:00
|
|
|
|
|
|
|
# Now that we have a `jaxpr` which represents as much as `f` as possible, we
|
|
|
|
# want to actually compute known output values. To do that, we first extract
|
|
|
|
# a `jaxpr_known`, and compute which outputs of `jaxpr` are known/unknown.
|
|
|
|
jaxpr_known_, _, out_unknowns, res_avals = partial_eval_jaxpr_nounits(
|
|
|
|
core.ClosedJaxpr(jaxpr, ()), in_unknowns, instantiate=False) # type: ignore
|
|
|
|
jaxpr_known, () = jaxpr_known_.jaxpr, jaxpr_known_.consts
|
|
|
|
num_res = len(res_avals)
|
|
|
|
# Next, we need values for known outputs. To get them, we need to evaluate
|
|
|
|
# jaxpr_known, minus the residual outputs that we don't need. In addition to
|
|
|
|
# eliminating residual outputs, we should perform DCE to eliminate the
|
|
|
|
# computation of those residuals; for example, if the primal program
|
|
|
|
# includes a sine, jaxpr_known includes both the sine and cosine, yet we
|
|
|
|
# don't want to compute the cosine here.
|
|
|
|
known_inputs = consts + [t for t in it.chain(env_tracers, tracers)
|
|
|
|
if t.pval.is_known()]
|
|
|
|
num_known_outputs = len(out_unknowns) - sum(out_unknowns)
|
|
|
|
jaxpr_known, kept_inputs = dce_jaxpr(
|
|
|
|
jaxpr_known, [True] * num_known_outputs + [False] * num_res)
|
|
|
|
known_inputs = [x for x, kept in zip(known_inputs, kept_inputs) if kept]
|
|
|
|
known_outputs = core.eval_jaxpr(jaxpr_known, (), *known_inputs)
|
|
|
|
del jaxpr_known, res_avals, num_res, num_known_outputs, kept_inputs
|
|
|
|
|
|
|
|
# We compute unknown outputs by using the full `jaxpr`, though we can prune
|
|
|
|
# out of it any known outputs and computations and only keep those
|
|
|
|
# operations we need to compute the unknown outputs.
|
|
|
|
jaxpr, kept_inputs = dce_jaxpr(jaxpr, out_unknowns)
|
|
|
|
const_tracers = map(trace.instantiate_const, map(trace.full_raise, consts))
|
|
|
|
unknown_inputs = [*const_tracers, *env_tracers, *instantiated_tracers]
|
|
|
|
unknown_inputs = [x for x, kept in zip(unknown_inputs, kept_inputs) if kept]
|
|
|
|
unknown_outputs = [JaxprTracer(trace, PartialVal.unknown(x.aval), None)
|
|
|
|
for x in jaxpr.outvars]
|
|
|
|
eqn = new_eqn_recipe(unknown_inputs, unknown_outputs, remat_call_p,
|
|
|
|
dict(params, call_jaxpr=jaxpr,
|
|
|
|
differentiated=True),
|
|
|
|
jaxpr.effects, source_info_util.current())
|
|
|
|
for t in unknown_outputs: t.recipe = eqn
|
|
|
|
|
|
|
|
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
|
2020-08-11 11:45:58 +02:00
|
|
|
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
2020-06-15 18:42:53 -07:00
|
|
|
|
|
|
|
def _partition_knowns(pvals, unknowns: Sequence[bool]):
|
|
|
|
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
|
|
|
|
[e for e, unknown in zip(pvals, unknowns) if unknown])
|
|
|
|
|
|
|
|
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
|
2021-08-06 11:09:29 -07:00
|
|
|
assert len(known_list) + len(unknown_list) == len(which_unknown)
|
2020-06-15 18:42:53 -07:00
|
|
|
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
|
|
|
|
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
|
2020-06-12 15:03:26 +02:00
|
|
|
|
|
|
|
|
2021-08-06 11:09:29 -07:00
|
|
|
def _partial_eval_jaxpr_custom(
|
2021-08-25 20:46:11 -07:00
|
|
|
jaxpr: Jaxpr, in_unknowns: Sequence[bool], saveable: Callable[..., bool],
|
|
|
|
) -> Tuple[Jaxpr, Jaxpr, Sequence[bool], Sequence[bool], int]:
|
2021-08-27 17:42:42 -07:00
|
|
|
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
|
2021-08-06 11:09:29 -07:00
|
|
|
env: Dict[Var, Tuple[bool, bool]] = {}
|
2021-08-25 20:46:11 -07:00
|
|
|
residuals: OrderedSet[Var] = OrderedSet()
|
2021-08-06 11:09:29 -07:00
|
|
|
|
|
|
|
def read(x: Atom) -> Tuple[bool, bool]:
|
|
|
|
if type(x) is Var:
|
|
|
|
return env[x]
|
|
|
|
return (False, True)
|
|
|
|
|
|
|
|
def write(unk: bool, inst: bool, v: Var) -> None:
|
|
|
|
assert (unk, inst) != (True, False)
|
|
|
|
env[v] = (unk, inst)
|
|
|
|
|
|
|
|
def ensure_instantiated(inst: bool, x: Atom) -> Atom:
|
|
|
|
if type(x) is Var and not inst:
|
|
|
|
residuals.add(x)
|
|
|
|
return x
|
|
|
|
|
2021-08-25 20:46:11 -07:00
|
|
|
known_eqns, staged_eqns = [], []
|
2022-04-26 13:01:01 -07:00
|
|
|
write(False, True, core.unitvar)
|
2021-08-06 11:09:29 -07:00
|
|
|
map(write, in_unknowns, [True] * len(in_unknowns), jaxpr.invars)
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
unks_in, inst_in = unzip2(map(read, eqn.invars))
|
|
|
|
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
|
|
|
|
if rule:
|
|
|
|
eqn1, eqn2, unks_out, inst_out, res = rule(saveable, unks_in, inst_in, eqn)
|
2021-08-25 20:46:11 -07:00
|
|
|
eqn1 and known_eqns.append(eqn1); eqn2 and staged_eqns.append(eqn2) # type: ignore
|
2021-08-06 11:09:29 -07:00
|
|
|
residuals.update(res)
|
|
|
|
map(write, unks_out, inst_out, eqn.outvars)
|
|
|
|
elif any(unks_in):
|
|
|
|
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
2022-02-28 13:36:39 -08:00
|
|
|
staged_eqns.append(eqn.replace(invars=inputs))
|
2021-08-06 11:09:29 -07:00
|
|
|
map(partial(write, True, True), eqn.outvars)
|
|
|
|
else:
|
2021-08-25 20:46:11 -07:00
|
|
|
known_eqns.append(eqn)
|
2021-08-06 11:09:29 -07:00
|
|
|
if saveable(eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params):
|
|
|
|
map(partial(write, False, False), eqn.outvars)
|
|
|
|
else:
|
|
|
|
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
2022-02-28 13:36:39 -08:00
|
|
|
staged_eqns.append(eqn.replace(invars=inputs))
|
2021-08-06 11:09:29 -07:00
|
|
|
map(partial(write, False, True), eqn.outvars)
|
|
|
|
out_unknowns, out_inst = unzip2(map(read, jaxpr.outvars))
|
|
|
|
assert all(type(v) is Var for v in residuals), residuals
|
|
|
|
|
2021-08-25 20:46:11 -07:00
|
|
|
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
|
|
|
|
outs_known_, _ = partition_list(out_unknowns, jaxpr.outvars)
|
2022-04-26 13:01:01 -07:00
|
|
|
outs_known = [x for x in outs_known_ if x.aval is not core.abstract_unit]
|
2022-02-28 13:36:39 -08:00
|
|
|
known_effects = core.join_effects(*(eqn.effects for eqn in known_eqns))
|
|
|
|
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns,
|
|
|
|
known_effects)
|
2021-08-25 20:46:11 -07:00
|
|
|
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
|
2021-08-06 11:09:29 -07:00
|
|
|
|
2021-08-25 20:46:11 -07:00
|
|
|
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
|
2022-02-28 13:36:39 -08:00
|
|
|
staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns))
|
|
|
|
jaxpr_staged = Jaxpr((), [*residuals, *jaxpr.invars], outs_staged,
|
|
|
|
staged_eqns, staged_effects)
|
2021-08-25 20:46:11 -07:00
|
|
|
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
|
2021-08-06 11:09:29 -07:00
|
|
|
|
2021-08-25 20:46:11 -07:00
|
|
|
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals)
|
2021-08-06 11:09:29 -07:00
|
|
|
|
2021-08-19 17:12:13 -07:00
|
|
|
# A primitive rule for policy-driven partial evaluation returns a 5-tuple
|
|
|
|
# with the components representing, respectively:
|
|
|
|
# * the JaxprEqn for the 'known' side (or None if there is no known component),
|
|
|
|
# * the JaxprEqn for the 'unknown' side (or None),
|
|
|
|
# * a list of booleans indicating which of the original outputs are unknown,
|
|
|
|
# * a list of booleans indicating which of the original outputs are
|
|
|
|
# instantiated (i.e. available) in the 'unknown' side,
|
|
|
|
# * a list of Var instances representing residuals to be added (i.e. to be
|
|
|
|
# plumbed as outputs of the 'known' side jaxpr and added as input binders to
|
|
|
|
# the 'unknown' jaxpr).
|
2021-08-06 11:09:29 -07:00
|
|
|
PartialEvalCustomResult = Tuple[Optional[JaxprEqn], Optional[JaxprEqn],
|
2021-08-25 20:46:11 -07:00
|
|
|
Sequence[bool], Sequence[bool], List[Var]]
|
2021-08-06 11:09:29 -07:00
|
|
|
PartialEvalCustomRule = Callable[
|
2021-08-25 20:46:11 -07:00
|
|
|
[Callable[..., bool], Sequence[bool], Sequence[bool], JaxprEqn],
|
2021-08-06 11:09:29 -07:00
|
|
|
PartialEvalCustomResult]
|
|
|
|
partial_eval_jaxpr_custom_rules: Dict[Primitive, PartialEvalCustomRule] = {}
|
|
|
|
|
|
|
|
def partial_eval_jaxpr_custom_rule_not_implemented(
|
2022-01-19 20:40:38 -08:00
|
|
|
name: str, saveable: Callable[..., bool], unks_in: Sequence[bool],
|
|
|
|
inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult:
|
|
|
|
msg = (f'custom-policy remat rule not implemented for {name}, '
|
|
|
|
'open a feature request at https://github.com/google/jax/issues!')
|
|
|
|
raise NotImplementedError(msg)
|
2021-08-06 11:09:29 -07:00
|
|
|
|
|
|
|
|
2022-02-15 03:43:40 -08:00
|
|
|
ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
|
|
|
|
int, dict, dict], Tuple[dict, dict]]
|
2021-08-06 11:09:29 -07:00
|
|
|
|
2021-08-19 17:12:13 -07:00
|
|
|
def call_partial_eval_custom_rule(
|
2021-10-12 20:06:38 -07:00
|
|
|
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
|
|
|
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
|
|
|
|
eqn: JaxprEqn
|
2021-08-25 20:46:11 -07:00
|
|
|
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
|
2021-10-15 16:51:37 -07:00
|
|
|
jaxpr = eqn.params[jaxpr_param_name]
|
|
|
|
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
|
|
|
_partial_eval_jaxpr_custom(jaxpr, unks_in, saveable)
|
2021-08-25 20:46:11 -07:00
|
|
|
ins_known, _ = partition_list(unks_in, eqn.invars)
|
2021-10-15 16:51:37 -07:00
|
|
|
# by convention, _partial_eval_jaxpr_custom drops units on known outputs
|
|
|
|
known_units_out = [v.aval is core.abstract_unit for v in jaxpr.outvars]
|
|
|
|
dropped_outs_known = map(op.or_, unks_out, known_units_out)
|
2022-02-15 03:43:40 -08:00
|
|
|
kept_outs_known = [not d for d in dropped_outs_known]
|
2021-10-15 16:51:37 -07:00
|
|
|
out_binders_known, _ = partition_list(dropped_outs_known, eqn.outvars)
|
2021-08-25 20:46:11 -07:00
|
|
|
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
2022-02-15 03:43:40 -08:00
|
|
|
kept_outs_staged = inst_out
|
2021-08-25 20:46:11 -07:00
|
|
|
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
|
|
|
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
|
2022-02-15 03:43:40 -08:00
|
|
|
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
|
|
|
|
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
|
|
|
|
params_known, params_staged = params_updater(
|
|
|
|
unks_in, kept_outs_known, kept_outs_staged, num_res, params_known, params_staged)
|
2021-09-01 22:38:17 -07:00
|
|
|
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
2022-02-28 13:36:39 -08:00
|
|
|
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
|
2021-09-01 22:38:17 -07:00
|
|
|
eqn_staged = new_jaxpr_eqn([*residuals, *eqn.invars], out_binders_staged,
|
2022-02-28 13:36:39 -08:00
|
|
|
eqn.primitive, params_staged,
|
|
|
|
jaxpr_staged.effects, eqn.source_info)
|
2021-08-25 20:46:11 -07:00
|
|
|
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
2021-08-06 11:09:29 -07:00
|
|
|
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
|
|
|
if type(x) is Var and not inst]
|
2021-08-25 20:46:11 -07:00
|
|
|
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
2021-08-06 11:09:29 -07:00
|
|
|
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
2021-10-12 20:06:38 -07:00
|
|
|
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
2022-02-15 03:43:40 -08:00
|
|
|
lambda _, __, ___, ____, x, y: (x, y))
|
2021-08-27 17:42:42 -07:00
|
|
|
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
|
2021-10-12 20:06:38 -07:00
|
|
|
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
2022-02-15 03:43:40 -08:00
|
|
|
lambda _, __, ___, ____, x, y: (x, y))
|
2021-08-06 11:09:29 -07:00
|
|
|
partial_eval_jaxpr_custom_rules[remat_call_p] = \
|
2021-10-12 20:06:38 -07:00
|
|
|
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
2022-02-15 03:43:40 -08:00
|
|
|
lambda _, __, ___, ____, p1, p2: (p1, dict(p2, differentiated=True)))
|
2021-08-06 11:09:29 -07:00
|
|
|
|
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
|
|
|
# Compute which inputs are just forwarded to outputs.
|
|
|
|
fwds: Dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars))
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if eqn.primitive in forwarding_rules:
|
|
|
|
eqn = eqn.replace(invars=[fwds.get(v, v) for v in eqn.invars]) # type: ignore
|
|
|
|
fwd_vars, _ = forwarding_rules[eqn.primitive](eqn)
|
|
|
|
for v_orig, v_new in zip(eqn.outvars, fwd_vars):
|
|
|
|
if v_new is not None:
|
|
|
|
fwds[v_orig] = v_new
|
|
|
|
idxs: Dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)}
|
|
|
|
return [None if type(v) is Literal else idxs.get(fwds.get(v)) # type: ignore
|
|
|
|
for v in jaxpr.outvars]
|
|
|
|
|
|
|
|
|
2022-04-27 08:10:01 -07:00
|
|
|
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool]
|
2021-08-06 11:09:29 -07:00
|
|
|
) -> Tuple[Jaxpr, List[bool]]:
|
2022-04-27 08:10:01 -07:00
|
|
|
return _dce_jaxpr(jaxpr, tuple(used_outputs))
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
|
|
|
|
) -> Tuple[Jaxpr, List[bool]]:
|
2021-08-06 11:09:29 -07:00
|
|
|
env: Dict[Var, bool] = {}
|
|
|
|
|
|
|
|
def read(v: Var) -> bool:
|
|
|
|
return env.get(v, False)
|
|
|
|
|
|
|
|
def write(x: Atom, b: bool) -> None:
|
|
|
|
if type(x) is Var:
|
|
|
|
env[x] = read(x) or b
|
|
|
|
|
|
|
|
new_eqns = []
|
|
|
|
map(write, jaxpr.outvars, used_outputs)
|
|
|
|
for eqn in jaxpr.eqns[::-1]:
|
|
|
|
used_outs = map(read, eqn.outvars)
|
2021-10-14 20:41:29 -07:00
|
|
|
# If any outputs are used, then we need to keep a version of the eqn and
|
|
|
|
# potentially mark some inputs as used. Otherwise mark all inputs as unused.
|
2021-11-24 12:58:16 +02:00
|
|
|
if any(used_outs) or core.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
2021-10-14 20:41:29 -07:00
|
|
|
# If there's a rule for modifying the eqn and computing used inputs, apply
|
|
|
|
# it. Otherwise, keep the eqn unmodified and mark all inputs as used.
|
|
|
|
rule = dce_rules.get(eqn.primitive)
|
|
|
|
if rule:
|
|
|
|
used_ins, new_eqn = rule(used_outs, eqn)
|
|
|
|
else:
|
|
|
|
used_ins = [True] * len(eqn.invars)
|
|
|
|
new_eqn = eqn
|
|
|
|
new_eqns.append(new_eqn)
|
2021-08-06 11:09:29 -07:00
|
|
|
else:
|
2021-09-02 15:41:39 -07:00
|
|
|
used_ins = [False] * len(eqn.invars)
|
2021-08-06 11:09:29 -07:00
|
|
|
map(write, eqn.invars, used_ins)
|
|
|
|
used_inputs = map(read, jaxpr.invars)
|
|
|
|
|
2021-11-15 22:36:39 -08:00
|
|
|
new_jaxpr = Jaxpr(jaxpr.constvars,
|
2021-08-06 11:09:29 -07:00
|
|
|
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
|
|
|
|
[v for v, b in zip(jaxpr.outvars, used_outputs) if b],
|
2022-02-28 13:36:39 -08:00
|
|
|
new_eqns[::-1], jaxpr.effects)
|
2021-08-06 11:09:29 -07:00
|
|
|
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
|
|
|
|
|
|
|
|
return new_jaxpr, used_inputs
|
|
|
|
|
|
|
|
DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], JaxprEqn]]
|
|
|
|
dce_rules: Dict[Primitive, DCERule] = {}
|
|
|
|
|
|
|
|
|
|
|
|
def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
|
|
|
) -> Tuple[List[bool], JaxprEqn]:
|
|
|
|
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
|
|
|
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
|
|
|
|
update_params = call_param_updaters.get(eqn.primitive)
|
|
|
|
if update_params:
|
2022-02-06 17:21:31 -08:00
|
|
|
new_params = update_params(new_params, used_inputs, 0)
|
2021-09-01 22:38:17 -07:00
|
|
|
new_eqn = new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
|
|
|
|
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
2022-02-28 13:36:39 -08:00
|
|
|
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
2021-08-06 11:09:29 -07:00
|
|
|
return used_inputs, new_eqn
|
|
|
|
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
2021-08-27 17:42:42 -07:00
|
|
|
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
|
|
|
|
dce_rules[remat_call_p] = dce_jaxpr_call_rule
|
2021-08-06 11:09:29 -07:00
|
|
|
|
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
|
|
|
|
) -> ClosedJaxpr:
|
|
|
|
"""Reorder `invars` by moving those indicated in `to_move` to the front."""
|
2022-04-27 08:10:01 -07:00
|
|
|
return _move_binders_to_front(closed_jaxpr, tuple(to_move))
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Tuple[bool, ...]
|
|
|
|
) -> ClosedJaxpr:
|
2020-09-18 10:07:13 -07:00
|
|
|
assert len(closed_jaxpr.in_avals) == len(to_move)
|
|
|
|
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
|
2022-02-22 10:31:05 -08:00
|
|
|
new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars,
|
2022-02-28 13:36:39 -08:00
|
|
|
closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns,
|
|
|
|
closed_jaxpr.jaxpr.effects)
|
2020-09-18 10:07:13 -07:00
|
|
|
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
|
|
|
|
return new_closed_jaxpr
|
2019-11-27 14:28:13 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
|
2019-11-27 14:28:13 -08:00
|
|
|
return ([elt for elt, move in zip(lst, to_move) if move] +
|
|
|
|
[elt for elt, move in zip(lst, to_move) if not move])
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
|
|
|
|
) -> ClosedJaxpr:
|
|
|
|
"""Reorder `invars` by moving those indicated in `to_move` to the back."""
|
|
|
|
return move_binders_to_front(closed_jaxpr, map(op.not_, to_move))
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
class DynamicJaxprTracer(core.Tracer):
|
2021-01-05 14:52:54 -08:00
|
|
|
__slots__ = ['aval']
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def __init__(self, trace, aval, line_info=None):
|
|
|
|
self._trace = trace
|
2021-01-05 14:52:54 -08:00
|
|
|
self._line_info = line_info
|
2020-07-30 12:59:36 -07:00
|
|
|
self.aval = aval
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def _contents(self):
|
|
|
|
return ()
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def _origin_msg(self):
|
2022-01-19 11:43:02 -08:00
|
|
|
if not self._trace.main.jaxpr_stack: # type: ignore
|
|
|
|
# If this Tracer has been leaked the jaxpr stack may no longer be
|
|
|
|
# available. So we can't print as much origin information.
|
|
|
|
return ("\nThis Tracer was created on line "
|
|
|
|
f"{source_info_util.summarize(self._line_info)}")
|
|
|
|
else:
|
|
|
|
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
|
2021-05-01 12:28:12 -07:00
|
|
|
dbg = self._trace.main.debug_info
|
2021-05-06 20:23:34 -07:00
|
|
|
if dbg is None:
|
|
|
|
return ""
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
if invar_pos:
|
2021-05-01 12:28:12 -07:00
|
|
|
origin = (f"While tracing the function {dbg.func_src_info} "
|
|
|
|
f"for {dbg.traced_for}, "
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
"this concrete value was not available in Python because it "
|
2021-05-01 12:28:12 -07:00
|
|
|
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
|
|
|
|
f"of {dbg.arg_info(invar_pos)}.")
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
elif progenitor_eqns:
|
2021-09-28 10:00:45 -04:00
|
|
|
msts = [" operation "
|
2022-03-09 12:20:28 -08:00
|
|
|
f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n"
|
2020-09-18 10:49:04 -07:00
|
|
|
f" from line {source_info_util.summarize(eqn.source_info)}"
|
2021-12-14 13:29:16 -08:00
|
|
|
for eqn in progenitor_eqns[:5]] # show at most 5
|
2021-05-01 12:28:12 -07:00
|
|
|
origin = (f"While tracing the function {dbg.func_src_info} "
|
|
|
|
f"for {dbg.traced_for}, "
|
2020-09-15 08:06:46 -07:00
|
|
|
"this value became a tracer due to JAX operations on these lines:"
|
2020-09-18 10:49:04 -07:00
|
|
|
"\n\n" + "\n\n".join(msts))
|
2021-12-14 13:29:16 -08:00
|
|
|
if len(progenitor_eqns) > 5:
|
|
|
|
origin += "\n\n(Additional originating lines are not shown.)"
|
2020-09-15 08:06:46 -07:00
|
|
|
else:
|
2021-05-01 12:28:12 -07:00
|
|
|
origin = (f"The error occured while tracing the function {dbg.func_src_info} "
|
|
|
|
f"for {dbg.traced_for}.")
|
|
|
|
return "\n" + origin
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-09-16 15:59:50 -07:00
|
|
|
def _assert_live(self) -> None:
|
|
|
|
if not self._trace.main.jaxpr_stack: # type: ignore
|
2021-01-05 14:52:54 -08:00
|
|
|
raise core.escaped_tracer_error(self, None)
|
2020-09-16 15:59:50 -07:00
|
|
|
|
2022-01-25 15:27:29 -08:00
|
|
|
TracerId = int
|
2022-01-20 22:58:09 -08:00
|
|
|
AvalId = int
|
2022-01-25 15:27:29 -08:00
|
|
|
ConstId = int
|
2020-07-30 12:59:36 -07:00
|
|
|
class JaxprStackFrame:
|
2022-01-25 15:27:29 -08:00
|
|
|
gensym: Callable[[AbstractValue], Var]
|
|
|
|
tracer_to_var: Dict[TracerId, Var]
|
|
|
|
constid_to_tracer: Dict[ConstId, Tracer]
|
|
|
|
constvar_to_val: Dict[Var, Any]
|
|
|
|
tracers: List[DynamicJaxprTracer] # hold onto strong refs for all tracers
|
|
|
|
eqns: List[JaxprEqn]
|
|
|
|
invars: List[Var]
|
2022-02-28 13:36:39 -08:00
|
|
|
effects: core.Effects
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def __init__(self):
|
2021-02-11 13:23:38 -08:00
|
|
|
self.gensym = core.gensym()
|
2020-07-30 12:59:36 -07:00
|
|
|
self.tracer_to_var = {}
|
2022-01-25 15:27:29 -08:00
|
|
|
self.constid_to_tracer = {}
|
2020-07-30 12:59:36 -07:00
|
|
|
self.constvar_to_val = {}
|
2020-08-30 12:38:14 +03:00
|
|
|
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
|
|
|
|
self.eqns = [] # cleared when we pop frame from main
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
self.invars = []
|
2022-02-28 13:36:39 -08:00
|
|
|
self.effects = set()
|
|
|
|
|
|
|
|
def add_eqn(self, eqn: core.JaxprEqn):
|
|
|
|
self.eqns.append(eqn)
|
|
|
|
self.effects |= eqn.effects
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2021-12-10 23:22:11 -08:00
|
|
|
def to_jaxpr(self, out_tracers):
|
2022-01-25 15:27:29 -08:00
|
|
|
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
|
|
|
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
2020-07-30 12:59:36 -07:00
|
|
|
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
|
|
|
constvars, constvals = unzip2(self.constvar_to_val.items())
|
2022-02-28 13:36:39 -08:00
|
|
|
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, self.effects)
|
2021-11-15 21:21:29 -08:00
|
|
|
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
2021-12-10 23:22:11 -08:00
|
|
|
return jaxpr, constvals
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2021-02-11 13:23:38 -08:00
|
|
|
def newvar(self, aval):
|
2021-12-10 23:22:11 -08:00
|
|
|
if isinstance(aval, DShapedArray):
|
|
|
|
# this aval may have tracers in it, so we replace those with variables
|
|
|
|
new_shape = [self.tracer_to_var[id(d)] if isinstance(d, Tracer) else d
|
|
|
|
for d in aval.shape]
|
|
|
|
aval = aval.update(shape=tuple(new_shape))
|
2021-02-11 13:23:38 -08:00
|
|
|
return self.gensym(aval)
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def find_progenitors(self, tracer):
|
2020-09-15 08:06:46 -07:00
|
|
|
var = self.tracer_to_var.get(id(tracer))
|
|
|
|
if not var:
|
2020-10-20 16:10:56 -07:00
|
|
|
return None, None
|
2020-09-15 08:06:46 -07:00
|
|
|
active_vars = {var}
|
2020-07-30 12:59:36 -07:00
|
|
|
for eqn in self.eqns[::-1]:
|
|
|
|
produced = set(eqn.outvars) & active_vars
|
|
|
|
if produced:
|
|
|
|
active_vars.difference_update(produced)
|
|
|
|
active_vars.update(eqn.invars)
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars]
|
|
|
|
constvars = active_vars & set(self.constvar_to_val)
|
|
|
|
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
|
|
|
|
return invar_positions, const_eqns
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2021-11-15 21:21:29 -08:00
|
|
|
def _const_folding_and_forwarding(jaxpr, constvals):
|
|
|
|
consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals))
|
|
|
|
var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined
|
2021-09-21 09:05:56 -07:00
|
|
|
new_eqns = []
|
|
|
|
for eqn in jaxpr.eqns:
|
2021-11-15 21:21:29 -08:00
|
|
|
# always apply invar substitutions
|
2022-02-28 13:36:39 -08:00
|
|
|
eqn = eqn.replace(invars=[var_subs.get(v, v) for v in eqn.invars])
|
2021-11-15 21:21:29 -08:00
|
|
|
# if any inputs are constants and we have a constant-folding rule, apply it
|
|
|
|
if eqn.primitive in const_fold_rules and any(v in consts for v in eqn.invars):
|
|
|
|
consts_in = [consts.get(v) for v in eqn.invars]
|
|
|
|
consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn)
|
|
|
|
assert (new_eqn is None) == all(c is not None for c in consts_out)
|
|
|
|
for v, c in zip(eqn.outvars, consts_out):
|
|
|
|
if c is not None: consts[v] = c
|
|
|
|
if new_eqn is None: continue
|
|
|
|
else: eqn = new_eqn
|
|
|
|
# if the application trivially maps some inputs to outputs, simplify
|
|
|
|
if eqn.primitive in forwarding_rules:
|
|
|
|
fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn)
|
|
|
|
assert (new_eqn is None) == all(v is not None for v in fwd_vars)
|
|
|
|
for v_orig, v_new in zip(eqn.outvars, fwd_vars):
|
|
|
|
if v_new is not None: var_subs[v_orig] = v_new
|
|
|
|
if new_eqn is None: continue
|
|
|
|
else: eqn = new_eqn
|
2021-09-21 09:05:56 -07:00
|
|
|
new_eqns.append(eqn)
|
|
|
|
new_constvars, new_constvals = unzip2(consts.items())
|
2021-11-15 21:21:29 -08:00
|
|
|
new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars]
|
2022-02-28 13:36:39 -08:00
|
|
|
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns, jaxpr.effects)
|
2021-09-21 09:05:56 -07:00
|
|
|
return new_jaxpr, new_constvals
|
|
|
|
|
2021-11-15 21:21:29 -08:00
|
|
|
ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
|
|
|
|
Tuple[List[Optional[Any]], Optional[JaxprEqn]]]
|
|
|
|
const_fold_rules: Dict[Primitive, ConstFoldRule] = {}
|
|
|
|
|
|
|
|
ForwardingRule = Callable[[JaxprEqn],
|
|
|
|
Tuple[List[Optional[Var]], Optional[JaxprEqn]]]
|
|
|
|
forwarding_rules: Dict[Primitive, ForwardingRule] = {}
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def _inline_literals(jaxpr, constvals):
|
2021-09-21 09:05:56 -07:00
|
|
|
# This function also ensures variables are labeled in a canonical ordering,
|
|
|
|
# prunes unused constants, and inserts `dropvar` symbols.
|
2022-03-30 17:52:55 -07:00
|
|
|
lits = {v: Literal(c, v.aval) for v, c in zip(jaxpr.constvars, constvals)
|
|
|
|
if type(c) in core.literalable_types and not np.shape(c)}
|
|
|
|
lit: Callable[[Var], Optional[Literal]] = lits.get
|
2022-01-20 11:05:50 -08:00
|
|
|
newname: Callable[[AbstractValue], Var] = core.gensym()
|
|
|
|
newvars: Dict[Var, Var] = {}
|
2022-03-30 17:52:55 -07:00
|
|
|
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
|
2021-03-17 12:40:23 -07:00
|
|
|
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
|
2022-03-30 17:52:55 -07:00
|
|
|
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2022-01-20 22:58:09 -08:00
|
|
|
def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
|
|
|
|
if isinstance(aval, DShapedArray):
|
|
|
|
return [d for d in aval.shape if isinstance(d, Var)]
|
|
|
|
return []
|
|
|
|
|
|
|
|
used = {v for eqn in jaxpr.eqns for invar in eqn.invars
|
|
|
|
for v in it.chain([invar], vars_in_shape(invar.aval))}
|
|
|
|
used |= {v for outvar in jaxpr.outvars
|
|
|
|
for v in it.chain([outvar], vars_in_shape(outvar.aval))}
|
2021-09-21 09:05:56 -07:00
|
|
|
new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)]
|
|
|
|
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals)
|
|
|
|
if v in used and not lit(v)]
|
2022-01-07 08:16:43 -08:00
|
|
|
new_invars = [var(v) for v in jaxpr.invars]
|
2021-03-21 13:39:57 -07:00
|
|
|
new_eqns = []
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
invars = [lit(v) or var(v) for v in eqn.invars]
|
2022-03-30 17:52:55 -07:00
|
|
|
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
|
2022-02-28 13:36:39 -08:00
|
|
|
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
|
2021-03-17 12:40:23 -07:00
|
|
|
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
|
2022-02-28 13:36:39 -08:00
|
|
|
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
|
|
|
|
jaxpr.effects)
|
2020-07-30 12:59:36 -07:00
|
|
|
return new_jaxpr, new_constvals
|
|
|
|
|
|
|
|
class DynamicJaxprTrace(core.Trace):
|
|
|
|
__slots__ = [] # type: ignore
|
|
|
|
|
|
|
|
@property
|
2020-09-16 15:59:50 -07:00
|
|
|
def frame(self):
|
|
|
|
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def new_arg(self, aval):
|
2020-09-16 15:59:50 -07:00
|
|
|
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.tracers.append(tracer)
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
|
|
|
|
self.frame.invars.append(var)
|
2020-07-30 12:59:36 -07:00
|
|
|
return tracer
|
|
|
|
|
2022-01-25 15:27:29 -08:00
|
|
|
def new_const(self, c):
|
2022-03-30 17:52:55 -07:00
|
|
|
# TODO(mattjj): for ints, or hashable consts, don't rely on id
|
2022-01-25 15:27:29 -08:00
|
|
|
tracer = self.frame.constid_to_tracer.get(id(c))
|
|
|
|
if tracer is None:
|
|
|
|
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
|
2022-01-20 22:58:09 -08:00
|
|
|
tracer = self._new_const(aval, c)
|
|
|
|
return tracer
|
|
|
|
|
|
|
|
pure = lift = new_const
|
|
|
|
|
|
|
|
def _new_const(self, aval, c):
|
|
|
|
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
|
|
|
|
self.frame.tracers.append(tracer)
|
|
|
|
self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
|
|
|
|
self.frame.constid_to_tracer[id(c)] = tracer
|
|
|
|
self.frame.constvar_to_val[var] = c
|
2020-07-30 12:59:36 -07:00
|
|
|
return tracer
|
|
|
|
|
2022-01-20 22:58:09 -08:00
|
|
|
def sublift(self, t):
|
|
|
|
# When lifting closed-over tracers corresponding to this same trace, the
|
|
|
|
# variable to lift could have tracers (representing axis size variables) in
|
|
|
|
# its shape. We must lift those too!
|
|
|
|
tracer = self.frame.constid_to_tracer.get(id(t))
|
|
|
|
if tracer is None:
|
|
|
|
aval = raise_to_shaped(get_aval(t), weak_type=dtypes.is_weakly_typed(t))
|
|
|
|
aval = self._lift_tracers_in_aval(aval)
|
|
|
|
tracer = self._new_const(aval, t)
|
|
|
|
return tracer
|
|
|
|
|
|
|
|
def _lift_tracers_in_aval(self, aval):
|
|
|
|
if (not isinstance(aval, DShapedArray) or
|
|
|
|
not any(isinstance(d, Tracer) for d in aval.shape)):
|
|
|
|
return aval
|
|
|
|
shape = [self.full_raise(d) if isinstance(d, Tracer) else d
|
|
|
|
for d in aval.shape]
|
|
|
|
return aval.update(shape=tuple(shape))
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def getvar(self, tracer):
|
|
|
|
var = self.frame.tracer_to_var.get(id(tracer))
|
|
|
|
if var is None:
|
2021-01-05 14:52:54 -08:00
|
|
|
raise core.escaped_tracer_error(tracer)
|
2020-10-20 16:10:56 -07:00
|
|
|
return var
|
|
|
|
|
|
|
|
def makevar(self, tracer):
|
|
|
|
var = self.frame.tracer_to_var.get(id(tracer))
|
|
|
|
assert var is None, "a jaxpr variable must be created only once per tracer"
|
|
|
|
self.frame.tracers.append(tracer)
|
|
|
|
var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
|
2020-07-30 12:59:36 -07:00
|
|
|
return var
|
|
|
|
|
|
|
|
def instantiate_const(self, val):
|
2020-08-30 01:16:51 -07:00
|
|
|
if (isinstance(val, Tracer) and val._trace.main is self.main
|
2020-07-30 12:59:36 -07:00
|
|
|
and val._trace.sublevel == self.sublevel):
|
|
|
|
return val
|
|
|
|
else:
|
|
|
|
return self.new_const(val)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2022-01-20 22:58:09 -08:00
|
|
|
if primitive in custom_staging_rules:
|
|
|
|
return custom_staging_rules[primitive](self, *tracers, **params)
|
|
|
|
return self.default_process_primitive(primitive, tracers, params)
|
|
|
|
|
|
|
|
def default_process_primitive(self, primitive, tracers, params):
|
2020-07-30 12:59:36 -07:00
|
|
|
avals = [t.aval for t in tracers]
|
2022-02-28 13:36:39 -08:00
|
|
|
out_avals, effects = primitive.abstract_eval(*avals, **params)
|
2020-07-30 12:59:36 -07:00
|
|
|
out_avals = [out_avals] if not primitive.multiple_results else out_avals
|
2020-09-16 15:59:50 -07:00
|
|
|
source_info = source_info_util.current()
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
invars = map(self.getvar, tracers)
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2022-02-28 13:36:39 -08:00
|
|
|
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
|
|
|
|
self.frame.add_eqn(eqn)
|
2020-07-30 12:59:36 -07:00
|
|
|
return out_tracers if primitive.multiple_results else out_tracers.pop()
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def process_call(self, call_primitive, f, explicit_tracers, params):
|
|
|
|
if f.in_type is None:
|
|
|
|
in_avals = [core.raise_to_shaped(get_aval(x)) for x in explicit_tracers]
|
|
|
|
keep_inputs = [True] * len(explicit_tracers)
|
|
|
|
im_tracers = []
|
|
|
|
else:
|
|
|
|
im_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers)
|
|
|
|
in_avals, keep_inputs = unzip2(f.in_type)
|
2021-05-03 21:40:50 -07:00
|
|
|
with core.new_sublevel():
|
2021-12-10 23:22:11 -08:00
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
|
|
|
f, self.main, in_avals, keep_inputs=keep_inputs)
|
2022-04-12 13:32:43 -07:00
|
|
|
if jaxpr.effects:
|
|
|
|
raise NotImplementedError('Effects not supported for call primitives.')
|
2022-03-30 17:52:55 -07:00
|
|
|
tracers = [*im_tracers, *explicit_tracers]
|
2021-05-03 21:40:50 -07:00
|
|
|
if params.get('inline', False):
|
2022-03-30 17:52:55 -07:00
|
|
|
return core.eval_jaxpr(jaxpr, consts, *tracers)
|
|
|
|
env = {v: t for v, t in zip(jaxpr.constvars, consts) if isinstance(t, Tracer)}
|
|
|
|
env.update(zip(jaxpr.invars, tracers))
|
|
|
|
out_avals_ = [_substitute_tracers_in_type(env, a) for a in out_avals]
|
2020-09-16 15:59:50 -07:00
|
|
|
source_info = source_info_util.current()
|
2022-03-30 17:52:55 -07:00
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals_]
|
|
|
|
invars = map(self.getvar, tracers)
|
2020-07-30 12:59:36 -07:00
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2020-07-30 12:59:36 -07:00
|
|
|
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
|
|
|
update_params = call_param_updaters.get(call_primitive)
|
|
|
|
if update_params:
|
2022-03-30 17:52:55 -07:00
|
|
|
new_params = update_params(new_params, [True] * len(explicit_tracers),
|
|
|
|
len(consts) + len(im_tracers))
|
2021-12-10 23:22:11 -08:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
|
2022-02-28 13:36:39 -08:00
|
|
|
call_primitive, new_params,
|
|
|
|
new_params['call_jaxpr'].effects, source_info)
|
2022-04-12 13:32:43 -07:00
|
|
|
self.frame.add_eqn(eqn)
|
2020-07-30 12:59:36 -07:00
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
assert False # unreachable
|
|
|
|
|
|
|
|
def process_map(self, map_primitive, f, tracers, params):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
|
|
|
axis_name, axis_size = params['axis_name'], params['axis_size']
|
2020-11-25 15:23:00 -08:00
|
|
|
reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
|
|
|
|
if in_axis is not None else a
|
2020-11-05 11:54:05 +00:00
|
|
|
for a, in_axis in zip(in_avals, params['in_axes'])]
|
2020-08-14 18:22:04 +02:00
|
|
|
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
|
2021-05-03 21:40:50 -07:00
|
|
|
with core.new_sublevel():
|
|
|
|
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
|
|
|
f, self.main, reduced_in_avals)
|
2022-04-12 13:32:43 -07:00
|
|
|
if jaxpr.effects:
|
|
|
|
raise NotImplementedError('Effects not supported for map primitives.')
|
2021-03-05 12:24:56 +00:00
|
|
|
out_axes = params['out_axes_thunk']()
|
2021-09-08 01:41:38 -07:00
|
|
|
out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a)
|
2021-03-05 12:24:56 +00:00
|
|
|
if out_axis is not None else a
|
|
|
|
for a, out_axis in zip(reduced_out_avals, out_axes)]
|
|
|
|
source_info = source_info_util.current()
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
|
|
|
outvars = map(self.makevar, out_tracers)
|
|
|
|
new_in_axes = (None,) * len(consts) + params['in_axes']
|
|
|
|
new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
|
|
|
|
call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
|
|
|
del new_params['out_axes_thunk']
|
|
|
|
update_params = call_param_updaters.get(map_primitive)
|
|
|
|
if update_params:
|
2022-02-06 17:21:31 -08:00
|
|
|
new_params = update_params(new_params, [True] * len(tracers), len(consts))
|
2021-03-05 12:24:56 +00:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
|
2022-02-28 13:36:39 -08:00
|
|
|
new_params, new_params['call_jaxpr'].effects, source_info)
|
2022-04-12 13:32:43 -07:00
|
|
|
self.frame.add_eqn(eqn)
|
2020-07-30 12:59:36 -07:00
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_map(self, map_primitive, out_tracers, params):
|
|
|
|
assert False # unreachable
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
2021-05-03 21:40:50 -07:00
|
|
|
with core.new_sublevel():
|
|
|
|
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
2022-04-12 13:32:43 -07:00
|
|
|
if fun_jaxpr.effects:
|
|
|
|
raise NotImplementedError('Effects not supported in `custom_jvp`.')
|
2020-10-16 00:21:04 -07:00
|
|
|
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
2021-12-14 12:57:52 -08:00
|
|
|
main_ = ref(self.main)
|
2020-10-16 00:21:04 -07:00
|
|
|
jvp_jaxpr_thunk = _memoize(
|
2021-12-14 12:57:52 -08:00
|
|
|
lambda: trace_to_subjaxpr_dynamic(jvp, main_(), 2 * in_avals)[::2])
|
2020-10-16 00:21:04 -07:00
|
|
|
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2020-10-16 00:21:04 -07:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
|
|
|
dict(fun_jaxpr=closed_fun_jaxpr,
|
|
|
|
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
|
|
|
num_consts=len(consts)),
|
2022-02-28 13:36:39 -08:00
|
|
|
fun_jaxpr.effects,
|
2020-10-16 00:21:04 -07:00
|
|
|
source_info_util.current())
|
2022-04-12 13:32:43 -07:00
|
|
|
self.frame.add_eqn(eqn)
|
2020-10-16 00:21:04 -07:00
|
|
|
return out_tracers
|
|
|
|
|
2021-12-11 14:07:30 -08:00
|
|
|
def post_process_custom_jvp_call(self, out_tracers, _):
|
2020-10-16 00:21:04 -07:00
|
|
|
assert False # unreachable
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
2021-05-03 21:40:50 -07:00
|
|
|
with core.new_sublevel():
|
|
|
|
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
2022-04-12 13:32:43 -07:00
|
|
|
if fun_jaxpr.effects:
|
|
|
|
raise NotImplementedError('Effects not supported in `custom_vjp`.')
|
2020-10-16 00:21:04 -07:00
|
|
|
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
2021-12-14 12:57:52 -08:00
|
|
|
main_ = ref(self.main)
|
2020-10-16 00:21:04 -07:00
|
|
|
fwd_jaxpr_thunk = _memoize(
|
2021-12-14 12:57:52 -08:00
|
|
|
lambda: trace_to_subjaxpr_dynamic(fwd, main_(), in_avals)[::2])
|
2020-10-16 00:21:04 -07:00
|
|
|
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2020-10-16 00:21:04 -07:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
|
|
|
dict(fun_jaxpr=closed_fun_jaxpr,
|
|
|
|
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
|
|
|
num_consts=len(consts),
|
|
|
|
bwd=bwd, out_trees=out_trees),
|
2022-02-28 13:36:39 -08:00
|
|
|
fun_jaxpr.effects,
|
2020-10-16 00:21:04 -07:00
|
|
|
source_info_util.current())
|
2022-04-12 13:32:43 -07:00
|
|
|
self.frame.add_eqn(eqn)
|
2020-10-16 00:21:04 -07:00
|
|
|
return out_tracers
|
|
|
|
|
2021-12-11 14:07:30 -08:00
|
|
|
def post_process_custom_vjp_call(self, out_tracers, _):
|
2020-10-16 00:21:04 -07:00
|
|
|
assert False # unreachable
|
|
|
|
|
2022-02-18 13:44:06 -08:00
|
|
|
def process_custom_transpose(self, prim, call, tracers,
|
|
|
|
transpose, out_types,
|
|
|
|
lin_tree, res_tree, out_tree):
|
|
|
|
tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves])
|
|
|
|
|
|
|
|
in_avals_p = [t.aval for t in tracers]
|
|
|
|
in_avals_t = [*[t.aval for t in tracers_res], *out_types]
|
|
|
|
|
|
|
|
with core.new_sublevel():
|
|
|
|
call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic(
|
|
|
|
call, self.main, in_avals_p)
|
|
|
|
closed_call_jaxpr = core.ClosedJaxpr(
|
|
|
|
convert_constvars_jaxpr(call_jaxpr), ())
|
|
|
|
|
|
|
|
transpose_flat, in_tree2 = flatten_fun_nokwargs(
|
|
|
|
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
|
2022-03-24 17:27:06 -07:00
|
|
|
|
|
|
|
main_ = ref(self.main)
|
|
|
|
# the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts
|
|
|
|
transpose_jaxpr_thunk = _memoize(
|
|
|
|
lambda: trace_to_subjaxpr_dynamic(
|
|
|
|
transpose_flat, main_(), in_avals_t)[::2])
|
2022-02-18 13:44:06 -08:00
|
|
|
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, call_consts))
|
|
|
|
outvars = map(self.makevar, out_tracers)
|
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
|
|
|
|
dict(call_jaxpr=closed_call_jaxpr,
|
2022-03-24 17:27:06 -07:00
|
|
|
transpose_jaxpr_thunk=transpose_jaxpr_thunk,
|
|
|
|
out_types=out_types, res_tree=res_tree,
|
|
|
|
lin_tree=lin_tree, out_tree=out_tree),
|
2022-02-28 13:36:39 -08:00
|
|
|
closed_call_jaxpr.effects,
|
2022-02-18 13:44:06 -08:00
|
|
|
source_info_util.current())
|
2022-02-28 13:36:39 -08:00
|
|
|
self.frame.add_eqn(eqn)
|
2022-02-18 13:44:06 -08:00
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
|
2022-01-20 22:58:09 -08:00
|
|
|
custom_staging_rules: Dict[Primitive, Callable] = {}
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def _memoize(thunk):
|
2021-12-14 12:57:52 -08:00
|
|
|
if config.jax_check_tracer_leaks:
|
|
|
|
return thunk
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
cell = []
|
|
|
|
saved_state = core.thread_local_state.trace_state.copy()
|
|
|
|
def memoized():
|
|
|
|
if not cell:
|
|
|
|
prev_state = core.thread_local_state.trace_state
|
|
|
|
core.thread_local_state.trace_state = saved_state
|
|
|
|
try:
|
|
|
|
cell.append(thunk())
|
|
|
|
finally:
|
|
|
|
core.thread_local_state.trace_state = prev_state
|
|
|
|
return cell[0]
|
|
|
|
return memoized
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2021-05-01 12:28:12 -07:00
|
|
|
class DebugInfo(NamedTuple):
|
|
|
|
func_src_info: str
|
|
|
|
traced_for: str
|
|
|
|
arg_info: Callable[[int], str]
|
|
|
|
|
|
|
|
|
|
|
|
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
|
|
|
|
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
|
|
|
|
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
|
|
|
|
|
|
|
|
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
|
|
|
|
traced_for: str) -> DebugInfo:
|
|
|
|
func_src_info = fun_sourceinfo(fn)
|
|
|
|
if in_tree is not None:
|
|
|
|
arg_info = partial(arg_info_pytree, fn, in_tree, has_kwargs)
|
|
|
|
else:
|
|
|
|
arg_info = arg_info_flattened # type: ignore
|
|
|
|
return DebugInfo(func_src_info, traced_for, arg_info)
|
|
|
|
|
|
|
|
def fun_sourceinfo(fun: Callable):
|
|
|
|
while isinstance(fun, functools.partial):
|
|
|
|
fun = fun.func
|
2021-07-15 19:11:01 +01:00
|
|
|
fun = inspect.unwrap(fun)
|
2021-05-01 12:28:12 -07:00
|
|
|
try:
|
|
|
|
filename = fun.__code__.co_filename
|
|
|
|
lineno = fun.__code__.co_firstlineno
|
|
|
|
line_info = f"{fun.__name__} at {filename}:{lineno}"
|
|
|
|
return line_info
|
|
|
|
except AttributeError:
|
|
|
|
return "<unknown>"
|
|
|
|
|
|
|
|
def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
|
|
|
|
flat_pos: List[int]) -> str:
|
|
|
|
dummy_args = [False] * in_tree.num_leaves
|
|
|
|
for i in flat_pos: dummy_args[i] = True
|
|
|
|
if has_kwargs:
|
|
|
|
args, kwargs = tree_unflatten(in_tree, dummy_args)
|
|
|
|
else:
|
|
|
|
args, kwargs = tree_unflatten(in_tree, dummy_args), {}
|
|
|
|
try:
|
|
|
|
ba = inspect.signature(fn).bind(*args, **kwargs)
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
return arg_info_flattened(flat_pos)
|
|
|
|
arg_names = [f"'{name}'" for name, x in ba.arguments.items()
|
|
|
|
if any(tree_leaves(x))]
|
|
|
|
if len(arg_names) == 1:
|
|
|
|
return f"the argument {arg_names[0]}"
|
|
|
|
elif len(arg_names) == 2:
|
|
|
|
return f"the arguments {arg_names[0]} and {arg_names[1]}"
|
|
|
|
else:
|
|
|
|
*rest, last = arg_names
|
|
|
|
return f"the arguments {', '.join(rest)}, and {last}"
|
|
|
|
|
|
|
|
def arg_info_flattened(flat_pos: List[int]) -> str:
|
|
|
|
if len(flat_pos) > 1:
|
|
|
|
return f"the argument passed at flattened positions {flat_pos}"
|
|
|
|
else:
|
|
|
|
return f"the argument passed at flattened position {flat_pos[0]}"
|
|
|
|
|
|
|
|
|
2021-12-06 15:13:01 -08:00
|
|
|
@profiler.annotate_function
|
2021-03-23 19:47:58 +00:00
|
|
|
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
|
|
|
|
in_avals: Sequence[AbstractValue],
|
2021-12-10 23:22:11 -08:00
|
|
|
debug_info: Optional[DebugInfo] = None,
|
|
|
|
*,
|
|
|
|
keep_inputs: Optional[List[bool]] = None):
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
2021-05-01 12:28:12 -07:00
|
|
|
main.debug_info = debug_info # type: ignore
|
2020-08-30 01:16:51 -07:00
|
|
|
main.jaxpr_stack = () # type: ignore
|
2021-12-10 23:22:11 -08:00
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
|
|
|
fun, main, in_avals, keep_inputs=keep_inputs)
|
2021-01-19 18:38:53 -08:00
|
|
|
del main, fun
|
2020-07-30 12:59:36 -07:00
|
|
|
return jaxpr, out_avals, consts
|
|
|
|
|
2020-08-30 01:16:51 -07:00
|
|
|
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
2021-12-10 23:22:11 -08:00
|
|
|
in_avals: Sequence[AbstractValue], *,
|
2022-03-30 17:52:55 -07:00
|
|
|
keep_inputs: Optional[Sequence[bool]] = None):
|
2021-12-10 23:22:11 -08:00
|
|
|
# In general, the Tracers passed to ther Python callable underlying `fun` may
|
|
|
|
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
|
|
|
|
# the jaxpr). For example:
|
|
|
|
#
|
|
|
|
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
|
|
|
|
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
|
|
|
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
|
|
|
|
#
|
|
|
|
# @lu.wrap_init
|
|
|
|
# def f(x, y):
|
|
|
|
# return x, y
|
|
|
|
#
|
|
|
|
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
|
|
|
|
# keep_inputs=[False, True, True])
|
|
|
|
# print(jaxpr)
|
|
|
|
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
|
|
|
|
#
|
|
|
|
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
|
|
|
|
# correspondence to the input binders (invars) of the jaxpr it returns. But in
|
|
|
|
# general the Tracers passed to the function f correspond only to a subset of
|
|
|
|
# those abstract values. That's because axis size variables may not be
|
|
|
|
# explicit arguments to f, while we make everything explicit in the jaxpr.
|
|
|
|
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
frame = JaxprStackFrame()
|
2021-10-28 11:06:58 -07:00
|
|
|
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
|
2020-08-30 01:16:51 -07:00
|
|
|
trace = DynamicJaxprTrace(main, core.cur_sublevel())
|
2022-03-30 17:52:55 -07:00
|
|
|
in_tracers = _input_type_to_tracers(trace, in_avals)
|
2021-12-10 23:22:11 -08:00
|
|
|
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
|
|
|
|
ans = fun.call_wrapped(*in_tracers_)
|
2020-07-30 12:59:36 -07:00
|
|
|
out_tracers = map(trace.full_raise, ans)
|
2021-12-10 23:22:11 -08:00
|
|
|
jaxpr, consts = frame.to_jaxpr(out_tracers)
|
2021-01-19 18:38:53 -08:00
|
|
|
del fun, main, trace, frame, in_tracers, out_tracers, ans
|
2022-03-22 18:35:24 -07:00
|
|
|
if not config.jax_dynamic_shapes:
|
|
|
|
# TODO(frostig,mattjj): check_jaxpr is incomplete under dynamic
|
|
|
|
# shapes; remove this guard when it is
|
|
|
|
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
2022-01-20 22:58:09 -08:00
|
|
|
return jaxpr, [v.aval for v in jaxpr.outvars], consts
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
@contextlib.contextmanager
|
2020-08-30 01:16:51 -07:00
|
|
|
def extend_jaxpr_stack(main, frame):
|
|
|
|
main.jaxpr_stack = main.jaxpr_stack + (frame,)
|
2020-07-30 12:59:36 -07:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
2020-08-30 01:16:51 -07:00
|
|
|
assert frame is main.jaxpr_stack[-1]
|
|
|
|
main.jaxpr_stack = main.jaxpr_stack[:-1]
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2021-12-06 15:13:01 -08:00
|
|
|
@profiler.annotate_function
|
2021-03-23 19:47:58 +00:00
|
|
|
def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
|
|
|
in_avals: Sequence[AbstractValue],
|
2022-03-30 17:52:55 -07:00
|
|
|
debug_info: Optional[DebugInfo] = None,
|
|
|
|
keep_inputs: Optional[Sequence[bool]] = None):
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
2021-05-01 12:28:12 -07:00
|
|
|
main.debug_info = debug_info # type: ignore
|
2020-08-30 01:16:51 -07:00
|
|
|
main.jaxpr_stack = () # type: ignore
|
2021-05-03 21:40:50 -07:00
|
|
|
with core.new_sublevel():
|
2022-03-30 17:52:55 -07:00
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
|
|
|
fun, main, in_avals, keep_inputs=keep_inputs)
|
2021-01-19 18:38:53 -08:00
|
|
|
del fun, main
|
2020-07-30 12:59:36 -07:00
|
|
|
return jaxpr, out_avals, consts
|
|
|
|
|
2020-08-10 07:20:09 -07:00
|
|
|
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
|
|
|
|
# This function provides a partial evaluation behavior used by Flax. We can't
|
|
|
|
# use trace_to_jaxpr directly because of an interaction with the curent
|
|
|
|
# custom_derivatives.py, which we work around by adding the EvalTrace.
|
|
|
|
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
2020-08-10 07:20:09 -07:00
|
|
|
return trace_to_jaxpr(fun, in_pvals)
|
2021-12-10 23:22:11 -08:00
|
|
|
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
AbstractedAxisName = Hashable
|
|
|
|
AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], Tuple[AbstractedAxisName, ...]]
|
|
|
|
|
|
|
|
class DBIdx(NamedTuple):
|
|
|
|
val: int
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class Bound:
|
|
|
|
name: AbstractedAxisName
|
|
|
|
bound: int
|
|
|
|
|
|
|
|
InputType = Tuple[Tuple[AbstractValue, bool], ...]
|
|
|
|
|
|
|
|
def infer_lambda_input_type(
|
|
|
|
axes_specs: Optional[Sequence[AbstractedAxesSpec]],
|
|
|
|
args: Sequence[Any]
|
|
|
|
) -> InputType:
|
|
|
|
partial_specs = _canonicalize_specs(map(np.ndim, args), axes_specs)
|
|
|
|
specs = _complete_specs(args, partial_specs)
|
|
|
|
idxs, implicit_names = _collect_implicit(args, specs)
|
|
|
|
implicit_inputs = [(_implicit_arg_type(n), False) for n in implicit_names]
|
|
|
|
explicit_inputs = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)]
|
|
|
|
return (*implicit_inputs, *explicit_inputs)
|
|
|
|
|
|
|
|
def _canonicalize_specs(
|
|
|
|
ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]]
|
|
|
|
) -> List[Dict[int, AbstractedAxisName]]:
|
|
|
|
if specs is None:
|
|
|
|
return [{}] * len(ndims)
|
|
|
|
else:
|
|
|
|
return [{i: d for i, d in enumerate(s) if d is not None} if type(s) is tuple
|
|
|
|
else s for n, s in zip(ndims, specs)]
|
|
|
|
|
|
|
|
def _complete_specs(
|
|
|
|
args: Sequence[Any], partial_specs: List[Dict[int, AbstractedAxisName]]
|
|
|
|
) -> List[Dict[int, AbstractedAxisName]]:
|
|
|
|
# Identify each user-supplied name in partial_specs with a size.
|
|
|
|
sizes: Dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {}
|
|
|
|
for x, spec in zip(args, partial_specs):
|
|
|
|
for i, name in spec.items():
|
|
|
|
d = sizes.setdefault(name, x.shape[i])
|
|
|
|
if d is not x.shape[i] and d != x.shape[i]: raise TypeError
|
|
|
|
# Introduce new names as needed for Tracers in shapes.
|
|
|
|
named_tracers: Dict[TracerId, AbstractedAxisName] = {
|
|
|
|
id(d): name for name, d in sizes.items() if isinstance(d, Tracer)}
|
|
|
|
specs: List[Dict[int, AbstractedAxisName]] = []
|
|
|
|
for x, spec in zip(args, partial_specs):
|
|
|
|
if isinstance(get_aval(x), DShapedArray):
|
|
|
|
spec = dict(spec)
|
|
|
|
for i, d in enumerate(x.shape):
|
|
|
|
if isinstance(d, Tracer):
|
|
|
|
spec[i] = named_tracers.get(id(d), TracerAsName(d))
|
|
|
|
specs.append(spec)
|
|
|
|
assert all(not spec or not any(isinstance(d, Tracer) and i not in spec
|
|
|
|
for i, d in enumerate(x.shape))
|
|
|
|
for x, spec in zip(args, specs))
|
|
|
|
return specs
|
|
|
|
|
|
|
|
def _collect_implicit(
|
|
|
|
args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]]
|
|
|
|
) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractedAxisName]]:
|
|
|
|
idxs: Dict[AbstractedAxisName, DBIdx] = {}
|
|
|
|
explicit_tracers: Dict[TracerId, int] = {}
|
|
|
|
counter = (DBIdx(i) for i in it.count())
|
|
|
|
# Add implicit arguments to idxs.
|
|
|
|
for explicit_idx, (x, spec) in enumerate(zip(args, specs)):
|
|
|
|
for i, name in spec.items():
|
|
|
|
if name not in idxs and id(x.shape[i]) not in explicit_tracers:
|
|
|
|
idxs[name] = next(counter)
|
|
|
|
if isinstance(x, Tracer):
|
|
|
|
explicit_tracers[id(x)] = explicit_idx
|
|
|
|
implicit_names: List[AbstractedAxisName] = list(idxs)
|
|
|
|
|
|
|
|
# Now that we know the implicit args, add explicit args to idxs.
|
|
|
|
offset = len(implicit_names)
|
|
|
|
for x, spec in zip(args, specs):
|
|
|
|
for i, name in spec.items():
|
|
|
|
if id(x.shape[i]) in explicit_tracers:
|
|
|
|
idxs[name] = DBIdx(offset + explicit_tracers[id(x.shape[i])])
|
|
|
|
|
|
|
|
return idxs, implicit_names
|
|
|
|
|
|
|
|
def _implicit_arg_type(name: AbstractedAxisName) -> AbstractValue:
|
|
|
|
if type(name) is Bound:
|
|
|
|
return AbstractBInt(name.bound)
|
|
|
|
else:
|
|
|
|
return ShapedArray((), dtypes.dtype('int32'))
|
|
|
|
|
|
|
|
def _arg_type(
|
|
|
|
idxs: Dict[AbstractedAxisName, DBIdx], x: Any,
|
|
|
|
spec: Dict[int, AbstractedAxisName]
|
|
|
|
) -> AbstractValue:
|
|
|
|
aval = get_aval(x) # aval.shape could contain Tracers
|
|
|
|
if not spec: return core.raise_to_shaped(aval)
|
|
|
|
shape: List[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d
|
|
|
|
for i, d in enumerate(aval.shape)]
|
|
|
|
assert not any(isinstance(d, Tracer) for d in shape)
|
|
|
|
return DShapedArray(tuple(shape), aval.dtype, False)
|
|
|
|
|
|
|
|
class TracerAsName:
|
|
|
|
tracer: DynamicJaxprTracer
|
|
|
|
def __init__(self, tracer):
|
|
|
|
trace = core.thread_local_state.trace_state.trace_stack.dynamic
|
|
|
|
self.tracer = trace.with_cur_sublevel().full_raise(tracer)
|
|
|
|
def __eq__(self, other):
|
|
|
|
return isinstance(other, TracerAsName) and self.tracer is other.tracer
|
|
|
|
def __hash__(self):
|
|
|
|
return id(self.tracer)
|
|
|
|
|
|
|
|
def _extract_implicit_args(
|
|
|
|
trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]],
|
|
|
|
explicit_tracers: Sequence[DynamicJaxprTracer]
|
|
|
|
) -> Sequence[DynamicJaxprTracer]:
|
|
|
|
# First, construct a list to represent the full argument list, leaving the
|
|
|
|
# implicit arguments as Nones for now.
|
|
|
|
explicit_tracers_ = iter(explicit_tracers)
|
|
|
|
tracers = [next(explicit_tracers_) if expl else None for _, expl in in_type]
|
|
|
|
assert next(explicit_tracers_, None) is None
|
|
|
|
del explicit_tracers_
|
|
|
|
|
|
|
|
# Next, populate the implicit arguments using DBIdxs in in_type.
|
|
|
|
for i, (aval, explicit) in enumerate(in_type):
|
|
|
|
if not explicit or not isinstance(aval, DShapedArray):
|
|
|
|
continue # can't populate an implicit argument
|
|
|
|
tracer = tracers[i]
|
|
|
|
assert tracer is not None
|
|
|
|
for d1, d2 in zip(aval.shape, tracer.aval.shape):
|
|
|
|
if isinstance(d1, DBIdx):
|
|
|
|
if tracers[d1.val] is None:
|
|
|
|
tracers[d1.val] = trace.instantiate_const(d2)
|
|
|
|
assert tracers[d1.val] is trace.instantiate_const(d2)
|
|
|
|
assert all(t is not None for t in tracers)
|
|
|
|
return [t for t, (_, e) in zip(tracers, in_type) if not e]
|
|
|
|
|
|
|
|
def _in_avals_from_tracers(
|
|
|
|
tracers: List[DynamicJaxprTracer]
|
|
|
|
) -> List[AbstractValue]:
|
|
|
|
# Returned AbstractValues contain DBIdx indices. Uses Tracer obj id as name.
|
|
|
|
dbidxs: Dict[TracerId, DBIdx] = {id(t): DBIdx(i) for i, t in enumerate(tracers)}
|
|
|
|
in_avals: List[AbstractValue] = []
|
|
|
|
for t in tracers:
|
|
|
|
a = t.aval
|
|
|
|
if isinstance(a, DShapedArray) and any(isinstance(d, Tracer) for d in a.shape):
|
|
|
|
shape = [dbidxs[id(d)] if isinstance(d, Tracer) else d for d in a.shape]
|
|
|
|
a = a.update(shape=tuple(shape))
|
|
|
|
in_avals.append(a)
|
|
|
|
return in_avals
|
|
|
|
|
|
|
|
def _input_type_to_tracers(
|
2021-12-10 23:22:11 -08:00
|
|
|
trace: DynamicJaxprTrace, in_avals: Sequence[AbstractValue]
|
|
|
|
) -> Sequence[Tracer]:
|
2022-01-20 22:58:09 -08:00
|
|
|
# Create input Tracers given input AbstractValues, each of which can contain
|
2022-03-30 17:52:55 -07:00
|
|
|
# DeBruijn indices which refer to positions in the input argument list. That
|
|
|
|
# is, each element `a` of `in_avals` can have DBIdx instances in its shape,
|
|
|
|
# which must refer to positions left of `a`'s.
|
2021-12-10 23:22:11 -08:00
|
|
|
in_tracers: List[Tracer] = []
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue:
|
|
|
|
if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape):
|
|
|
|
shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] # type: ignore
|
|
|
|
return a.update(shape=tuple(shape))
|
2021-12-10 23:22:11 -08:00
|
|
|
return a
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
for a in in_avals:
|
|
|
|
in_tracers.append(trace.new_arg(_substitute_tracers_in_aval(a)))
|
|
|
|
return in_tracers
|
|
|
|
|
2022-01-20 22:58:09 -08:00
|
|
|
def _substitute_vars_in_type(
|
2022-03-30 17:52:55 -07:00
|
|
|
consts: Dict[Var, Literal], env: Dict[Var, Var], a: AbstractValue
|
2022-01-20 11:05:50 -08:00
|
|
|
) -> AbstractValue:
|
|
|
|
if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape):
|
2022-03-30 17:52:55 -07:00
|
|
|
shape = [consts[d].val if d in consts else env[d] # type: ignore
|
|
|
|
if isinstance(d, Var) else d for d in a.shape]
|
2022-01-20 22:58:09 -08:00
|
|
|
return a.update(shape=tuple(shape))
|
|
|
|
else:
|
|
|
|
return a
|
|
|
|
|
|
|
|
def _substitute_tracers_in_type(
|
|
|
|
env: Dict[Var, Tracer], a: AbstractValue
|
|
|
|
) -> AbstractValue:
|
|
|
|
# Substitutes variables into a given AbstractValue using given environment.
|
|
|
|
# That is, the input is an AbstractValue possibly containing Vars, and the
|
|
|
|
# output is an aval possibly containing Tracers.
|
|
|
|
if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape):
|
2022-01-20 11:05:50 -08:00
|
|
|
shape = [env[d] if isinstance(d, Var) else d for d in a.shape]
|
|
|
|
return a.update(shape=tuple(shape))
|
|
|
|
else:
|
|
|
|
return a
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
Const = Any
|
|
|
|
Val = Any
|
|
|
|
|
|
|
|
def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
|
|
|
|
) -> Tuple[Jaxpr, List[Const]]:
|
|
|
|
bounds = {v: v.aval.bound for v in jaxpr.invars
|
|
|
|
if type(v.aval) is AbstractBInt}
|
|
|
|
idxs = {v: DBIdx(i) for i, v in enumerate(jaxpr.invars)}
|
|
|
|
|
|
|
|
def substitute(aval: AbstractValue) -> AbstractValue:
|
|
|
|
if isinstance(aval, AbstractBInt):
|
|
|
|
return ShapedArray((), np.dtype('int32'))
|
|
|
|
elif isinstance(aval, DShapedArray):
|
|
|
|
shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore
|
|
|
|
typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray
|
|
|
|
return typ(tuple(shape), aval.dtype, aval.weak_type)
|
|
|
|
else:
|
|
|
|
return aval
|
|
|
|
|
|
|
|
in_avals = [substitute(v.aval) for v in jaxpr.invars]
|
|
|
|
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
|
|
|
|
padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals)
|
|
|
|
return padded_jaxpr, padded_consts
|
|
|
|
|
|
|
|
class BoundedAxisSize(NamedTuple):
|
|
|
|
val: Union[int, DynamicJaxprTracer]
|
|
|
|
bound: int
|
|
|
|
|
|
|
|
def _eval_jaxpr_padded(
|
|
|
|
jaxpr: Jaxpr, consts: List[Const], *args: DynamicJaxprTracer
|
|
|
|
) -> List[Union[Const, DynamicJaxprTracer]]:
|
|
|
|
env: Dict[Var, Val] = {}
|
|
|
|
|
|
|
|
def read(x):
|
|
|
|
return x.val if type(x) is Literal else env[x]
|
|
|
|
|
|
|
|
def write(v, val) -> None:
|
|
|
|
env[v] = val
|
|
|
|
|
2022-04-26 13:01:01 -07:00
|
|
|
write(core.unitvar, core.unit)
|
2022-03-30 17:52:55 -07:00
|
|
|
map(write, jaxpr.constvars, consts)
|
|
|
|
map(write, jaxpr.invars, args)
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
rule = padding_rules[eqn.primitive]
|
|
|
|
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
|
|
|
|
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
|
|
|
|
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
|
|
|
|
map(write, eqn.outvars, outs)
|
|
|
|
return map(read, jaxpr.outvars)
|
|
|
|
|
|
|
|
def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:
|
|
|
|
if isinstance(aval, DShapedArray):
|
|
|
|
shp = [BoundedAxisSize(env[d], d.aval.bound) if type(d) is Var and
|
|
|
|
type(d.aval) is AbstractBInt else env.get(d, d) for d in aval.shape]
|
|
|
|
return DShapedArray(tuple(shp), aval.dtype, aval.weak_type)
|
|
|
|
else:
|
|
|
|
return aval
|
|
|
|
|
|
|
|
padding_rules: Dict[Primitive, Callable] = {}
|
|
|
|
|
|
|
|
def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
|
|
|
|
if call_jaxpr.constvars: raise NotImplementedError
|
|
|
|
padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ())
|
|
|
|
if padded_consts: raise NotImplementedError
|
|
|
|
new_params = dict(params, call_jaxpr=padded_jaxpr)
|
|
|
|
subfuns, bind_params = prim.get_bind_params(new_params)
|
|
|
|
return prim.bind(*subfuns, *args, **bind_params)
|
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
|
|
|
|
# TODO(mattjj): the following are deprecated; update callers to _nounits version
|
|
|
|
# See https://github.com/google/jax/pull/9498
|
|
|
|
@lu.transformation
|
|
|
|
def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
|
|
|
|
pvals: Sequence[PartialVal]):
|
|
|
|
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
2021-10-28 11:06:58 -07:00
|
|
|
trace = main.with_cur_sublevel()
|
2022-02-06 17:21:31 -08:00
|
|
|
in_tracers = map(trace.new_arg, pvals)
|
|
|
|
ans = yield in_tracers, {}
|
|
|
|
assert isinstance(ans, (list, tuple)), (
|
|
|
|
f"Got unexpected return type when tracing function to jaxpr: {ans}")
|
|
|
|
assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), (
|
|
|
|
f"Got unexpected return type when tracing function to jaxpr: {ans}")
|
|
|
|
instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
|
|
|
|
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
|
|
|
|
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
|
|
|
|
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
|
|
|
|
out_pvals = [t.pval for t in out_tracers]
|
|
|
|
del trace, in_tracers, out_tracers
|
|
|
|
yield jaxpr, (out_pvals, consts, env)
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
|
|
|
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
|
|
|
|
py_args = map(PartialVal, zip(pvs, consts))
|
|
|
|
jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
|
|
|
|
out_pvs, out_consts = unzip2(out_pvals)
|
|
|
|
out = tuple(out_consts) + tuple(consts)
|
|
|
|
yield out, (out_pvs, jaxpr, env)
|
2022-04-23 12:08:57 -07:00
|
|
|
|
|
|
|
def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
|
|
|
|
instantiate: Union[bool, Sequence[bool]],
|
|
|
|
) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
|
|
|
|
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
|
|
|
return _partial_eval_jaxpr(jaxpr, tuple(unknowns), instantiate)
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
def _partial_eval_jaxpr(jaxpr, unknowns, instantiate):
|
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
|
|
|
|
|
|
|
cell = []
|
|
|
|
def fun(*vals):
|
|
|
|
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
|
|
|
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
|
|
|
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
|
|
|
|
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
|
|
|
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
|
|
|
return out_consts_2 + consts_2
|
|
|
|
|
|
|
|
# For jaxpr_known we pass core.unit for the unknown inputs, and known
|
|
|
|
# PartialVal for the known inputs.
|
2022-04-26 13:01:01 -07:00
|
|
|
in_avals = [core.abstract_unit if uk else a
|
|
|
|
for a, uk in zip(jaxpr.in_avals, unknowns)]
|
2022-04-23 12:08:57 -07:00
|
|
|
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
|
|
|
|
(out_pvs_2, jaxpr_2, num_res), = cell
|
|
|
|
assert len(jaxpr_2.constvars) == num_res
|
|
|
|
|
|
|
|
# jaxpr :: a -> b
|
|
|
|
# jaxpr_1 :: a1 -> [b1, res]
|
|
|
|
# jaxpr_2 :: res | a2 -> b2
|
|
|
|
# jaxpr_2 :: [a2, res] -> b2
|
|
|
|
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
|
|
|
|
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
|
|
|
|
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
|
|
|
|
if not unknown:
|
2022-04-26 13:01:01 -07:00
|
|
|
var.aval = core.abstract_unit
|
2022-04-23 12:08:57 -07:00
|
|
|
|
|
|
|
uk_out = [pv is not None for pv in out_pvs_2]
|
|
|
|
|
|
|
|
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
|
|
|
|
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
|
|
|
|
# out_avals_1 and in_avals_2 need the residuals added
|
|
|
|
res_avals = out_avals[len(jaxpr.out_avals):]
|
|
|
|
assert len(res_avals) == num_res
|
|
|
|
out_avals_1 = [*out_avals_1, *res_avals]
|
|
|
|
in_avals_2 = [*in_avals_2, *res_avals]
|
|
|
|
|
|
|
|
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
|
2022-04-27 08:10:01 -07:00
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
def _drop_vars(jaxpr: Jaxpr, drop_ins: Tuple[bool, ...], drop_outs: Tuple[bool, ...]):
|
|
|
|
return Jaxpr(jaxpr.constvars,
|
|
|
|
[v for v, d in zip(jaxpr.invars, drop_ins) if not d],
|
|
|
|
[v for v, d in zip(jaxpr.outvars, drop_outs) if not d],
|
|
|
|
jaxpr.eqns, jaxpr.effects)
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
|
|
|
|
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
|
|
|
# nontrivially DCE through scan, call, or other higher-order primitives.
|
|
|
|
# TODO(mattjj): better DCE (i.e. use above dce_jaxpr)
|
|
|
|
if drop_outputs:
|
|
|
|
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
|
|
|
|
else:
|
|
|
|
new_outvars = [var if output else core.unitvar
|
|
|
|
for var, output in zip(jaxpr.outvars, outputs)]
|
|
|
|
|
|
|
|
needed_vars = {v for v in new_outvars if type(v) is not Literal}
|
|
|
|
new_eqns = []
|
|
|
|
for eqn in jaxpr.eqns[::-1]:
|
|
|
|
if set(eqn.outvars) & needed_vars:
|
|
|
|
new_eqns.append(eqn)
|
|
|
|
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
|
|
|
|
new_eqns = new_eqns[::-1]
|
|
|
|
return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, new_eqns,
|
|
|
|
jaxpr.effects)
|