2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import itertools as it
|
2020-03-09 20:42:08 +01:00
|
|
|
from collections import namedtuple
|
2020-07-30 12:59:36 -07:00
|
|
|
import contextlib
|
|
|
|
import functools
|
|
|
|
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
|
2020-09-15 08:06:46 -07:00
|
|
|
List, Union, cast, Type, no_type_check)
|
2019-11-19 12:26:30 -08:00
|
|
|
from weakref import ref
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2019-05-28 22:38:06 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from .. import core
|
2020-07-30 12:59:36 -07:00
|
|
|
from .. import dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
from .. import linear_util as lu
|
2020-05-27 13:57:47 +00:00
|
|
|
from ..ad_util import Zero
|
2021-01-11 14:20:32 -08:00
|
|
|
from .._src.util import (unzip2, safe_zip, safe_map, toposort, partial,
|
|
|
|
split_list, cache, as_hashable_function)
|
2020-07-30 12:59:36 -07:00
|
|
|
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
2020-09-18 10:07:13 -07:00
|
|
|
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
|
2020-11-18 21:17:02 -05:00
|
|
|
dropvar, ConcreteArray, raise_to_shaped)
|
2020-11-04 11:54:01 -08:00
|
|
|
from jax._src import source_info_util
|
2020-07-30 12:59:36 -07:00
|
|
|
from ..config import config
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
|
|
|
zip = safe_zip
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
class PartialVal(tuple):
|
|
|
|
"""Partial value: either a known value or an unknown (abstract) value.
|
|
|
|
|
|
|
|
Represented as a pair `(aval_opt, const)` of one of two kinds:
|
|
|
|
* `(None, <Constant>)` indicates a known value, either a Python regular
|
|
|
|
value, or a Tracer.
|
|
|
|
* `(<AbstractValue>, *)` indicates an unknown value characterized by an
|
|
|
|
abstract value.
|
|
|
|
"""
|
|
|
|
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
|
|
|
|
pv, const = xs
|
|
|
|
if not core.skip_checks:
|
|
|
|
# type checks
|
|
|
|
assert isinstance(pv, (AbstractValue, type(None))), xs
|
2020-05-27 13:57:47 +00:00
|
|
|
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
|
2020-03-18 07:11:44 +01:00
|
|
|
# invariant checks
|
|
|
|
if isinstance(pv, AbstractValue):
|
2020-07-30 12:59:36 -07:00
|
|
|
assert get_aval(const) == core.abstract_unit, xs
|
2020-03-18 07:11:44 +01:00
|
|
|
return tuple.__new__(cls, xs)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def known(cls, const: core.Value) -> 'PartialVal':
|
|
|
|
return PartialVal((None, const))
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def unknown(cls, aval: AbstractValue) -> 'PartialVal':
|
|
|
|
return PartialVal((aval, core.unit))
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def is_known(self) -> bool:
|
2020-03-18 07:11:44 +01:00
|
|
|
return self[0] is None
|
|
|
|
|
|
|
|
def get_known(self) -> Optional[core.Value]:
|
|
|
|
"""Get the known value, if known, else None."""
|
|
|
|
return self[1] if self[0] is None else None
|
|
|
|
|
|
|
|
def get_aval(self) -> AbstractValue:
|
2020-06-23 09:39:45 -07:00
|
|
|
"""Get AbstractValue directly (if unknown) or from the constant (known)."""
|
2020-03-18 07:11:44 +01:00
|
|
|
known = self.get_known()
|
|
|
|
if known is not None:
|
|
|
|
return get_aval(known)
|
|
|
|
else:
|
|
|
|
return self[0]
|
|
|
|
|
|
|
|
def merge_with_known(self, val: core.Value) -> core.Value:
|
|
|
|
"""Either the stored known value, or the given 'val'."""
|
|
|
|
known = self.get_known()
|
|
|
|
return known if known is not None else val
|
2019-05-09 07:23:39 -07:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTrace(Trace):
|
2020-06-01 21:45:36 -04:00
|
|
|
def pure(self, val) -> 'JaxprTracer':
|
2019-09-20 15:35:43 -07:00
|
|
|
return self.new_const(val)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def lift(self, val) -> 'JaxprTracer':
|
2018-11-17 18:03:33 -08:00
|
|
|
return self.new_const(val)
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def sublift(self, val) -> 'JaxprTracer':
|
2018-11-17 18:03:33 -08:00
|
|
|
return JaxprTracer(self, val.pval, FreeVar(val))
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def new_const(self, val) -> 'JaxprTracer':
|
2020-01-29 16:23:27 -05:00
|
|
|
if isinstance(val, Tracer) and val._trace.level == self.level:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise Exception
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.known(val), unit)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def new_instantiated_literal(self, val) -> 'JaxprTracer':
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), Literal(val))
|
2019-05-13 08:48:13 -07:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def new_instantiated_const(self, val) -> 'JaxprTracer':
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def new_arg(self, pval: PartialVal) -> 'JaxprTracer':
|
2020-06-05 17:22:55 +02:00
|
|
|
const = pval.get_known()
|
|
|
|
if const is None:
|
|
|
|
return JaxprTracer(self, pval, LambdaBinding())
|
|
|
|
else:
|
|
|
|
return self.new_const(const)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 10:26:43 -04:00
|
|
|
def instantiate_const(self, tracer) -> Tracer:
|
2020-03-18 07:11:44 +01:00
|
|
|
const = tracer.pval.get_known()
|
|
|
|
if const is None:
|
2018-11-17 18:03:33 -08:00
|
|
|
return tracer
|
2020-03-18 07:11:44 +01:00
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
if type(const) in core.literalable_types and np.shape(const) == ():
|
2019-09-20 15:35:43 -07:00
|
|
|
return self.new_instantiated_literal(const)
|
2019-05-13 08:48:13 -07:00
|
|
|
else:
|
|
|
|
return self.new_instantiated_const(const)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer':
|
2020-03-18 07:11:44 +01:00
|
|
|
const = tracer.pval.get_known()
|
|
|
|
if const is None:
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
return tracer
|
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-04-01 16:03:56 -04:00
|
|
|
if primitive in custom_partial_eval_rules:
|
2019-07-27 15:46:14 -07:00
|
|
|
return custom_partial_eval_rules[primitive](self, *tracers, **params)
|
2019-04-01 16:03:56 -04:00
|
|
|
else:
|
2020-01-30 15:03:00 -08:00
|
|
|
return self.default_process_primitive(primitive, tracers, params)
|
|
|
|
|
|
|
|
def default_process_primitive(self, primitive, tracers, params):
|
2020-04-17 20:08:24 +03:00
|
|
|
"""By default, if all the input tracers are known, then execute the primitive
|
|
|
|
and all the ouputs are known. Otherwise, all the outputs are unknown."""
|
2020-07-30 12:59:36 -07:00
|
|
|
consts = [t.pval.get_known() for t in tracers]
|
2020-03-18 07:11:44 +01:00
|
|
|
if all(c is not None for c in consts):
|
2020-02-09 21:06:37 -08:00
|
|
|
return primitive.bind(*consts, **params)
|
|
|
|
tracers = map(self.instantiate_const, tracers)
|
|
|
|
avals = [t.aval for t in tracers]
|
|
|
|
out_aval = primitive.abstract_eval(*avals, **params)
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source = source_info_util.current()
|
2020-02-09 21:06:37 -08:00
|
|
|
if primitive.multiple_results:
|
2020-03-18 07:11:44 +01:00
|
|
|
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
|
2020-02-09 21:06:37 -08:00
|
|
|
for aval in out_aval]
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
|
2020-02-09 21:06:37 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
|
|
|
|
params, source)
|
2020-02-09 21:06:37 -08:00
|
|
|
return out_tracer
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
# We use process_call to handle both call and map primitives.
|
2020-06-12 15:03:26 +02:00
|
|
|
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
2020-07-30 12:59:36 -07:00
|
|
|
if not config.omnistaging_enabled:
|
2020-09-15 08:06:46 -07:00
|
|
|
if (self.main.trace_type is StagingJaxprTrace # type: ignore
|
|
|
|
and primitive in staged_out_calls): # type: ignore
|
2020-07-30 12:59:36 -07:00
|
|
|
tracers = map(self.instantiate_const_abstracted, tracers)
|
2020-04-21 18:12:02 -07:00
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
if primitive in call_partial_eval_rules:
|
|
|
|
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
|
2020-05-28 17:39:13 +02:00
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
in_pvals = [t.pval for t in tracers]
|
|
|
|
if primitive.map_primitive:
|
2020-06-23 09:39:45 -07:00
|
|
|
mapped_aval = partial(core.mapped_aval, params['axis_size'])
|
2020-11-05 11:54:05 +00:00
|
|
|
in_pvals = [pval if pval.is_known() or in_axis is None
|
|
|
|
else PartialVal.unknown(mapped_aval(in_axis, pval[0]))
|
|
|
|
for pval, in_axis in zip(in_pvals, params['in_axes'])]
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
|
|
|
|
def app(f, *args):
|
|
|
|
f, num_outputs = count_outputs(f)
|
|
|
|
out_axes_thunk = params['out_axes_thunk']
|
2020-12-02 14:13:05 +00:00
|
|
|
@as_hashable_function(closure=out_axes_thunk)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def new_out_axes_thunk():
|
|
|
|
out_axes = out_axes_thunk()
|
|
|
|
return out_axes + (0,) * (num_outputs() - len(out_axes))
|
|
|
|
pe_params = dict(params, out_axes_thunk=new_out_axes_thunk)
|
|
|
|
return primitive.bind(f, *args, **pe_params)
|
|
|
|
else:
|
|
|
|
app = partial(primitive.bind, **params)
|
2020-05-28 17:39:13 +02:00
|
|
|
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
f, in_pvals, app, instantiate=False)
|
2020-06-12 15:03:26 +02:00
|
|
|
if primitive.map_primitive:
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
|
|
|
|
out_axes = params['out_axes_thunk']()
|
|
|
|
out_pvals = [pval if pval.is_known() else
|
|
|
|
PartialVal.unknown(unmapped_aval(out_axis, pval[0])) if out_axis is not None else
|
|
|
|
PartialVal.unknown(pval[0])
|
|
|
|
for pval, out_axis in zip(out_pvals, out_axes)]
|
2020-06-12 15:03:26 +02:00
|
|
|
|
|
|
|
# Skip known invars and outvars, and lift constants as regular invars
|
|
|
|
in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
|
|
|
|
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
|
|
|
|
jaxpr = _drop_invars(jaxpr, in_knowns)
|
2020-09-18 10:07:13 -07:00
|
|
|
jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
|
2020-06-12 15:03:26 +02:00
|
|
|
|
|
|
|
# Known tracers get propagated as if they were constants
|
2020-06-23 09:39:45 -07:00
|
|
|
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
|
|
|
|
if pval.is_known()]
|
2020-06-12 15:03:26 +02:00
|
|
|
|
|
|
|
# Unknown tracers need to have the jaxpr set up as their recipe
|
2020-06-23 09:39:45 -07:00
|
|
|
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
|
|
|
|
if not pval.is_known()]
|
2020-06-12 15:03:26 +02:00
|
|
|
unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
|
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
2020-06-23 09:39:45 -07:00
|
|
|
in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)
|
|
|
|
|
|
|
|
# Set up new params
|
|
|
|
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
2020-06-12 15:03:26 +02:00
|
|
|
if primitive.map_primitive:
|
2020-11-05 11:54:05 +00:00
|
|
|
in_axes = params['in_axes']
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
# NOTE: const_tracers are added as map outputs, and we always map them
|
|
|
|
# along axis 0 (see `new_out_axes_thunk` above).
|
2020-11-05 11:54:05 +00:00
|
|
|
new_in_axes = ((0,) * len(const_tracers) +
|
|
|
|
(None,) * len(env_tracers) +
|
|
|
|
tuple(axis for axis, t in zip(in_axes, tracers)
|
|
|
|
if not t.pval.is_known()))
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals)
|
|
|
|
if not pval.is_known())
|
|
|
|
new_params = dict(new_params, in_axes=new_in_axes, out_axes=new_out_axes)
|
|
|
|
del new_params['out_axes_thunk']
|
2020-06-23 09:39:45 -07:00
|
|
|
update_params = call_param_updaters.get(primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, [not t.pval.is_known() for t in tracers])
|
|
|
|
|
|
|
|
eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info_util.current())
|
2020-06-23 09:39:45 -07:00
|
|
|
for t in unknown_tracers_out: t.recipe = eqn
|
2020-06-12 15:03:26 +02:00
|
|
|
return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
process_map = process_call
|
|
|
|
|
|
|
|
# We use post_process_call to handle both call and map primitives.
|
|
|
|
def post_process_call(self, primitive, out_tracers, params):
|
2020-04-21 18:12:02 -07:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
|
|
|
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
|
|
|
out = out_pv_consts + consts
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
nconsts = len(consts)
|
2020-04-21 18:12:02 -07:00
|
|
|
del consts, out_pv_consts
|
2020-08-30 12:38:14 +03:00
|
|
|
main = self.main
|
2020-04-21 18:12:02 -07:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
if primitive.map_primitive:
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
out_axes = params['out_axes_thunk']()
|
2020-06-23 09:39:45 -07:00
|
|
|
sz = params['axis_size']
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
out_pvs = [None if pv is None else core.unmapped_aval(sz, ax, pv)
|
|
|
|
for pv, ax in zip(out_pvs, out_axes)]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2019-09-20 07:01:01 -07:00
|
|
|
def todo(x):
|
|
|
|
n = len(jaxpr.outvars)
|
|
|
|
out_pv_consts, consts = x[:n], x[n:]
|
2020-08-30 12:38:14 +03:00
|
|
|
trace = JaxprTrace(main, core.cur_sublevel())
|
2019-09-20 07:01:01 -07:00
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
|
|
|
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2020-06-23 09:39:45 -07:00
|
|
|
in_tracers = (*const_tracers, *map(trace.full_raise, env))
|
|
|
|
|
|
|
|
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
|
|
|
if primitive.map_primitive:
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
# NOTE: We've assigned axis 0 to const tracers below, in out_axes_transform.
|
2020-11-05 11:54:05 +00:00
|
|
|
new_in_axes = (0,) * len(const_tracers) + (None,) * len(env)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
new_params = dict(new_params, in_axes=new_in_axes, out_axes=out_axes)
|
|
|
|
del new_params['out_axes_thunk']
|
2020-06-23 09:39:45 -07:00
|
|
|
update_params = call_param_updaters.get(primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, [])
|
|
|
|
|
|
|
|
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info_util.current())
|
2019-09-20 07:01:01 -07:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
|
|
|
|
if primitive.map_primitive:
|
|
|
|
def out_axes_transform(out_axes):
|
|
|
|
return out_axes + (0,) * nconsts
|
|
|
|
todo = (todo, out_axes_transform)
|
|
|
|
|
2019-09-20 07:01:01 -07:00
|
|
|
return out, todo
|
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
post_process_map = post_process_call
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
|
2020-10-16 00:21:04 -07:00
|
|
|
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
|
|
|
|
instantiate: bool):
|
2020-07-30 12:59:36 -07:00
|
|
|
"""Partially evaluate f on a sequence of PartialVals."""
|
|
|
|
in_avals, in_consts = unzip2(pvals)
|
2020-10-16 00:21:04 -07:00
|
|
|
f = trace_to_subjaxpr(f, self.main, instantiate)
|
2020-07-30 12:59:36 -07:00
|
|
|
f, aux = partial_eval_wrapper(f, tuple(in_avals))
|
|
|
|
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
|
|
|
|
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvs = map(PartialVal, zip(out_avals, out_consts))
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
return jaxpr, out_pvs, consts, env_tracers
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
tracers = map(self.instantiate_const_abstracted, tracers)
|
|
|
|
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
|
|
|
|
fun = trace_to_subjaxpr(fun, self.main, True)
|
|
|
|
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
|
|
|
|
out_flat = prim.bind(fun, jvp, *in_consts)
|
|
|
|
out_avals, jaxpr, env = aux()
|
|
|
|
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
|
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
|
|
|
in_tracers = (*const_tracers, *env_tracers, *tracers)
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
|
|
|
|
@_memoize
|
|
|
|
def jvp_jaxpr_thunk():
|
|
|
|
jvp_ = trace_to_subjaxpr(jvp, self.main, True)
|
|
|
|
jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
|
|
|
|
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
|
|
|
|
out_avals, jaxpr, env = aux()
|
|
|
|
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
|
|
|
return converted_jaxpr, (*consts, *env)
|
|
|
|
|
|
|
|
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
|
|
|
|
dict(fun_jaxpr=closed_jaxpr,
|
|
|
|
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
|
|
|
num_consts=len(consts) + len(env)),
|
|
|
|
source_info_util.current())
|
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_custom_jvp_call(self, out_tracers, params):
|
|
|
|
# This path should only be reachable if we expose a partial eval API
|
|
|
|
# unrelated to autodiff, since we raise an error when differentiation with
|
|
|
|
# respect to values over which a custom_jvp function closes is detected.
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
tracers = map(self.instantiate_const_abstracted, tracers)
|
|
|
|
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
|
|
|
|
fun = trace_to_subjaxpr(fun, self.main, True)
|
|
|
|
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
|
|
|
|
out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees)
|
|
|
|
out_avals, jaxpr, env = aux()
|
|
|
|
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
|
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
|
|
|
in_tracers = (*const_tracers, *env_tracers, *tracers)
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
|
|
|
|
@_memoize
|
|
|
|
def fwd_jaxpr_thunk():
|
|
|
|
fwd_ = trace_to_subjaxpr(fwd, self.main, True)
|
|
|
|
fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
|
|
|
|
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
|
|
|
|
out_avals, jaxpr, env = aux()
|
|
|
|
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
|
|
|
return converted_jaxpr, (*consts, *env)
|
|
|
|
|
|
|
|
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
|
|
|
|
dict(fun_jaxpr=closed_jaxpr,
|
|
|
|
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
|
|
|
num_consts=len(consts) + len(env),
|
|
|
|
bwd=bwd, out_trees=out_trees),
|
|
|
|
source_info_util.current())
|
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_custom_vjp_call(self, out_tracers, params):
|
|
|
|
# This path should only be reachable if we expose a partial eval API
|
|
|
|
# unrelated to autodiff, since we raise an error when differentiation with
|
|
|
|
# respect to values over which a custom_vjp function closes is detected.
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2020-06-23 09:39:45 -07:00
|
|
|
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
|
|
|
|
py_args = map(PartialVal, zip(pvs, consts))
|
|
|
|
jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
|
2019-07-26 16:48:17 -04:00
|
|
|
out_pvs, out_consts = unzip2(out_pvals)
|
2020-06-23 09:39:45 -07:00
|
|
|
out = tuple(out_consts) + tuple(consts)
|
2019-07-26 16:48:17 -04:00
|
|
|
yield out, (out_pvs, jaxpr, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
@lu.transformation_with_aux
|
|
|
|
def count_outputs(*args, **kwargs):
|
|
|
|
ans = yield args, kwargs
|
|
|
|
yield ans, len(ans)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-05-28 17:39:13 +02:00
|
|
|
custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
|
|
|
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
2020-06-23 09:39:45 -07:00
|
|
|
call_param_updaters: Dict[core.Primitive, Callable] = {}
|
2020-05-28 17:39:13 +02:00
|
|
|
|
|
|
|
|
2019-02-13 14:28:30 -08:00
|
|
|
def abstract_eval_fun(fun, *avals, **params):
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
2020-08-25 05:38:41 -07:00
|
|
|
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
|
2020-07-30 12:59:36 -07:00
|
|
|
else:
|
2020-08-25 05:38:41 -07:00
|
|
|
pvals_in = [PartialVal.unknown(a) for a in avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
|
2020-09-15 08:06:46 -07:00
|
|
|
instantiate=True, stage_out=True) # type: ignore
|
2020-08-25 05:38:41 -07:00
|
|
|
avals_out, _ = unzip2(pvals_out)
|
2019-07-26 16:48:17 -04:00
|
|
|
for aval_out in avals_out:
|
|
|
|
assert isinstance(aval_out, AbstractValue) # instantiate=True
|
|
|
|
return avals_out
|
2019-02-13 14:28:30 -08:00
|
|
|
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar',
|
|
|
|
'ConstVar', Literal, core.Unit]
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['pval', 'recipe']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, trace: JaxprTrace, pval: PartialVal,
|
|
|
|
recipe: Optional[JaxprTracerRecipe]):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert isinstance(pval, PartialVal)
|
|
|
|
pv, const = pval
|
2020-02-15 06:35:49 +01:00
|
|
|
if isinstance(const, Tracer) and const._trace.level >= trace.level:
|
2020-03-28 14:55:58 -07:00
|
|
|
raise core.escaped_tracer_error(
|
|
|
|
"Tracer from a higher level: {} in trace {}".format(const, trace))
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.pval = pval
|
|
|
|
self.recipe = recipe
|
|
|
|
|
|
|
|
def __repr__(self):
|
2020-01-29 16:23:27 -05:00
|
|
|
return 'Traced<{}:{}>'.format(self.aval, self._trace)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
2020-06-01 21:45:36 -04:00
|
|
|
def aval(self) -> AbstractValue:
|
2020-03-18 07:11:44 +01:00
|
|
|
return self.pval.get_aval()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
2020-06-01 21:45:36 -04:00
|
|
|
def parents(self) -> Sequence['JaxprTracer']:
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(self.recipe, JaxprEqnRecipe):
|
2020-02-03 20:58:56 +01:00
|
|
|
return self.recipe.invars
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return []
|
|
|
|
|
|
|
|
def full_lower(self):
|
2020-03-18 07:11:44 +01:00
|
|
|
known = self.pval.get_known()
|
|
|
|
if known is not None:
|
|
|
|
return core.full_lower(known)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
def is_known(self):
|
|
|
|
return self.pval.is_known()
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
# TODO(necula): this could return a ClosedJaxpr with out_pvals
|
2020-03-21 13:54:30 +01:00
|
|
|
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
2020-09-15 08:06:46 -07:00
|
|
|
instantiate: Union[bool, Sequence[bool]] = False,
|
|
|
|
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
2020-03-18 07:11:44 +01:00
|
|
|
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
|
|
|
|
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
|
|
|
|
for the outputs. The intermediate values that depend only on known inputs and
|
|
|
|
are needed to compute the output of `jaxpr` are in `consts` and are passed in
|
|
|
|
as the constvars of the `jaxpr`. The handling of the known outputs depends on
|
|
|
|
`instantiate`.
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
For example, given `fun` defined as follows::
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def fun(ki, ui): # ki will be a known input in this example
|
|
|
|
ka = ki + 2
|
|
|
|
kb = ka + 3
|
|
|
|
return (kb, ui + ka)
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
|
|
|
|
computation that depends on unknown inputs is `ui + ka` and will be the only
|
2020-06-23 09:39:45 -07:00
|
|
|
computation in the body of the `jaxpr`. This computation depends on the known
|
|
|
|
intermediate value `ka`, which will be computed statically. Currently, such
|
|
|
|
constants are either embedded in the Jaxpr if they are scalars, or passed as a
|
|
|
|
constvar to `jaxpr`, and then the value of the actual constant will be in
|
|
|
|
`consts`:
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
When `instantiate=False` we get::
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr =
|
2020-03-18 07:11:44 +01:00
|
|
|
{ lambda ka ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (*, c) } # known outputs are `*`
|
2020-07-30 12:59:36 -07:00
|
|
|
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
|
|
|
|
consts = [3] # the constant for `ka`
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
When `instantiate=True` we get::
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr =
|
2020-03-18 07:11:44 +01:00
|
|
|
{ lambda ka kb ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (kb, c) } # known output are explicit
|
2020-07-30 12:59:36 -07:00
|
|
|
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
|
|
|
|
consts = [3, 6] # values for `ka` and `kb` constvars
|
2020-03-18 07:11:44 +01:00
|
|
|
"""
|
2020-09-15 08:06:46 -07:00
|
|
|
with core.new_main(JaxprTrace) as main:
|
2020-08-30 12:38:14 +03:00
|
|
|
fun = trace_to_subjaxpr(fun, main, instantiate)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
assert not env
|
2020-08-30 12:38:14 +03:00
|
|
|
del main
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
return jaxpr, out_pvals, consts
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2020-08-30 12:38:14 +03:00
|
|
|
def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
|
2020-03-18 07:11:44 +01:00
|
|
|
pvals: Sequence[PartialVal]):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
2020-08-30 12:38:14 +03:00
|
|
|
trace = JaxprTrace(main, core.cur_sublevel())
|
2018-11-17 18:03:33 -08:00
|
|
|
in_tracers = map(trace.new_arg, pvals)
|
2019-06-23 15:31:13 -07:00
|
|
|
ans = yield in_tracers, {}
|
2020-03-18 07:11:44 +01:00
|
|
|
instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
|
2019-07-27 15:46:14 -07:00
|
|
|
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
|
|
|
|
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
|
|
|
|
out_pvals = [t.pval for t in out_tracers]
|
|
|
|
del trace, in_tracers, out_tracers
|
|
|
|
yield jaxpr, (out_pvals, consts, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
|
2019-07-27 15:46:14 -07:00
|
|
|
if instantiate:
|
|
|
|
return trace.instantiate_const(trace.full_raise(tracer))
|
2019-05-10 08:58:05 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return tracer
|
2019-05-10 08:58:05 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
FreeVar = namedtuple('FreeVar', ['val'])
|
|
|
|
ConstVar = namedtuple('ConstVar', ['val'])
|
|
|
|
LambdaBinding = namedtuple('LambdaBinding', [])
|
2020-06-01 21:45:36 -04:00
|
|
|
class JaxprEqnRecipe(NamedTuple):
|
|
|
|
eqn_id: object
|
|
|
|
invars: Sequence[JaxprTracer]
|
|
|
|
outvars: 'Sequence[ref[JaxprTracer]]'
|
|
|
|
primitive: core.Primitive
|
|
|
|
params: Dict[str, Any]
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info: Optional[source_info_util.Traceback]
|
2020-06-01 21:45:36 -04:00
|
|
|
|
|
|
|
def new_eqn_recipe(invars: Sequence[JaxprTracer],
|
|
|
|
outvars: Sequence[JaxprTracer],
|
|
|
|
primitive: core.Primitive,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
params: Dict[str, Any],
|
|
|
|
source_info: Optional[source_info_util.Traceback]
|
|
|
|
) -> JaxprEqnRecipe:
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Constructs a new JaxEqnRecipe.
|
|
|
|
|
|
|
|
Params:
|
|
|
|
invars: the tracers for the primitive inputs.
|
|
|
|
outvars: the tracers for the primitive outputs.
|
|
|
|
primitive: the primitive.
|
|
|
|
params: the primitive params
|
|
|
|
"""
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
|
|
|
|
if primitive.call_primitive or primitive.map_primitive:
|
2020-02-05 15:38:25 +01:00
|
|
|
assert "call_jaxpr" in params
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
if primitive.map_primitive:
|
2020-11-05 11:54:05 +00:00
|
|
|
assert ("in_axes" in params and
|
|
|
|
len(params["in_axes"]) == len(params["call_jaxpr"].invars))
|
2020-06-23 09:39:45 -07:00
|
|
|
assert ("donated_invars" in params and
|
|
|
|
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
|
2020-01-07 13:11:32 -08:00
|
|
|
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
params, source_info)
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2019-11-19 12:26:30 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
def recipe_to_eqn(getvar: Callable[[JaxprTracer], core.Atom],
|
2020-06-01 21:45:36 -04:00
|
|
|
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
|
2019-11-20 09:12:15 -08:00
|
|
|
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
|
|
|
invars = [getvar(t) for t in in_tracers]
|
2020-06-23 09:39:45 -07:00
|
|
|
outvars = [core.dropvar if t is None else cast(core.Var, getvar(t))
|
2020-06-01 21:45:36 -04:00
|
|
|
for t in out_tracers]
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def tracers_to_jaxpr(
|
|
|
|
in_tracers: List[JaxprTracer],
|
2020-06-02 10:26:43 -04:00
|
|
|
out_tracers: List[JaxprTracer]
|
|
|
|
) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Constructs Jaxpr given tracers for inputs and outputs.
|
|
|
|
|
|
|
|
Params:
|
|
|
|
in_tracers: the tracers that were created for the function inputs
|
|
|
|
out_tracers: the tracers that were output by the function.
|
|
|
|
|
|
|
|
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
|
|
|
|
the `constvars` in the returned Jaxps, and a list of environment values.
|
2020-03-18 07:11:44 +01:00
|
|
|
The vars for the environment values have been prepended to the Jaxpr's
|
2020-01-07 13:11:32 -08:00
|
|
|
`invars`.
|
|
|
|
"""
|
2020-05-26 11:28:50 -07:00
|
|
|
newvar = core.gensym()
|
2020-06-01 21:45:36 -04:00
|
|
|
t_to_var: Dict[int, core.Atom] = {}
|
|
|
|
def getvar(t: JaxprTracer) -> core.Atom:
|
2020-03-09 09:14:23 +00:00
|
|
|
var = t_to_var.get(id(t))
|
|
|
|
if var is None:
|
2020-03-18 07:11:44 +01:00
|
|
|
aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
|
2020-04-08 22:29:07 -07:00
|
|
|
var = t_to_var[id(t)] = newvar(aval)
|
2020-03-09 09:14:23 +00:00
|
|
|
return var
|
2019-07-26 16:48:17 -04:00
|
|
|
sorted_tracers = toposort(out_tracers)
|
2019-11-20 09:12:15 -08:00
|
|
|
invars = map(getvar, in_tracers)
|
2020-06-01 21:45:36 -04:00
|
|
|
eqns: List[core.JaxprEqn] = []
|
|
|
|
env: Dict[core.Var, Any] = {}
|
|
|
|
consts: Dict[core.Var, Any] = {}
|
|
|
|
const_to_var: Dict[int, core.Var] = {}
|
2020-03-09 09:14:23 +00:00
|
|
|
def getconstvar(c):
|
|
|
|
var = const_to_var.get(id(c))
|
|
|
|
if var is None:
|
2020-04-08 22:29:07 -07:00
|
|
|
var = const_to_var[id(c)] = newvar(get_aval(c))
|
2020-03-09 09:14:23 +00:00
|
|
|
return var
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids = set()
|
2018-11-17 18:03:33 -08:00
|
|
|
for t in sorted_tracers:
|
|
|
|
recipe = t.recipe
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(recipe, JaxprEqnRecipe):
|
2019-11-20 09:12:15 -08:00
|
|
|
if recipe.eqn_id not in processed_eqn_ids:
|
2020-06-23 09:39:45 -07:00
|
|
|
eqns.append(recipe_to_eqn(getvar, recipe))
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids.add(recipe.eqn_id)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, LambdaBinding):
|
2020-02-15 06:35:49 +01:00
|
|
|
if not any(t is in_tracer for in_tracer in in_tracers):
|
2020-03-28 14:55:58 -07:00
|
|
|
raise core.escaped_tracer_error(
|
|
|
|
"Tracer not among input tracers {}".format(t))
|
2018-11-17 18:03:33 -08:00
|
|
|
assert in_tracers, "Lambda binding with no args"
|
|
|
|
elif isinstance(recipe, FreeVar):
|
2020-06-01 21:45:36 -04:00
|
|
|
env[cast(core.Var, getvar(t))] = recipe.val
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, ConstVar):
|
2020-03-09 09:14:23 +00:00
|
|
|
v = t_to_var[id(t)] = getconstvar(recipe.val)
|
2019-06-18 08:09:37 -07:00
|
|
|
consts[v] = recipe.val
|
2019-05-13 08:48:13 -07:00
|
|
|
elif isinstance(recipe, Literal):
|
|
|
|
t_to_var[id(t)] = recipe
|
2018-11-17 18:03:33 -08:00
|
|
|
elif recipe is unit:
|
2018-11-21 13:20:44 -08:00
|
|
|
t_to_var[id(t)] = unitvar
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError(recipe)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
env_vars, env_vals = unzip2(env.items())
|
|
|
|
const_vars, const_vals = unzip2(consts.items())
|
2020-01-07 13:11:32 -08:00
|
|
|
# The env_vars are pre-pended to the invars
|
2020-06-23 09:39:45 -07:00
|
|
|
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
|
2019-09-20 15:35:43 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2018-11-17 18:03:33 -08:00
|
|
|
return jaxpr, const_vals, env_vals
|
|
|
|
|
2020-02-05 13:55:59 +01:00
|
|
|
@cache()
|
2020-06-01 21:45:36 -04:00
|
|
|
def convert_constvars_jaxpr(jaxpr: Jaxpr):
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Moves the constvars to the start of invars."""
|
2019-05-08 16:27:23 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2020-01-07 13:11:32 -08:00
|
|
|
lifted_jaxpr = Jaxpr(constvars=(),
|
2019-07-27 15:46:14 -07:00
|
|
|
invars=jaxpr.constvars + jaxpr.invars,
|
|
|
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
2019-05-01 15:47:01 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
2019-04-09 08:45:34 -07:00
|
|
|
return lifted_jaxpr
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int):
|
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
|
|
|
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
|
|
|
|
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
|
|
|
|
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
|
|
|
core.skip_checks or core.check_jaxpr(converted_jaxpr)
|
|
|
|
return converted_jaxpr
|
|
|
|
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
|
|
|
|
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
|
2020-04-17 20:08:24 +03:00
|
|
|
instantiate: Union[bool, Sequence[bool]],
|
2020-09-18 10:07:13 -07:00
|
|
|
) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
|
2020-03-18 07:11:44 +01:00
|
|
|
"""Specializes a Jaxpr given an indication of which inputs are known.
|
|
|
|
|
|
|
|
Returns: (jaxpr_known, jaxpr_unknown, out_unknowns).
|
|
|
|
|
|
|
|
`out_unknowns` specifies which outputs are unknown (depend on some unknown inputs).
|
|
|
|
`jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs,
|
|
|
|
and performs *all* the computation in `jaxpr` that depends only on the known inputs.
|
|
|
|
Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`,
|
|
|
|
appended with the known residuals (the intermediate computations in `jaxpr`
|
|
|
|
that depend only on known inputs and that are needed to compute the unknown outputs).
|
|
|
|
|
|
|
|
`jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals
|
|
|
|
computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known
|
|
|
|
outputs replaced by `*`.
|
|
|
|
|
|
|
|
Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively
|
|
|
|
unknown inputs into:
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *)
|
|
|
|
let _, uout = jaxpr_unknown(ki, ui, kresidual)
|
|
|
|
in (kout, uout)
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2
|
|
|
|
in (ki + 3, ui + ka)"
|
|
|
|
then
|
|
|
|
`jaxpr_known` = lambda ki, ui: let ka = ki + 2
|
|
|
|
in (ki + 3, *, ka)
|
|
|
|
'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka)
|
2020-08-04 12:41:30 +00:00
|
|
|
|
|
|
|
Note that if instantiate is True for a given output, then jaxpr_known always returns a
|
|
|
|
unit in its place. So when instantiate is True, the expectation is the one doesn't
|
|
|
|
run `jaxpr_known` for any of its outputs, but only to generate residuals that will allow
|
|
|
|
to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
|
|
|
|
simply get passed as residual constants and returned immediately.
|
2020-03-18 07:11:44 +01:00
|
|
|
"""
|
2019-04-23 09:15:16 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-04-11 14:50:58 -07:00
|
|
|
|
|
|
|
cell = []
|
|
|
|
def fun(*vals):
|
2020-03-18 07:11:44 +01:00
|
|
|
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
2020-09-15 08:06:46 -07:00
|
|
|
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
|
|
|
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
|
|
|
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
|
|
|
return out_consts_2 + consts_2
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
|
|
|
|
# known inputs.
|
|
|
|
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
|
|
|
|
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
|
2019-07-27 15:46:14 -07:00
|
|
|
(out_pvs_2, jaxpr_2, num_res), = cell
|
|
|
|
assert len(jaxpr_2.constvars) == num_res
|
|
|
|
|
|
|
|
# jaxpr :: a -> b
|
|
|
|
# jaxpr_1 :: a1 -> [b1, res]
|
|
|
|
# jaxpr_2 :: res | a2 -> b2
|
|
|
|
# jaxpr_2 :: [a2, res] -> b2
|
2020-02-03 20:58:56 +01:00
|
|
|
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
|
2020-03-09 09:14:23 +00:00
|
|
|
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
|
|
|
|
if not unknown:
|
|
|
|
var.aval = abstract_unit
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
uk_out = [pv is not None for pv in out_pvs_2]
|
|
|
|
|
|
|
|
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
|
|
|
|
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
|
|
|
|
# out_avals_1 and in_avals_2 need the residuals added
|
2020-09-15 08:06:46 -07:00
|
|
|
res_avals = out_avals[len(jaxpr.out_avals):]
|
2019-07-27 15:46:14 -07:00
|
|
|
assert len(res_avals) == num_res
|
2020-09-15 08:06:46 -07:00
|
|
|
out_avals_1 = [*out_avals_1, *res_avals]
|
|
|
|
in_avals_2 = [*in_avals_2, *res_avals]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
remat_call_p = core.CallPrimitive('remat_call')
|
|
|
|
remat_call = remat_call_p.bind
|
2019-11-22 10:53:11 -08:00
|
|
|
remat_call_p.def_impl(core.call_impl)
|
|
|
|
|
2020-08-11 11:45:58 +02:00
|
|
|
def _remat_partial_eval(trace, _, f, tracers, params):
|
2019-11-22 10:53:11 -08:00
|
|
|
concrete = params['concrete']
|
|
|
|
|
|
|
|
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
|
|
|
# the function being called, not just for the unknown parts. To do that, we
|
|
|
|
# instantiate all the input tracers as constants in the jaxpr being formed.
|
|
|
|
# Those tracers might have concrete avals, and doing abstract interpretation
|
|
|
|
# on concrete avals engenders a tradeoff: it allows data-dependent Python
|
|
|
|
# control flow to work, but it can in some cases lead to redundant FLOPs (done
|
|
|
|
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
|
|
|
|
# `concrete` parameter to switch this behavior, and if `concrete` is False
|
|
|
|
# then we raise the avals to the Shaped level.
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
if concrete:
|
|
|
|
instantiated_tracers = map(trace.instantiate_const, tracers)
|
|
|
|
else:
|
|
|
|
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
2020-05-28 17:39:13 +02:00
|
|
|
in_pvals = [t.pval for t in instantiated_tracers]
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
2020-06-16 15:46:51 -07:00
|
|
|
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
2020-10-16 00:21:04 -07:00
|
|
|
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
|
2020-07-30 12:59:36 -07:00
|
|
|
else:
|
2020-09-15 08:06:46 -07:00
|
|
|
with core.initial_style_staging(): # type: ignore
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
2020-10-16 00:21:04 -07:00
|
|
|
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
|
2020-06-15 18:42:53 -07:00
|
|
|
|
|
|
|
# Convert consts to inputs, since they may contain Tracer instances.
|
|
|
|
jaxpr = convert_constvars_jaxpr(jaxpr)
|
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
# Since we traced with everything marked as unknown, but we need to know which
|
|
|
|
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
2020-09-18 10:07:13 -07:00
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
|
2020-06-15 18:42:53 -07:00
|
|
|
in_unknowns = ([False] * len(consts) +
|
|
|
|
[not t.is_known() for t in it.chain(env_tracers, tracers)])
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
2020-09-18 10:07:13 -07:00
|
|
|
closed_jaxpr, in_unknowns, instantiate=False) # type: ignore
|
2020-07-30 12:59:36 -07:00
|
|
|
else:
|
|
|
|
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
2020-09-18 10:07:13 -07:00
|
|
|
closed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore
|
2020-06-12 15:03:26 +02:00
|
|
|
out_knowns = [not b for b in out_unknowns]
|
|
|
|
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
|
2019-11-27 15:25:49 -08:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
# Next, we need values for the outputs that should be known. Since consts
|
2020-06-12 15:03:26 +02:00
|
|
|
# weren't passed through Python for evaluation, we need to evaluate jaxpr_known,
|
2019-11-22 10:53:11 -08:00
|
|
|
# minus the residual outputs that we don't need. When `concrete=True`, as an
|
|
|
|
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
|
|
|
|
# produced concrete avals at the output, simply by using those as computed
|
2020-06-12 15:03:26 +02:00
|
|
|
# values. For the use case of inverse-mode ad in op-by-op ("eager mode")
|
2019-11-27 14:28:13 -08:00
|
|
|
# evaluation, all the primal outputs should be concrete (thus not recomputed).
|
2020-06-12 15:03:26 +02:00
|
|
|
to_compute = [type(pval[0]) is not ConcreteArray
|
2020-06-15 18:42:53 -07:00
|
|
|
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
|
2020-06-12 15:03:26 +02:00
|
|
|
num_outputs = len(jaxpr_unknown.out_avals)
|
|
|
|
num_res = len(jaxpr_known.out_avals) - num_outputs
|
|
|
|
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
|
|
|
|
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
|
2020-05-28 17:39:13 +02:00
|
|
|
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
|
2020-06-15 18:42:53 -07:00
|
|
|
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
|
2020-06-12 15:03:26 +02:00
|
|
|
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
|
|
|
|
|
|
|
|
# Known outputs should keep propagating as constants
|
|
|
|
assert all(pv.is_known() for pv in out_known_pvals)
|
2020-06-15 18:42:53 -07:00
|
|
|
known_output_tracers = [trace.new_const(pval.get_known())
|
|
|
|
for pval in out_known_pvals]
|
|
|
|
# Unknown outputs get wrapped in tracers with the appropriate recipe
|
|
|
|
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
|
|
|
|
for out_pval in out_unknown_pvals]
|
|
|
|
|
|
|
|
# dce jaxpr outputs
|
2020-09-18 10:07:13 -07:00
|
|
|
new_jaxpr = _dce_jaxpr(closed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
|
2020-06-15 18:42:53 -07:00
|
|
|
new_params = dict(params, call_jaxpr=new_jaxpr)
|
|
|
|
|
|
|
|
# set up eqn for unknown outputs
|
2020-08-11 11:45:58 +02:00
|
|
|
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
|
|
|
|
eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p, new_params,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info_util.current())
|
2020-08-11 11:45:58 +02:00
|
|
|
for t in unknown_output_tracers: t.recipe = eqn
|
|
|
|
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
|
|
|
|
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
2020-06-15 18:42:53 -07:00
|
|
|
|
|
|
|
def _partition_knowns(pvals, unknowns: Sequence[bool]):
|
|
|
|
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
|
|
|
|
[e for e, unknown in zip(pvals, unknowns) if unknown])
|
|
|
|
|
|
|
|
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
|
|
|
|
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
|
|
|
|
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
|
2020-06-12 15:03:26 +02:00
|
|
|
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJaxpr:
|
|
|
|
new_jaxpr = _dce_open_jaxpr(closed_jaxpr.jaxpr, tuple(outputs), drop_outputs)
|
|
|
|
return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
|
2020-06-12 15:03:26 +02:00
|
|
|
|
|
|
|
@cache()
|
2020-09-18 10:07:13 -07:00
|
|
|
def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
|
2019-11-22 10:53:11 -08:00
|
|
|
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
2019-11-27 14:28:13 -08:00
|
|
|
# nontrivially DCE through scan, call, or other higher-order primitives.
|
|
|
|
# TODO(mattjj): better DCE
|
2020-06-12 15:03:26 +02:00
|
|
|
if drop_outputs:
|
|
|
|
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
|
|
|
|
else:
|
|
|
|
new_outvars = [var if output else unitvar
|
|
|
|
for var, output in zip(jaxpr.outvars, outputs)]
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-01-22 17:19:14 -08:00
|
|
|
needed_vars = {v for v in new_outvars if type(v) is not Literal}
|
2019-11-22 10:53:11 -08:00
|
|
|
new_eqns = []
|
|
|
|
for eqn in jaxpr.eqns[::-1]:
|
|
|
|
if set(eqn.outvars) & needed_vars:
|
|
|
|
new_eqns.append(eqn)
|
2020-01-22 17:19:14 -08:00
|
|
|
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
|
2019-11-22 10:53:11 -08:00
|
|
|
new_eqns = new_eqns[::-1]
|
2020-06-12 15:03:26 +02:00
|
|
|
return core.Jaxpr(jaxpr.constvars, jaxpr.invars,
|
|
|
|
new_outvars, new_eqns)
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _drop_invars(jaxpr: Jaxpr, drop: Tuple[bool, ...]):
|
|
|
|
return core.Jaxpr(jaxpr.constvars, [v for v, d in zip(jaxpr.invars, drop) if not d],
|
|
|
|
jaxpr.outvars, jaxpr.eqns)
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
def _reconstruct_pval(pval1: PartialVal, const2: core.Value):
|
2020-03-18 07:11:44 +01:00
|
|
|
pv1, _ = pval1
|
2020-06-12 15:03:26 +02:00
|
|
|
if pval1.is_known():
|
2019-11-22 10:53:11 -08:00
|
|
|
return pval1
|
|
|
|
else:
|
|
|
|
if type(pv1) is ConcreteArray:
|
2020-06-22 20:04:07 -07:00
|
|
|
return PartialVal.known(pv1.val) # pytype: disable=attribute-error
|
2019-11-22 10:53:11 -08:00
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
return PartialVal.known(const2)
|
2019-11-27 14:28:13 -08:00
|
|
|
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr:
|
2020-03-18 07:11:44 +01:00
|
|
|
"""Reorder the `invars` to move to front the ones for which `to_move` is True."""
|
2020-09-18 10:07:13 -07:00
|
|
|
assert not closed_jaxpr.jaxpr.constvars
|
|
|
|
assert len(closed_jaxpr.in_avals) == len(to_move)
|
|
|
|
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
|
|
|
|
new_jaxpr = core.Jaxpr((), new_invars, closed_jaxpr.jaxpr.outvars,
|
|
|
|
closed_jaxpr.jaxpr.eqns)
|
|
|
|
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
|
|
|
|
return new_closed_jaxpr
|
2019-11-27 14:28:13 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
|
2019-11-27 14:28:13 -08:00
|
|
|
return ([elt for elt, move in zip(lst, to_move) if move] +
|
|
|
|
[elt for elt, move in zip(lst, to_move) if not move])
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
|
|
|
|
class DynamicJaxprTracer(core.Tracer):
|
|
|
|
__slots__ = ['aval', 'line_info']
|
|
|
|
|
|
|
|
def __init__(self, trace, aval, line_info=None):
|
|
|
|
self._trace = trace
|
|
|
|
self.aval = aval
|
|
|
|
self.line_info = line_info
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def _contents(self):
|
|
|
|
return ()
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def _origin_msg(self):
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
|
|
|
|
if invar_pos:
|
|
|
|
origin = (f"While tracing the function {self._trace.main.source_info}, "
|
|
|
|
"this concrete value was not available in Python because it "
|
|
|
|
"depends on the value of the arguments to "
|
2020-09-25 15:35:44 -07:00
|
|
|
f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
"and the computation of these values is being staged out "
|
|
|
|
"(that is, delayed rather than executed eagerly).\n\n"
|
|
|
|
"You can use transformation parameters such as `static_argnums` "
|
|
|
|
"for `jit` to avoid tracing particular arguments of transformed "
|
|
|
|
"functions, though at the cost of more recompiles.")
|
|
|
|
elif progenitor_eqns:
|
2020-09-18 10:49:04 -07:00
|
|
|
msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
|
|
|
|
f" from line {source_info_util.summarize(eqn.source_info)}"
|
|
|
|
for eqn in progenitor_eqns]
|
2020-09-15 08:06:46 -07:00
|
|
|
origin = (f"While tracing the function {self._trace.main.source_info}, "
|
|
|
|
"this value became a tracer due to JAX operations on these lines:"
|
2020-09-18 10:49:04 -07:00
|
|
|
"\n\n" + "\n\n".join(msts))
|
2020-09-15 08:06:46 -07:00
|
|
|
else:
|
|
|
|
origin = ("The error occured while tracing the function "
|
|
|
|
f"{self._trace.main.source_info}.")
|
|
|
|
return origin
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-09-16 15:59:50 -07:00
|
|
|
def _assert_live(self) -> None:
|
|
|
|
if not self._trace.main.jaxpr_stack: # type: ignore
|
|
|
|
msg = f"tracer created on line {source_info_util.summarize(self.line_info)}"
|
|
|
|
raise core.escaped_tracer_error(msg)
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
class JaxprStackFrame:
|
|
|
|
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
'tracers', 'eqns', 'invars']
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.newvar = core.gensym()
|
|
|
|
self.tracer_to_var = {}
|
|
|
|
self.constid_to_var = {}
|
|
|
|
self.constvar_to_val = {}
|
2020-08-30 12:38:14 +03:00
|
|
|
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
|
|
|
|
self.eqns = [] # cleared when we pop frame from main
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
self.invars = []
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def to_jaxpr(self, in_tracers, out_tracers):
|
|
|
|
invars = [self.tracer_to_var[id(t)] for t in in_tracers]
|
|
|
|
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
|
|
|
constvars, constvals = unzip2(self.constvar_to_val.items())
|
|
|
|
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
|
|
|
|
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
|
|
|
out_avals = [t.aval for t in out_tracers]
|
|
|
|
return jaxpr, out_avals, constvals
|
|
|
|
|
|
|
|
def find_progenitors(self, tracer):
|
2020-09-15 08:06:46 -07:00
|
|
|
var = self.tracer_to_var.get(id(tracer))
|
|
|
|
if not var:
|
2020-10-20 16:10:56 -07:00
|
|
|
return None, None
|
2020-09-15 08:06:46 -07:00
|
|
|
active_vars = {var}
|
2020-07-30 12:59:36 -07:00
|
|
|
for eqn in self.eqns[::-1]:
|
|
|
|
produced = set(eqn.outvars) & active_vars
|
|
|
|
if produced:
|
|
|
|
active_vars.difference_update(produced)
|
|
|
|
active_vars.update(eqn.invars)
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars]
|
|
|
|
constvars = active_vars & set(self.constvar_to_val)
|
|
|
|
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
|
|
|
|
return invar_positions, const_eqns
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def _inline_literals(jaxpr, constvals):
|
|
|
|
consts = dict(zip(jaxpr.constvars, constvals))
|
|
|
|
newvar = core.gensym()
|
|
|
|
class var(dict):
|
|
|
|
def __missing__(self, v):
|
|
|
|
new_v = self[v] = newvar(v.aval)
|
|
|
|
return new_v
|
|
|
|
var = var()
|
|
|
|
|
|
|
|
def lit(var: core.Var) -> Optional[Any]:
|
|
|
|
val = consts.get(var)
|
|
|
|
if type(val) in core.literalable_types and not np.shape(val):
|
|
|
|
return Literal(val)
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
|
|
|
|
new_constvars = [var[v] for v in jaxpr.constvars if not lit(v)]
|
|
|
|
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
|
|
|
|
new_invars = [var[v] for v in jaxpr.invars]
|
|
|
|
new_eqns = [new_jaxpr_eqn([lit(v) or var[v] for v in eqn.invars],
|
|
|
|
[var[v] if v in used else dropvar for v in eqn.outvars],
|
|
|
|
eqn.primitive, eqn.params, eqn.source_info)
|
|
|
|
for eqn in jaxpr.eqns]
|
|
|
|
new_outvars = [lit(v) or var[v] for v in jaxpr.outvars]
|
|
|
|
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
|
|
|
|
return new_jaxpr, new_constvals
|
|
|
|
|
|
|
|
class DynamicJaxprTrace(core.Trace):
|
|
|
|
__slots__ = [] # type: ignore
|
|
|
|
|
|
|
|
@property
|
2020-09-16 15:59:50 -07:00
|
|
|
def frame(self):
|
|
|
|
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def new_arg(self, aval):
|
2020-09-16 15:59:50 -07:00
|
|
|
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.tracers.append(tracer)
|
Improve a tracer error message
Previously, given this function:
```python
@jax.jit
def f(x,y):
if x > y:
return x
else:
return y
```
we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):
```
...
While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:
operation c:bool[] = gt a:int32[] b:int32[]
from line tim.py:5 (f)
...
```
But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.
After this change, we instead produce this error message:
```
...
While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.
...
```
I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
|
|
|
self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
|
|
|
|
self.frame.invars.append(var)
|
2020-07-30 12:59:36 -07:00
|
|
|
return tracer
|
|
|
|
|
|
|
|
def new_const(self, val):
|
2020-10-29 15:11:37 -07:00
|
|
|
aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val))
|
2020-09-16 15:59:50 -07:00
|
|
|
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.tracers.append(tracer)
|
|
|
|
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
|
|
|
|
self.frame.constvar_to_val[var] = val
|
|
|
|
return tracer
|
|
|
|
|
|
|
|
pure = lift = sublift = new_const
|
|
|
|
|
|
|
|
def getvar(self, tracer):
|
|
|
|
var = self.frame.tracer_to_var.get(id(tracer))
|
|
|
|
if var is None:
|
2020-10-20 16:10:56 -07:00
|
|
|
if tracer.line_info is not None:
|
|
|
|
detail = f"tracer created on line {source_info_util.summarize(tracer.line_info)}"
|
|
|
|
else:
|
|
|
|
detail = None
|
|
|
|
raise core.escaped_tracer_error(detail)
|
|
|
|
return var
|
|
|
|
|
|
|
|
def makevar(self, tracer):
|
|
|
|
var = self.frame.tracer_to_var.get(id(tracer))
|
|
|
|
assert var is None, "a jaxpr variable must be created only once per tracer"
|
|
|
|
self.frame.tracers.append(tracer)
|
|
|
|
var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
|
2020-07-30 12:59:36 -07:00
|
|
|
return var
|
|
|
|
|
|
|
|
def getconstvar(self, c):
|
|
|
|
var = self.frame.constid_to_var.get(id(c))
|
|
|
|
if var is None:
|
|
|
|
var = self.frame.constid_to_var[id(c)] = self.frame.newvar(get_aval(c))
|
|
|
|
return var
|
|
|
|
|
|
|
|
def instantiate_const(self, val):
|
2020-08-30 01:16:51 -07:00
|
|
|
if (isinstance(val, Tracer) and val._trace.main is self.main
|
2020-07-30 12:59:36 -07:00
|
|
|
and val._trace.sublevel == self.sublevel):
|
|
|
|
return val
|
|
|
|
else:
|
|
|
|
return self.new_const(val)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
avals = [t.aval for t in tracers]
|
|
|
|
out_avals = primitive.abstract_eval(*avals, **params)
|
|
|
|
out_avals = [out_avals] if not primitive.multiple_results else out_avals
|
2020-09-16 15:59:50 -07:00
|
|
|
source_info = source_info_util.current()
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
invars = map(self.getvar, tracers)
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2020-09-16 15:59:50 -07:00
|
|
|
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.eqns.append(eqn)
|
|
|
|
return out_tracers if primitive.multiple_results else out_tracers.pop()
|
|
|
|
|
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
2020-08-30 01:16:51 -07:00
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
|
2020-07-30 12:59:36 -07:00
|
|
|
if not jaxpr.eqns:
|
|
|
|
return core.eval_jaxpr(jaxpr, consts, *tracers)
|
2020-09-16 15:59:50 -07:00
|
|
|
source_info = source_info_util.current()
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2020-07-30 12:59:36 -07:00
|
|
|
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
|
|
|
update_params = call_param_updaters.get(call_primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, [True] * len(tracers))
|
2020-09-16 15:59:50 -07:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
|
|
|
|
new_params, source_info)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.eqns.append(eqn)
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
assert False # unreachable
|
|
|
|
|
|
|
|
def process_map(self, map_primitive, f, tracers, params):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
|
|
|
axis_name, axis_size = params['axis_name'], params['axis_size']
|
2020-11-25 15:23:00 -08:00
|
|
|
reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
|
|
|
|
if in_axis is not None else a
|
2020-11-05 11:54:05 +00:00
|
|
|
for a, in_axis in zip(in_avals, params['in_axes'])]
|
2020-08-14 18:22:04 +02:00
|
|
|
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
2020-08-30 01:16:51 -07:00
|
|
|
f, self.main, reduced_in_avals)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
out_axes = params['out_axes_thunk']()
|
2020-11-25 15:23:00 -08:00
|
|
|
out_avals = [core.unmapped_aval(params['axis_size'], out_axis, a)
|
|
|
|
if out_axis is not None else a
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
for a, out_axis in zip(reduced_out_avals, out_axes)]
|
2020-09-16 15:59:50 -07:00
|
|
|
source_info = source_info_util.current()
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2020-11-05 11:54:05 +00:00
|
|
|
new_in_axes = (None,) * len(consts) + params['in_axes']
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
|
2020-07-30 12:59:36 -07:00
|
|
|
call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
del new_params['out_axes_thunk']
|
2020-07-30 12:59:36 -07:00
|
|
|
update_params = call_param_updaters.get(map_primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, [True] * len(tracers))
|
2020-09-16 15:59:50 -07:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
|
|
|
|
new_params, source_info)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.eqns.append(eqn)
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_map(self, map_primitive, out_tracers, params):
|
|
|
|
assert False # unreachable
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
|
|
|
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
|
|
|
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
|
|
|
jvp_jaxpr_thunk = _memoize(
|
|
|
|
lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2])
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2020-10-16 00:21:04 -07:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
|
|
|
dict(fun_jaxpr=closed_fun_jaxpr,
|
|
|
|
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
|
|
|
num_consts=len(consts)),
|
|
|
|
source_info_util.current())
|
|
|
|
self.frame.eqns.append(eqn)
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_custom_jvp_call(self, out_tracers, params):
|
|
|
|
assert False # unreachable
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
|
|
|
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
|
|
|
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
|
|
|
fwd_jaxpr_thunk = _memoize(
|
|
|
|
lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2])
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
2020-10-20 16:10:56 -07:00
|
|
|
outvars = map(self.makevar, out_tracers)
|
2020-10-16 00:21:04 -07:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
|
|
|
dict(fun_jaxpr=closed_fun_jaxpr,
|
|
|
|
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
|
|
|
num_consts=len(consts),
|
|
|
|
bwd=bwd, out_trees=out_trees),
|
|
|
|
source_info_util.current())
|
|
|
|
self.frame.eqns.append(eqn)
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_custom_vjp_call(self, out_tracers, params):
|
|
|
|
assert False # unreachable
|
|
|
|
|
|
|
|
def _memoize(thunk):
|
|
|
|
cell = []
|
|
|
|
saved_state = core.thread_local_state.trace_state.copy()
|
|
|
|
def memoized():
|
|
|
|
if not cell:
|
|
|
|
prev_state = core.thread_local_state.trace_state
|
|
|
|
core.thread_local_state.trace_state = saved_state
|
|
|
|
try:
|
|
|
|
cell.append(thunk())
|
|
|
|
finally:
|
|
|
|
core.thread_local_state.trace_state = prev_state
|
|
|
|
return cell[0]
|
|
|
|
return memoized
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
|
|
|
assert config.omnistaging_enabled
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
|
|
|
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
|
|
|
main.jaxpr_stack = () # type: ignore
|
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
|
|
|
del main
|
2020-07-30 12:59:36 -07:00
|
|
|
return jaxpr, out_avals, consts
|
|
|
|
|
2020-08-30 01:16:51 -07:00
|
|
|
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
|
|
|
in_avals: Sequence[AbstractValue]):
|
2020-07-30 12:59:36 -07:00
|
|
|
frame = JaxprStackFrame()
|
2020-08-30 01:16:51 -07:00
|
|
|
with extend_jaxpr_stack(main, frame):
|
|
|
|
trace = DynamicJaxprTrace(main, core.cur_sublevel())
|
2020-07-30 12:59:36 -07:00
|
|
|
in_tracers = map(trace.new_arg, in_avals)
|
|
|
|
ans = fun.call_wrapped(*in_tracers)
|
|
|
|
out_tracers = map(trace.full_raise, ans)
|
|
|
|
jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
|
|
|
|
return jaxpr, out_avals, consts
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
2020-08-30 01:16:51 -07:00
|
|
|
def extend_jaxpr_stack(main, frame):
|
|
|
|
main.jaxpr_stack = main.jaxpr_stack + (frame,)
|
2020-07-30 12:59:36 -07:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
2020-08-30 01:16:51 -07:00
|
|
|
assert frame is main.jaxpr_stack[-1]
|
|
|
|
main.jaxpr_stack = main.jaxpr_stack[:-1]
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
|
|
|
assert config.omnistaging_enabled
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
|
|
|
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
|
|
|
main.jaxpr_stack = () # type: ignore
|
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
|
|
|
del main
|
2020-07-30 12:59:36 -07:00
|
|
|
return jaxpr, out_avals, consts
|
|
|
|
|
2020-08-10 07:20:09 -07:00
|
|
|
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
|
|
|
|
# This function provides a partial evaluation behavior used by Flax. We can't
|
|
|
|
# use trace_to_jaxpr directly because of an interaction with the curent
|
|
|
|
# custom_derivatives.py, which we work around by adding the EvalTrace.
|
|
|
|
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
|
|
|
|
assert config.omnistaging_enabled
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
2020-08-10 07:20:09 -07:00
|
|
|
return trace_to_jaxpr(fun, in_pvals)
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def fun_sourceinfo(fun):
|
|
|
|
if isinstance(fun, functools.partial):
|
|
|
|
fun = fun.func
|
|
|
|
try:
|
|
|
|
filename = fun.__code__.co_filename
|
|
|
|
lineno = fun.__code__.co_firstlineno
|
|
|
|
return f"{fun.__name__} at {filename}:{lineno}"
|
|
|
|
except AttributeError:
|
|
|
|
return "<unknown>"
|
|
|
|
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
@config.register_omnistaging_disabler
|
|
|
|
@no_type_check
|
|
|
|
def omnistaging_disabler() -> None:
|
|
|
|
global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
2020-09-15 08:06:46 -07:00
|
|
|
instantiate: Union[bool, Sequence[bool]] = False,
|
|
|
|
stage_out=False, bottom=False,
|
|
|
|
trace_type: Optional[Type[Trace]] = None,
|
|
|
|
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
|
|
|
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
|
|
|
|
|
|
|
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
|
|
|
|
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
|
|
|
|
for the outputs. The intermediate values that depend only on known inputs and
|
|
|
|
are needed to compute the output of `jaxpr` are in `consts` and are passed in
|
|
|
|
as the constvars of the `jaxpr`. The handling of the known outputs depends on
|
|
|
|
`instantiate`.
|
|
|
|
|
|
|
|
For example, given `fun` defined as follows::
|
|
|
|
|
|
|
|
def fun(ki, ui): # ki will be a known input in this example
|
|
|
|
ka = ki + 2
|
|
|
|
kb = ka + 3
|
|
|
|
return (kb, ui + ka)
|
|
|
|
|
|
|
|
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
|
|
|
|
computation that depends on unknown inputs is `ui + ka` and will be the only
|
|
|
|
computation in the body of the `jaxpr`. This computation depends on the known
|
|
|
|
intermediate value `ka`, which will be computed statically. Currently, such
|
|
|
|
constants are either embedded in the Jaxpr if they are scalars, or passed as a
|
|
|
|
constvar to `jaxpr`, and then the value of the actual constant will be in
|
|
|
|
`consts`:
|
|
|
|
|
|
|
|
When `instantiate=False` we get::
|
|
|
|
|
|
|
|
jaxpr =
|
|
|
|
{ lambda ka ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (*, c) } # known outputs are `*`
|
|
|
|
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
|
|
|
|
consts = [3] # the constant for `ka`
|
|
|
|
|
|
|
|
When `instantiate=True` we get::
|
|
|
|
|
|
|
|
jaxpr =
|
|
|
|
{ lambda ka kb ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (kb, c) } # known output are explicit
|
|
|
|
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
|
|
|
|
consts = [3, 6] # values for `ka` and `kb` constvars
|
|
|
|
"""
|
|
|
|
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
|
|
|
|
with core.new_main(trace_type, bottom=bottom) as main:
|
2020-08-30 12:38:14 +03:00
|
|
|
fun = trace_to_subjaxpr(fun, main, instantiate)
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
|
|
|
assert not env
|
2020-08-30 12:38:14 +03:00
|
|
|
del main
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
return jaxpr, out_pvals, consts
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
|
2020-07-30 12:59:36 -07:00
|
|
|
instantiate: Union[bool, Sequence[bool]],
|
2020-09-15 08:06:46 -07:00
|
|
|
trace_type: Optional[Type[core.Trace]]
|
2020-09-18 10:07:13 -07:00
|
|
|
) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
|
2020-07-30 12:59:36 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
|
|
|
|
|
|
|
cell = []
|
|
|
|
def fun(*vals):
|
|
|
|
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
|
|
|
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
2020-09-15 08:06:46 -07:00
|
|
|
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
|
|
|
|
trace_type=trace_type)
|
2020-07-30 12:59:36 -07:00
|
|
|
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
|
|
|
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
|
|
|
return out_consts_2 + consts_2
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
# The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
|
|
|
|
# the avals, and it will never actually reach any primitives, because the `fun` above will
|
|
|
|
# execute the jaxpr with the right avals (it reconstructs `pvals` inside).
|
|
|
|
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
|
|
|
|
for aval, uk in zip(jaxpr.in_avals, unknowns)]
|
|
|
|
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
|
2020-07-30 12:59:36 -07:00
|
|
|
(out_pvs_2, jaxpr_2, num_res), = cell
|
|
|
|
assert len(jaxpr_2.constvars) == num_res
|
|
|
|
|
|
|
|
# jaxpr :: a -> b
|
|
|
|
# jaxpr_1 :: a1 -> [b1, res]
|
|
|
|
# jaxpr_2 :: res | a2 -> b2
|
|
|
|
# jaxpr_2 :: [a2, res] -> b2
|
|
|
|
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
|
|
|
|
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
|
|
|
|
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
|
|
|
|
if not unknown:
|
|
|
|
var.aval = abstract_unit
|
|
|
|
|
|
|
|
uk_out = [pv is not None for pv in out_pvs_2]
|
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
# See comment at top of `JaxprTrace`. This method should be reachable
|
|
|
|
# only when we stage out, and in that case we drop the custom differentiation
|
|
|
|
# rules, because we do not need them.
|
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
assert self.main.trace_type is StagingJaxprTrace
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
# See comment in the above process_custom_jvp_call method.
|
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
assert self.main.trace_type is StagingJaxprTrace
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
JaxprTrace.process_custom_vjp_call = process_custom_vjp_call
|
|
|
|
|
|
|
|
staged_out_calls = set()
|
|
|
|
|
|
|
|
class StagingJaxprTrace(JaxprTrace): pass
|