2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import itertools as it
|
2020-03-09 20:42:08 +01:00
|
|
|
from collections import namedtuple
|
2020-07-30 12:59:36 -07:00
|
|
|
import contextlib
|
|
|
|
import functools
|
|
|
|
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
|
2020-09-15 08:06:46 -07:00
|
|
|
List, Union, cast, Type, no_type_check)
|
2019-11-19 12:26:30 -08:00
|
|
|
from weakref import ref
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2019-05-28 22:38:06 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from .. import core
|
2020-07-30 12:59:36 -07:00
|
|
|
from .. import dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
from .. import linear_util as lu
|
2020-06-04 15:27:48 -07:00
|
|
|
from ..abstract_arrays import ConcreteArray, raise_to_shaped
|
2020-05-27 13:57:47 +00:00
|
|
|
from ..ad_util import Zero
|
2020-01-26 23:27:56 -08:00
|
|
|
from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
|
2020-06-23 09:39:45 -07:00
|
|
|
cache)
|
2020-07-30 12:59:36 -07:00
|
|
|
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
|
|
|
unit, unitvar, abstract_unit, TypedJaxpr, new_jaxpr_eqn,
|
|
|
|
dropvar)
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
from .. import source_info_util
|
2020-07-30 12:59:36 -07:00
|
|
|
from ..config import config
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
|
|
|
zip = safe_zip
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
class PartialVal(tuple):
|
|
|
|
"""Partial value: either a known value or an unknown (abstract) value.
|
|
|
|
|
|
|
|
Represented as a pair `(aval_opt, const)` of one of two kinds:
|
|
|
|
* `(None, <Constant>)` indicates a known value, either a Python regular
|
|
|
|
value, or a Tracer.
|
|
|
|
* `(<AbstractValue>, *)` indicates an unknown value characterized by an
|
|
|
|
abstract value.
|
|
|
|
"""
|
|
|
|
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
|
|
|
|
pv, const = xs
|
|
|
|
if not core.skip_checks:
|
|
|
|
# type checks
|
|
|
|
assert isinstance(pv, (AbstractValue, type(None))), xs
|
2020-05-27 13:57:47 +00:00
|
|
|
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
|
2020-03-18 07:11:44 +01:00
|
|
|
# invariant checks
|
|
|
|
if isinstance(pv, AbstractValue):
|
2020-07-30 12:59:36 -07:00
|
|
|
assert get_aval(const) == core.abstract_unit, xs
|
2020-03-18 07:11:44 +01:00
|
|
|
return tuple.__new__(cls, xs)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def known(cls, const: core.Value) -> 'PartialVal':
|
|
|
|
return PartialVal((None, const))
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def unknown(cls, aval: AbstractValue) -> 'PartialVal':
|
|
|
|
return PartialVal((aval, core.unit))
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def is_known(self) -> bool:
|
2020-03-18 07:11:44 +01:00
|
|
|
return self[0] is None
|
|
|
|
|
|
|
|
def get_known(self) -> Optional[core.Value]:
|
|
|
|
"""Get the known value, if known, else None."""
|
|
|
|
return self[1] if self[0] is None else None
|
|
|
|
|
|
|
|
def get_aval(self) -> AbstractValue:
|
2020-06-23 09:39:45 -07:00
|
|
|
"""Get AbstractValue directly (if unknown) or from the constant (known)."""
|
2020-03-18 07:11:44 +01:00
|
|
|
known = self.get_known()
|
|
|
|
if known is not None:
|
|
|
|
return get_aval(known)
|
|
|
|
else:
|
|
|
|
return self[0]
|
|
|
|
|
|
|
|
def merge_with_known(self, val: core.Value) -> core.Value:
|
|
|
|
"""Either the stored known value, or the given 'val'."""
|
|
|
|
known = self.get_known()
|
|
|
|
return known if known is not None else val
|
2019-05-09 07:23:39 -07:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTrace(Trace):
|
2020-06-01 21:45:36 -04:00
|
|
|
def pure(self, val) -> 'JaxprTracer':
|
2019-09-20 15:35:43 -07:00
|
|
|
return self.new_const(val)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def lift(self, val) -> 'JaxprTracer':
|
2018-11-17 18:03:33 -08:00
|
|
|
return self.new_const(val)
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def sublift(self, val) -> 'JaxprTracer':
|
2018-11-17 18:03:33 -08:00
|
|
|
return JaxprTracer(self, val.pval, FreeVar(val))
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def new_const(self, val) -> 'JaxprTracer':
|
2020-01-29 16:23:27 -05:00
|
|
|
if isinstance(val, Tracer) and val._trace.level == self.level:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise Exception
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.known(val), unit)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def new_instantiated_literal(self, val) -> 'JaxprTracer':
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), Literal(val))
|
2019-05-13 08:48:13 -07:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def new_instantiated_const(self, val) -> 'JaxprTracer':
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def new_arg(self, pval: PartialVal) -> 'JaxprTracer':
|
2020-06-05 17:22:55 +02:00
|
|
|
const = pval.get_known()
|
|
|
|
if const is None:
|
|
|
|
return JaxprTracer(self, pval, LambdaBinding())
|
|
|
|
else:
|
|
|
|
return self.new_const(const)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-02 10:26:43 -04:00
|
|
|
def instantiate_const(self, tracer) -> Tracer:
|
2020-03-18 07:11:44 +01:00
|
|
|
const = tracer.pval.get_known()
|
|
|
|
if const is None:
|
2018-11-17 18:03:33 -08:00
|
|
|
return tracer
|
2020-03-18 07:11:44 +01:00
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
if type(const) in core.literalable_types and np.shape(const) == ():
|
2019-09-20 15:35:43 -07:00
|
|
|
return self.new_instantiated_literal(const)
|
2019-05-13 08:48:13 -07:00
|
|
|
else:
|
|
|
|
return self.new_instantiated_const(const)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer':
|
2020-03-18 07:11:44 +01:00
|
|
|
const = tracer.pval.get_known()
|
|
|
|
if const is None:
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
return tracer
|
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-04-01 16:03:56 -04:00
|
|
|
if primitive in custom_partial_eval_rules:
|
2019-07-27 15:46:14 -07:00
|
|
|
return custom_partial_eval_rules[primitive](self, *tracers, **params)
|
2019-04-01 16:03:56 -04:00
|
|
|
else:
|
2020-01-30 15:03:00 -08:00
|
|
|
return self.default_process_primitive(primitive, tracers, params)
|
|
|
|
|
|
|
|
def default_process_primitive(self, primitive, tracers, params):
|
2020-04-17 20:08:24 +03:00
|
|
|
"""By default, if all the input tracers are known, then execute the primitive
|
|
|
|
and all the ouputs are known. Otherwise, all the outputs are unknown."""
|
2020-07-30 12:59:36 -07:00
|
|
|
consts = [t.pval.get_known() for t in tracers]
|
2020-03-18 07:11:44 +01:00
|
|
|
if all(c is not None for c in consts):
|
2020-02-09 21:06:37 -08:00
|
|
|
return primitive.bind(*consts, **params)
|
|
|
|
tracers = map(self.instantiate_const, tracers)
|
|
|
|
avals = [t.aval for t in tracers]
|
|
|
|
out_aval = primitive.abstract_eval(*avals, **params)
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source = source_info_util.current()
|
2020-02-09 21:06:37 -08:00
|
|
|
if primitive.multiple_results:
|
2020-03-18 07:11:44 +01:00
|
|
|
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
|
2020-02-09 21:06:37 -08:00
|
|
|
for aval in out_aval]
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
|
2020-02-09 21:06:37 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
|
|
|
|
params, source)
|
2020-02-09 21:06:37 -08:00
|
|
|
return out_tracer
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
# We use process_call to handle both call and map primitives.
|
2020-06-12 15:03:26 +02:00
|
|
|
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
2020-07-30 12:59:36 -07:00
|
|
|
if not config.omnistaging_enabled:
|
2020-09-15 08:06:46 -07:00
|
|
|
if (self.main.trace_type is StagingJaxprTrace # type: ignore
|
|
|
|
and primitive in staged_out_calls): # type: ignore
|
2020-07-30 12:59:36 -07:00
|
|
|
tracers = map(self.instantiate_const_abstracted, tracers)
|
2020-04-21 18:12:02 -07:00
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
if primitive in call_partial_eval_rules:
|
|
|
|
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
|
2020-05-28 17:39:13 +02:00
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
in_pvals = [t.pval for t in tracers]
|
|
|
|
if primitive.map_primitive:
|
2020-06-23 09:39:45 -07:00
|
|
|
mapped_aval = partial(core.mapped_aval, params['axis_size'])
|
|
|
|
in_pvals = [pval if pval.is_known() or not is_mapped
|
|
|
|
else PartialVal.unknown(mapped_aval(pval[0]))
|
|
|
|
for pval, is_mapped in zip(in_pvals, params['mapped_invars'])]
|
2020-05-28 17:39:13 +02:00
|
|
|
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
|
2020-06-12 15:03:26 +02:00
|
|
|
f, in_pvals, partial(primitive.bind, **params))
|
|
|
|
if primitive.map_primitive:
|
2020-06-23 09:39:45 -07:00
|
|
|
unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
|
|
|
|
out_pvals = [pval if pval.is_known()
|
|
|
|
else PartialVal.unknown(unmapped_aval(pval[0]))
|
|
|
|
for pval in out_pvals]
|
2020-06-12 15:03:26 +02:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
# Avoid staging out trivial calls, but maps may involve broadcasting.
|
|
|
|
if not jaxpr.eqns and not primitive.map_primitive:
|
2020-05-20 14:30:33 -07:00
|
|
|
env = {core.unitvar: core.unit}
|
|
|
|
map(env.setdefault, jaxpr.invars, (*env_tracers, *tracers))
|
|
|
|
map(env.setdefault, jaxpr.constvars, consts)
|
2020-05-28 17:39:13 +02:00
|
|
|
return [v.val if type(v) is Literal
|
|
|
|
else pval.get_known() if pval.is_known()
|
|
|
|
else env[v]
|
|
|
|
for v, pval in zip(jaxpr.outvars, out_pvals)]
|
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
# Skip known invars and outvars, and lift constants as regular invars
|
|
|
|
in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
|
|
|
|
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
|
|
|
|
jaxpr = _drop_invars(jaxpr, in_knowns)
|
|
|
|
jaxpr = _dce_untyped_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
|
|
|
|
|
|
|
|
# Known tracers get propagated as if they were constants
|
2020-06-23 09:39:45 -07:00
|
|
|
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
|
|
|
|
if pval.is_known()]
|
2020-06-12 15:03:26 +02:00
|
|
|
|
|
|
|
# Unknown tracers need to have the jaxpr set up as their recipe
|
2020-06-23 09:39:45 -07:00
|
|
|
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
|
|
|
|
if not pval.is_known()]
|
2020-06-12 15:03:26 +02:00
|
|
|
unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
|
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
2020-06-23 09:39:45 -07:00
|
|
|
in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)
|
|
|
|
|
|
|
|
# Set up new params
|
|
|
|
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
2020-06-12 15:03:26 +02:00
|
|
|
if primitive.map_primitive:
|
2020-06-23 09:39:45 -07:00
|
|
|
mapped_invars = params['mapped_invars']
|
2020-06-12 15:03:26 +02:00
|
|
|
new_mapped_invars = ((True,) * len(const_tracers) +
|
|
|
|
(False,) * len(env_tracers) +
|
2020-06-23 09:39:45 -07:00
|
|
|
tuple(v for v, t in zip(mapped_invars, tracers)
|
|
|
|
if not t.pval.is_known()))
|
|
|
|
new_params = dict(new_params, mapped_invars=new_mapped_invars)
|
|
|
|
update_params = call_param_updaters.get(primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, [not t.pval.is_known() for t in tracers])
|
|
|
|
|
|
|
|
eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info_util.current())
|
2020-06-23 09:39:45 -07:00
|
|
|
for t in unknown_tracers_out: t.recipe = eqn
|
2020-06-12 15:03:26 +02:00
|
|
|
return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
process_map = process_call
|
|
|
|
|
|
|
|
# We use post_process_call to handle both call and map primitives.
|
|
|
|
def post_process_call(self, primitive, out_tracers, params):
|
2020-04-21 18:12:02 -07:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
|
|
|
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
|
|
|
out = out_pv_consts + consts
|
|
|
|
del consts, out_pv_consts
|
2020-08-30 12:38:14 +03:00
|
|
|
main = self.main
|
2020-04-21 18:12:02 -07:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
if primitive.map_primitive:
|
|
|
|
sz = params['axis_size']
|
|
|
|
out_pvs = [None if pv is None else core.unmapped_aval(sz, pv)
|
|
|
|
for pv in out_pvs]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2019-09-20 07:01:01 -07:00
|
|
|
def todo(x):
|
|
|
|
n = len(jaxpr.outvars)
|
|
|
|
out_pv_consts, consts = x[:n], x[n:]
|
2020-08-30 12:38:14 +03:00
|
|
|
trace = JaxprTrace(main, core.cur_sublevel())
|
2019-09-20 07:01:01 -07:00
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
|
|
|
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2020-06-23 09:39:45 -07:00
|
|
|
in_tracers = (*const_tracers, *map(trace.full_raise, env))
|
|
|
|
|
|
|
|
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
|
|
|
if primitive.map_primitive:
|
|
|
|
new_mapped_invars = (True,) * len(const_tracers) + (False,) * len(env)
|
|
|
|
new_params = dict(new_params, mapped_invars=new_mapped_invars)
|
|
|
|
update_params = call_param_updaters.get(primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, [])
|
|
|
|
|
|
|
|
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info_util.current())
|
2019-09-20 07:01:01 -07:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
return out, todo
|
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
post_process_map = post_process_call
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
|
|
|
|
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]]):
|
|
|
|
"""Partially evaluate f on a sequence of PartialVals."""
|
|
|
|
in_avals, in_consts = unzip2(pvals)
|
2020-08-30 01:16:51 -07:00
|
|
|
f = trace_to_subjaxpr(f, self.main, False)
|
2020-07-30 12:59:36 -07:00
|
|
|
f, aux = partial_eval_wrapper(f, tuple(in_avals))
|
|
|
|
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
|
|
|
|
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvs = map(PartialVal, zip(out_avals, out_consts))
|
|
|
|
env_tracers = map(self.full_raise, env)
|
|
|
|
return jaxpr, out_pvs, consts, env_tracers
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2020-06-23 09:39:45 -07:00
|
|
|
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
|
|
|
|
py_args = map(PartialVal, zip(pvs, consts))
|
|
|
|
jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
|
2019-07-26 16:48:17 -04:00
|
|
|
out_pvs, out_consts = unzip2(out_pvals)
|
2020-06-23 09:39:45 -07:00
|
|
|
out = tuple(out_consts) + tuple(consts)
|
2019-07-26 16:48:17 -04:00
|
|
|
yield out, (out_pvs, jaxpr, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-05-28 17:39:13 +02:00
|
|
|
custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
|
|
|
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
2020-06-23 09:39:45 -07:00
|
|
|
call_param_updaters: Dict[core.Primitive, Callable] = {}
|
2020-05-28 17:39:13 +02:00
|
|
|
|
|
|
|
|
2019-02-13 14:28:30 -08:00
|
|
|
def abstract_eval_fun(fun, *avals, **params):
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
2020-08-25 05:38:41 -07:00
|
|
|
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
|
2020-07-30 12:59:36 -07:00
|
|
|
else:
|
2020-08-25 05:38:41 -07:00
|
|
|
pvals_in = [PartialVal.unknown(a) for a in avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
|
2020-09-15 08:06:46 -07:00
|
|
|
instantiate=True, stage_out=True) # type: ignore
|
2020-08-25 05:38:41 -07:00
|
|
|
avals_out, _ = unzip2(pvals_out)
|
2019-07-26 16:48:17 -04:00
|
|
|
for aval_out in avals_out:
|
|
|
|
assert isinstance(aval_out, AbstractValue) # instantiate=True
|
|
|
|
return avals_out
|
2019-02-13 14:28:30 -08:00
|
|
|
|
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar',
|
|
|
|
'ConstVar', Literal, core.Unit]
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['pval', 'recipe']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def __init__(self, trace: JaxprTrace, pval: PartialVal,
|
|
|
|
recipe: Optional[JaxprTracerRecipe]):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert isinstance(pval, PartialVal)
|
|
|
|
pv, const = pval
|
2020-02-15 06:35:49 +01:00
|
|
|
if isinstance(const, Tracer) and const._trace.level >= trace.level:
|
2020-03-28 14:55:58 -07:00
|
|
|
raise core.escaped_tracer_error(
|
|
|
|
"Tracer from a higher level: {} in trace {}".format(const, trace))
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.pval = pval
|
|
|
|
self.recipe = recipe
|
|
|
|
|
|
|
|
def __repr__(self):
|
2020-01-29 16:23:27 -05:00
|
|
|
return 'Traced<{}:{}>'.format(self.aval, self._trace)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
2020-06-01 21:45:36 -04:00
|
|
|
def aval(self) -> AbstractValue:
|
2020-03-18 07:11:44 +01:00
|
|
|
return self.pval.get_aval()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
2020-06-01 21:45:36 -04:00
|
|
|
def parents(self) -> Sequence['JaxprTracer']:
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(self.recipe, JaxprEqnRecipe):
|
2020-02-03 20:58:56 +01:00
|
|
|
return self.recipe.invars
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return []
|
|
|
|
|
|
|
|
def full_lower(self):
|
2020-03-18 07:11:44 +01:00
|
|
|
known = self.pval.get_known()
|
|
|
|
if known is not None:
|
|
|
|
return core.full_lower(known)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
def is_known(self):
|
|
|
|
return self.pval.is_known()
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
# TODO(necula): this should return a TypedJaxpr
|
2020-03-21 13:54:30 +01:00
|
|
|
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
2020-09-15 08:06:46 -07:00
|
|
|
instantiate: Union[bool, Sequence[bool]] = False,
|
|
|
|
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
2020-03-18 07:11:44 +01:00
|
|
|
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
|
|
|
|
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
|
|
|
|
for the outputs. The intermediate values that depend only on known inputs and
|
|
|
|
are needed to compute the output of `jaxpr` are in `consts` and are passed in
|
|
|
|
as the constvars of the `jaxpr`. The handling of the known outputs depends on
|
|
|
|
`instantiate`.
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
For example, given `fun` defined as follows::
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def fun(ki, ui): # ki will be a known input in this example
|
|
|
|
ka = ki + 2
|
|
|
|
kb = ka + 3
|
|
|
|
return (kb, ui + ka)
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
|
|
|
|
computation that depends on unknown inputs is `ui + ka` and will be the only
|
2020-06-23 09:39:45 -07:00
|
|
|
computation in the body of the `jaxpr`. This computation depends on the known
|
|
|
|
intermediate value `ka`, which will be computed statically. Currently, such
|
|
|
|
constants are either embedded in the Jaxpr if they are scalars, or passed as a
|
|
|
|
constvar to `jaxpr`, and then the value of the actual constant will be in
|
|
|
|
`consts`:
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
When `instantiate=False` we get::
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr =
|
2020-03-18 07:11:44 +01:00
|
|
|
{ lambda ka ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (*, c) } # known outputs are `*`
|
2020-07-30 12:59:36 -07:00
|
|
|
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
|
|
|
|
consts = [3] # the constant for `ka`
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
When `instantiate=True` we get::
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr =
|
2020-03-18 07:11:44 +01:00
|
|
|
{ lambda ka kb ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (kb, c) } # known output are explicit
|
2020-07-30 12:59:36 -07:00
|
|
|
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
|
|
|
|
consts = [3, 6] # values for `ka` and `kb` constvars
|
2020-03-18 07:11:44 +01:00
|
|
|
"""
|
2020-09-15 08:06:46 -07:00
|
|
|
with core.new_main(JaxprTrace) as main:
|
2020-08-30 12:38:14 +03:00
|
|
|
fun = trace_to_subjaxpr(fun, main, instantiate)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
assert not env
|
2020-08-30 12:38:14 +03:00
|
|
|
del main
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
return jaxpr, out_pvals, consts
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2020-08-30 12:38:14 +03:00
|
|
|
def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
|
2020-03-18 07:11:44 +01:00
|
|
|
pvals: Sequence[PartialVal]):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
2020-08-30 12:38:14 +03:00
|
|
|
trace = JaxprTrace(main, core.cur_sublevel())
|
2018-11-17 18:03:33 -08:00
|
|
|
in_tracers = map(trace.new_arg, pvals)
|
2019-06-23 15:31:13 -07:00
|
|
|
ans = yield in_tracers, {}
|
2020-03-18 07:11:44 +01:00
|
|
|
instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
|
2019-07-27 15:46:14 -07:00
|
|
|
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
|
|
|
|
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
|
|
|
|
out_pvals = [t.pval for t in out_tracers]
|
|
|
|
del trace, in_tracers, out_tracers
|
|
|
|
yield jaxpr, (out_pvals, consts, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
|
2019-07-27 15:46:14 -07:00
|
|
|
if instantiate:
|
|
|
|
return trace.instantiate_const(trace.full_raise(tracer))
|
2019-05-10 08:58:05 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return tracer
|
2019-05-10 08:58:05 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
FreeVar = namedtuple('FreeVar', ['val'])
|
|
|
|
ConstVar = namedtuple('ConstVar', ['val'])
|
|
|
|
LambdaBinding = namedtuple('LambdaBinding', [])
|
2020-06-01 21:45:36 -04:00
|
|
|
class JaxprEqnRecipe(NamedTuple):
|
|
|
|
eqn_id: object
|
|
|
|
invars: Sequence[JaxprTracer]
|
|
|
|
outvars: 'Sequence[ref[JaxprTracer]]'
|
|
|
|
primitive: core.Primitive
|
|
|
|
params: Dict[str, Any]
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info: Optional[source_info_util.Traceback]
|
2020-06-01 21:45:36 -04:00
|
|
|
|
|
|
|
def new_eqn_recipe(invars: Sequence[JaxprTracer],
|
|
|
|
outvars: Sequence[JaxprTracer],
|
|
|
|
primitive: core.Primitive,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
params: Dict[str, Any],
|
|
|
|
source_info: Optional[source_info_util.Traceback]
|
|
|
|
) -> JaxprEqnRecipe:
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Constructs a new JaxEqnRecipe.
|
|
|
|
|
|
|
|
Params:
|
|
|
|
invars: the tracers for the primitive inputs.
|
|
|
|
outvars: the tracers for the primitive outputs.
|
|
|
|
primitive: the primitive.
|
|
|
|
params: the primitive params
|
|
|
|
"""
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
|
|
|
|
if primitive.call_primitive or primitive.map_primitive:
|
2020-02-05 15:38:25 +01:00
|
|
|
assert "call_jaxpr" in params
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
if primitive.map_primitive:
|
2020-06-23 09:39:45 -07:00
|
|
|
assert ("mapped_invars" in params and
|
|
|
|
len(params["mapped_invars"]) == len(params["call_jaxpr"].invars))
|
|
|
|
assert ("donated_invars" in params and
|
|
|
|
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
|
2020-01-07 13:11:32 -08:00
|
|
|
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
params, source_info)
|
2020-02-05 15:38:25 +01:00
|
|
|
|
2019-11-19 12:26:30 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
def recipe_to_eqn(getvar: Callable[[JaxprTracer], core.Atom],
|
2020-06-01 21:45:36 -04:00
|
|
|
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
|
2019-11-20 09:12:15 -08:00
|
|
|
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
|
|
|
invars = [getvar(t) for t in in_tracers]
|
2020-06-23 09:39:45 -07:00
|
|
|
outvars = [core.dropvar if t is None else cast(core.Var, getvar(t))
|
2020-06-01 21:45:36 -04:00
|
|
|
for t in out_tracers]
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-01 21:45:36 -04:00
|
|
|
def tracers_to_jaxpr(
|
|
|
|
in_tracers: List[JaxprTracer],
|
2020-06-02 10:26:43 -04:00
|
|
|
out_tracers: List[JaxprTracer]
|
|
|
|
) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Constructs Jaxpr given tracers for inputs and outputs.
|
|
|
|
|
|
|
|
Params:
|
|
|
|
in_tracers: the tracers that were created for the function inputs
|
|
|
|
out_tracers: the tracers that were output by the function.
|
|
|
|
|
|
|
|
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
|
|
|
|
the `constvars` in the returned Jaxps, and a list of environment values.
|
2020-03-18 07:11:44 +01:00
|
|
|
The vars for the environment values have been prepended to the Jaxpr's
|
2020-01-07 13:11:32 -08:00
|
|
|
`invars`.
|
|
|
|
"""
|
2020-05-26 11:28:50 -07:00
|
|
|
newvar = core.gensym()
|
2020-06-01 21:45:36 -04:00
|
|
|
t_to_var: Dict[int, core.Atom] = {}
|
|
|
|
def getvar(t: JaxprTracer) -> core.Atom:
|
2020-03-09 09:14:23 +00:00
|
|
|
var = t_to_var.get(id(t))
|
|
|
|
if var is None:
|
2020-03-18 07:11:44 +01:00
|
|
|
aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
|
2020-04-08 22:29:07 -07:00
|
|
|
var = t_to_var[id(t)] = newvar(aval)
|
2020-03-09 09:14:23 +00:00
|
|
|
return var
|
2019-07-26 16:48:17 -04:00
|
|
|
sorted_tracers = toposort(out_tracers)
|
2019-11-20 09:12:15 -08:00
|
|
|
invars = map(getvar, in_tracers)
|
2020-06-01 21:45:36 -04:00
|
|
|
eqns: List[core.JaxprEqn] = []
|
|
|
|
env: Dict[core.Var, Any] = {}
|
|
|
|
consts: Dict[core.Var, Any] = {}
|
|
|
|
const_to_var: Dict[int, core.Var] = {}
|
2020-03-09 09:14:23 +00:00
|
|
|
def getconstvar(c):
|
|
|
|
var = const_to_var.get(id(c))
|
|
|
|
if var is None:
|
2020-04-08 22:29:07 -07:00
|
|
|
var = const_to_var[id(c)] = newvar(get_aval(c))
|
2020-03-09 09:14:23 +00:00
|
|
|
return var
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids = set()
|
2018-11-17 18:03:33 -08:00
|
|
|
for t in sorted_tracers:
|
|
|
|
recipe = t.recipe
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(recipe, JaxprEqnRecipe):
|
2019-11-20 09:12:15 -08:00
|
|
|
if recipe.eqn_id not in processed_eqn_ids:
|
2020-06-23 09:39:45 -07:00
|
|
|
eqns.append(recipe_to_eqn(getvar, recipe))
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids.add(recipe.eqn_id)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, LambdaBinding):
|
2020-02-15 06:35:49 +01:00
|
|
|
if not any(t is in_tracer for in_tracer in in_tracers):
|
2020-03-28 14:55:58 -07:00
|
|
|
raise core.escaped_tracer_error(
|
|
|
|
"Tracer not among input tracers {}".format(t))
|
2018-11-17 18:03:33 -08:00
|
|
|
assert in_tracers, "Lambda binding with no args"
|
|
|
|
elif isinstance(recipe, FreeVar):
|
2020-06-01 21:45:36 -04:00
|
|
|
env[cast(core.Var, getvar(t))] = recipe.val
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, ConstVar):
|
2020-03-09 09:14:23 +00:00
|
|
|
v = t_to_var[id(t)] = getconstvar(recipe.val)
|
2019-06-18 08:09:37 -07:00
|
|
|
consts[v] = recipe.val
|
2019-05-13 08:48:13 -07:00
|
|
|
elif isinstance(recipe, Literal):
|
|
|
|
t_to_var[id(t)] = recipe
|
2018-11-17 18:03:33 -08:00
|
|
|
elif recipe is unit:
|
2018-11-21 13:20:44 -08:00
|
|
|
t_to_var[id(t)] = unitvar
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError(recipe)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
env_vars, env_vals = unzip2(env.items())
|
|
|
|
const_vars, const_vals = unzip2(consts.items())
|
2020-01-07 13:11:32 -08:00
|
|
|
# The env_vars are pre-pended to the invars
|
2020-06-23 09:39:45 -07:00
|
|
|
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
|
2019-09-20 15:35:43 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2018-11-17 18:03:33 -08:00
|
|
|
return jaxpr, const_vals, env_vals
|
|
|
|
|
2020-02-05 13:55:59 +01:00
|
|
|
@cache()
|
2020-06-01 21:45:36 -04:00
|
|
|
def convert_constvars_jaxpr(jaxpr: Jaxpr):
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Moves the constvars to the start of invars."""
|
2019-05-08 16:27:23 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2020-01-07 13:11:32 -08:00
|
|
|
lifted_jaxpr = Jaxpr(constvars=(),
|
2019-07-27 15:46:14 -07:00
|
|
|
invars=jaxpr.constvars + jaxpr.invars,
|
|
|
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
2019-05-01 15:47:01 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
2019-04-09 08:45:34 -07:00
|
|
|
return lifted_jaxpr
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
|
|
|
|
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
|
2020-04-17 20:08:24 +03:00
|
|
|
instantiate: Union[bool, Sequence[bool]],
|
2020-04-09 14:10:52 +02:00
|
|
|
) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]:
|
2020-03-18 07:11:44 +01:00
|
|
|
"""Specializes a Jaxpr given an indication of which inputs are known.
|
|
|
|
|
|
|
|
Returns: (jaxpr_known, jaxpr_unknown, out_unknowns).
|
|
|
|
|
|
|
|
`out_unknowns` specifies which outputs are unknown (depend on some unknown inputs).
|
|
|
|
`jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs,
|
|
|
|
and performs *all* the computation in `jaxpr` that depends only on the known inputs.
|
|
|
|
Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`,
|
|
|
|
appended with the known residuals (the intermediate computations in `jaxpr`
|
|
|
|
that depend only on known inputs and that are needed to compute the unknown outputs).
|
|
|
|
|
|
|
|
`jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals
|
|
|
|
computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known
|
|
|
|
outputs replaced by `*`.
|
|
|
|
|
|
|
|
Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively
|
|
|
|
unknown inputs into:
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *)
|
|
|
|
let _, uout = jaxpr_unknown(ki, ui, kresidual)
|
|
|
|
in (kout, uout)
|
2020-03-18 07:11:44 +01:00
|
|
|
|
|
|
|
For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2
|
|
|
|
in (ki + 3, ui + ka)"
|
|
|
|
then
|
|
|
|
`jaxpr_known` = lambda ki, ui: let ka = ki + 2
|
|
|
|
in (ki + 3, *, ka)
|
|
|
|
'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka)
|
2020-08-04 12:41:30 +00:00
|
|
|
|
|
|
|
Note that if instantiate is True for a given output, then jaxpr_known always returns a
|
|
|
|
unit in its place. So when instantiate is True, the expectation is the one doesn't
|
|
|
|
run `jaxpr_known` for any of its outputs, but only to generate residuals that will allow
|
|
|
|
to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
|
|
|
|
simply get passed as residual constants and returned immediately.
|
2020-03-18 07:11:44 +01:00
|
|
|
"""
|
2019-04-23 09:15:16 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-04-11 14:50:58 -07:00
|
|
|
|
|
|
|
cell = []
|
|
|
|
def fun(*vals):
|
2020-03-18 07:11:44 +01:00
|
|
|
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
2020-09-15 08:06:46 -07:00
|
|
|
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
|
|
|
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
|
|
|
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
|
|
|
return out_consts_2 + consts_2
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
|
|
|
|
# known inputs.
|
|
|
|
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
|
|
|
|
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
|
2019-07-27 15:46:14 -07:00
|
|
|
(out_pvs_2, jaxpr_2, num_res), = cell
|
|
|
|
assert len(jaxpr_2.constvars) == num_res
|
|
|
|
|
|
|
|
# jaxpr :: a -> b
|
|
|
|
# jaxpr_1 :: a1 -> [b1, res]
|
|
|
|
# jaxpr_2 :: res | a2 -> b2
|
|
|
|
# jaxpr_2 :: [a2, res] -> b2
|
2020-02-03 20:58:56 +01:00
|
|
|
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
|
2020-03-09 09:14:23 +00:00
|
|
|
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
|
|
|
|
if not unknown:
|
|
|
|
var.aval = abstract_unit
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
uk_out = [pv is not None for pv in out_pvs_2]
|
|
|
|
|
|
|
|
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
|
|
|
|
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
|
|
|
|
# out_avals_1 and in_avals_2 need the residuals added
|
2020-09-15 08:06:46 -07:00
|
|
|
res_avals = out_avals[len(jaxpr.out_avals):]
|
2019-07-27 15:46:14 -07:00
|
|
|
assert len(res_avals) == num_res
|
2020-09-15 08:06:46 -07:00
|
|
|
out_avals_1 = [*out_avals_1, *res_avals]
|
|
|
|
in_avals_2 = [*in_avals_2, *res_avals]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1)
|
|
|
|
typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2)
|
|
|
|
return typed_jaxpr_1, typed_jaxpr_2, uk_out
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
remat_call_p = core.CallPrimitive('remat_call')
|
|
|
|
remat_call = remat_call_p.bind
|
2019-11-22 10:53:11 -08:00
|
|
|
remat_call_p.def_impl(core.call_impl)
|
|
|
|
|
2020-08-11 11:45:58 +02:00
|
|
|
def _remat_partial_eval(trace, _, f, tracers, params):
|
2019-11-22 10:53:11 -08:00
|
|
|
concrete = params['concrete']
|
|
|
|
|
|
|
|
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
|
|
|
# the function being called, not just for the unknown parts. To do that, we
|
|
|
|
# instantiate all the input tracers as constants in the jaxpr being formed.
|
|
|
|
# Those tracers might have concrete avals, and doing abstract interpretation
|
|
|
|
# on concrete avals engenders a tradeoff: it allows data-dependent Python
|
|
|
|
# control flow to work, but it can in some cases lead to redundant FLOPs (done
|
|
|
|
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
|
|
|
|
# `concrete` parameter to switch this behavior, and if `concrete` is False
|
|
|
|
# then we raise the avals to the Shaped level.
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
if concrete:
|
|
|
|
instantiated_tracers = map(trace.instantiate_const, tracers)
|
|
|
|
else:
|
|
|
|
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
2020-05-28 17:39:13 +02:00
|
|
|
in_pvals = [t.pval for t in instantiated_tracers]
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
2020-06-16 15:46:51 -07:00
|
|
|
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
|
|
|
f, in_pvals, partial(remat_call_p.bind, **params))
|
2020-07-30 12:59:36 -07:00
|
|
|
else:
|
2020-09-15 08:06:46 -07:00
|
|
|
with core.initial_style_staging(): # type: ignore
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
|
|
|
f, in_pvals, partial(remat_call_p.bind, **params))
|
2020-06-15 18:42:53 -07:00
|
|
|
|
|
|
|
# Convert consts to inputs, since they may contain Tracer instances.
|
|
|
|
jaxpr = convert_constvars_jaxpr(jaxpr)
|
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
# Since we traced with everything marked as unknown, but we need to know which
|
|
|
|
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
2020-06-15 18:42:53 -07:00
|
|
|
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in const_tracers] +
|
|
|
|
[raise_to_shaped(t.pval.get_aval()) for t in env_tracers] +
|
|
|
|
[raise_to_shaped(pval.get_aval()) for pval in in_pvals])
|
2020-05-28 17:39:13 +02:00
|
|
|
out_avals = [raise_to_shaped(abstract_unit if var is unitvar
|
2019-12-23 11:49:01 -08:00
|
|
|
else get_aval(var.val) if type(var) is Literal
|
2020-05-28 17:39:13 +02:00
|
|
|
else pval.get_aval())
|
2020-06-12 15:03:26 +02:00
|
|
|
for var, pval in zip(jaxpr.outvars, eval_out_pvals)]
|
2020-06-15 18:42:53 -07:00
|
|
|
typed_jaxpr = core.TypedJaxpr(jaxpr, (), in_avals, out_avals)
|
|
|
|
in_unknowns = ([False] * len(consts) +
|
|
|
|
[not t.is_known() for t in it.chain(env_tracers, tracers)])
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
|
|
|
typed_jaxpr, in_unknowns, instantiate=False) # type: ignore
|
|
|
|
else:
|
|
|
|
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
2020-09-15 08:06:46 -07:00
|
|
|
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore
|
2020-06-12 15:03:26 +02:00
|
|
|
out_knowns = [not b for b in out_unknowns]
|
|
|
|
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
|
2019-11-27 15:25:49 -08:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
# Next, we need values for the outputs that should be known. Since consts
|
2020-06-12 15:03:26 +02:00
|
|
|
# weren't passed through Python for evaluation, we need to evaluate jaxpr_known,
|
2019-11-22 10:53:11 -08:00
|
|
|
# minus the residual outputs that we don't need. When `concrete=True`, as an
|
|
|
|
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
|
|
|
|
# produced concrete avals at the output, simply by using those as computed
|
2020-06-12 15:03:26 +02:00
|
|
|
# values. For the use case of inverse-mode ad in op-by-op ("eager mode")
|
2019-11-27 14:28:13 -08:00
|
|
|
# evaluation, all the primal outputs should be concrete (thus not recomputed).
|
2020-06-12 15:03:26 +02:00
|
|
|
to_compute = [type(pval[0]) is not ConcreteArray
|
2020-06-15 18:42:53 -07:00
|
|
|
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
|
2020-06-12 15:03:26 +02:00
|
|
|
num_outputs = len(jaxpr_unknown.out_avals)
|
|
|
|
num_res = len(jaxpr_known.out_avals) - num_outputs
|
|
|
|
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
|
|
|
|
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
|
2020-05-28 17:39:13 +02:00
|
|
|
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
|
2020-06-15 18:42:53 -07:00
|
|
|
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
|
2020-06-12 15:03:26 +02:00
|
|
|
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
|
|
|
|
|
|
|
|
# Known outputs should keep propagating as constants
|
|
|
|
assert all(pv.is_known() for pv in out_known_pvals)
|
2020-06-15 18:42:53 -07:00
|
|
|
known_output_tracers = [trace.new_const(pval.get_known())
|
|
|
|
for pval in out_known_pvals]
|
|
|
|
# Unknown outputs get wrapped in tracers with the appropriate recipe
|
|
|
|
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
|
|
|
|
for out_pval in out_unknown_pvals]
|
|
|
|
|
|
|
|
# dce jaxpr outputs
|
|
|
|
new_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
|
|
|
|
new_params = dict(params, call_jaxpr=new_jaxpr)
|
|
|
|
|
|
|
|
# set up eqn for unknown outputs
|
2020-08-11 11:45:58 +02:00
|
|
|
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
|
|
|
|
eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p, new_params,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
source_info_util.current())
|
2020-08-11 11:45:58 +02:00
|
|
|
for t in unknown_output_tracers: t.recipe = eqn
|
|
|
|
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
|
|
|
|
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
2020-06-15 18:42:53 -07:00
|
|
|
|
|
|
|
def _partition_knowns(pvals, unknowns: Sequence[bool]):
|
|
|
|
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
|
|
|
|
[e for e, unknown in zip(pvals, unknowns) if unknown])
|
|
|
|
|
|
|
|
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
|
|
|
|
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
|
|
|
|
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
|
2020-06-12 15:03:26 +02:00
|
|
|
|
|
|
|
|
|
|
|
def _dce_jaxpr(typed_jaxpr: TypedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> TypedJaxpr:
|
|
|
|
if drop_outputs:
|
|
|
|
new_out_avals = [aval for aval, output in zip(typed_jaxpr.out_avals, outputs) if output]
|
|
|
|
else:
|
|
|
|
new_out_avals = [aval if output else core.abstract_unit
|
|
|
|
for aval, output in zip(typed_jaxpr.out_avals, outputs)]
|
|
|
|
new_jaxpr = _dce_untyped_jaxpr(typed_jaxpr.jaxpr, tuple(outputs), drop_outputs)
|
|
|
|
return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals,
|
|
|
|
new_out_avals)
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _dce_untyped_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
|
2019-11-22 10:53:11 -08:00
|
|
|
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
2019-11-27 14:28:13 -08:00
|
|
|
# nontrivially DCE through scan, call, or other higher-order primitives.
|
|
|
|
# TODO(mattjj): better DCE
|
2020-06-12 15:03:26 +02:00
|
|
|
if drop_outputs:
|
|
|
|
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
|
|
|
|
else:
|
|
|
|
new_outvars = [var if output else unitvar
|
|
|
|
for var, output in zip(jaxpr.outvars, outputs)]
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-01-22 17:19:14 -08:00
|
|
|
needed_vars = {v for v in new_outvars if type(v) is not Literal}
|
2019-11-22 10:53:11 -08:00
|
|
|
new_eqns = []
|
|
|
|
for eqn in jaxpr.eqns[::-1]:
|
|
|
|
if set(eqn.outvars) & needed_vars:
|
|
|
|
new_eqns.append(eqn)
|
2020-01-22 17:19:14 -08:00
|
|
|
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
|
2019-11-22 10:53:11 -08:00
|
|
|
new_eqns = new_eqns[::-1]
|
2020-06-12 15:03:26 +02:00
|
|
|
return core.Jaxpr(jaxpr.constvars, jaxpr.invars,
|
|
|
|
new_outvars, new_eqns)
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _drop_invars(jaxpr: Jaxpr, drop: Tuple[bool, ...]):
|
|
|
|
return core.Jaxpr(jaxpr.constvars, [v for v, d in zip(jaxpr.invars, drop) if not d],
|
|
|
|
jaxpr.outvars, jaxpr.eqns)
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
|
2020-06-12 15:03:26 +02:00
|
|
|
def _reconstruct_pval(pval1: PartialVal, const2: core.Value):
|
2020-03-18 07:11:44 +01:00
|
|
|
pv1, _ = pval1
|
2020-06-12 15:03:26 +02:00
|
|
|
if pval1.is_known():
|
2019-11-22 10:53:11 -08:00
|
|
|
return pval1
|
|
|
|
else:
|
|
|
|
if type(pv1) is ConcreteArray:
|
2020-06-22 20:04:07 -07:00
|
|
|
return PartialVal.known(pv1.val) # pytype: disable=attribute-error
|
2019-11-22 10:53:11 -08:00
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
return PartialVal.known(const2)
|
2019-11-27 14:28:13 -08:00
|
|
|
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def move_binders_to_front(typed_jaxpr: TypedJaxpr, to_move: Sequence[bool]) -> TypedJaxpr:
|
|
|
|
"""Reorder the `invars` to move to front the ones for which `to_move` is True."""
|
2020-01-07 13:11:32 -08:00
|
|
|
assert not typed_jaxpr.jaxpr.constvars
|
2019-11-27 14:28:13 -08:00
|
|
|
assert len(typed_jaxpr.in_avals) == len(to_move)
|
|
|
|
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
|
2020-01-07 13:11:32 -08:00
|
|
|
new_jaxpr = core.Jaxpr((), new_invars, typed_jaxpr.jaxpr.outvars,
|
2019-11-27 14:28:13 -08:00
|
|
|
typed_jaxpr.jaxpr.eqns)
|
|
|
|
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
|
|
|
|
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
|
|
|
|
new_in_avals, typed_jaxpr.out_avals)
|
|
|
|
return new_typed_jaxpr
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
|
2019-11-27 14:28:13 -08:00
|
|
|
return ([elt for elt, move in zip(lst, to_move) if move] +
|
|
|
|
[elt for elt, move in zip(lst, to_move) if not move])
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
|
|
|
|
class DynamicJaxprTracer(core.Tracer):
|
|
|
|
__slots__ = ['aval', 'line_info']
|
|
|
|
|
|
|
|
def __init__(self, trace, aval, line_info=None):
|
|
|
|
self._trace = trace
|
|
|
|
self.aval = aval
|
|
|
|
self.line_info = line_info
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def _contents(self):
|
|
|
|
return ()
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def _origin_msg(self):
|
2020-07-30 12:59:36 -07:00
|
|
|
progenitor_eqns = self._trace.frame.find_progenitors(self)
|
|
|
|
msgs = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
|
|
|
|
f" from line {source_info_util.summarize(eqn.source_info)}"
|
|
|
|
for eqn in progenitor_eqns]
|
2020-09-15 08:06:46 -07:00
|
|
|
if msgs:
|
|
|
|
origin = (f"While tracing the function {self._trace.main.source_info}, "
|
|
|
|
"this value became a tracer due to JAX operations on these lines:"
|
|
|
|
"\n\n" + "\n\n".join(msgs))
|
|
|
|
else:
|
|
|
|
origin = ("The error occured while tracing the function "
|
|
|
|
f"{self._trace.main.source_info}.")
|
|
|
|
return origin
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-09-16 15:59:50 -07:00
|
|
|
def _assert_live(self) -> None:
|
|
|
|
if not self._trace.main.jaxpr_stack: # type: ignore
|
|
|
|
msg = f"tracer created on line {source_info_util.summarize(self.line_info)}"
|
|
|
|
raise core.escaped_tracer_error(msg)
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
class JaxprStackFrame:
|
|
|
|
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
|
|
|
|
'tracers', 'eqns']
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.newvar = core.gensym()
|
|
|
|
self.tracer_to_var = {}
|
|
|
|
self.constid_to_var = {}
|
|
|
|
self.constvar_to_val = {}
|
2020-08-30 12:38:14 +03:00
|
|
|
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
|
|
|
|
self.eqns = [] # cleared when we pop frame from main
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def to_jaxpr(self, in_tracers, out_tracers):
|
|
|
|
invars = [self.tracer_to_var[id(t)] for t in in_tracers]
|
|
|
|
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
|
|
|
constvars, constvals = unzip2(self.constvar_to_val.items())
|
|
|
|
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
|
|
|
|
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
|
|
|
# core.skip_checks or core.check_jaxpr(jaxpr)
|
|
|
|
out_avals = [t.aval for t in out_tracers]
|
|
|
|
return jaxpr, out_avals, constvals
|
|
|
|
|
|
|
|
def find_progenitors(self, tracer):
|
2020-09-15 08:06:46 -07:00
|
|
|
var = self.tracer_to_var.get(id(tracer))
|
|
|
|
if not var:
|
|
|
|
return []
|
|
|
|
active_vars = {var}
|
2020-07-30 12:59:36 -07:00
|
|
|
for eqn in self.eqns[::-1]:
|
|
|
|
produced = set(eqn.outvars) & active_vars
|
|
|
|
if produced:
|
|
|
|
active_vars.difference_update(produced)
|
|
|
|
active_vars.update(eqn.invars)
|
|
|
|
return [eqn for eqn in self.eqns if set(eqn.invars) & active_vars]
|
|
|
|
|
|
|
|
def _inline_literals(jaxpr, constvals):
|
|
|
|
consts = dict(zip(jaxpr.constvars, constvals))
|
|
|
|
newvar = core.gensym()
|
|
|
|
class var(dict):
|
|
|
|
def __missing__(self, v):
|
|
|
|
new_v = self[v] = newvar(v.aval)
|
|
|
|
return new_v
|
|
|
|
var = var()
|
|
|
|
|
|
|
|
def lit(var: core.Var) -> Optional[Any]:
|
|
|
|
val = consts.get(var)
|
|
|
|
if type(val) in core.literalable_types and not np.shape(val):
|
|
|
|
return Literal(val)
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
|
|
|
|
new_constvars = [var[v] for v in jaxpr.constvars if not lit(v)]
|
|
|
|
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
|
|
|
|
new_invars = [var[v] for v in jaxpr.invars]
|
|
|
|
new_eqns = [new_jaxpr_eqn([lit(v) or var[v] for v in eqn.invars],
|
|
|
|
[var[v] if v in used else dropvar for v in eqn.outvars],
|
|
|
|
eqn.primitive, eqn.params, eqn.source_info)
|
|
|
|
for eqn in jaxpr.eqns]
|
|
|
|
new_outvars = [lit(v) or var[v] for v in jaxpr.outvars]
|
|
|
|
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
|
|
|
|
return new_jaxpr, new_constvals
|
|
|
|
|
|
|
|
class DynamicJaxprTrace(core.Trace):
|
|
|
|
__slots__ = [] # type: ignore
|
|
|
|
|
|
|
|
@property
|
2020-09-16 15:59:50 -07:00
|
|
|
def frame(self):
|
|
|
|
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def new_arg(self, aval):
|
2020-09-16 15:59:50 -07:00
|
|
|
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.tracers.append(tracer)
|
|
|
|
self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(aval)
|
|
|
|
return tracer
|
|
|
|
|
|
|
|
def new_const(self, val):
|
2020-09-16 15:59:50 -07:00
|
|
|
aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val))
|
|
|
|
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.tracers.append(tracer)
|
|
|
|
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
|
|
|
|
self.frame.constvar_to_val[var] = val
|
|
|
|
return tracer
|
|
|
|
|
|
|
|
pure = lift = sublift = new_const
|
|
|
|
|
|
|
|
def getvar(self, tracer):
|
|
|
|
var = self.frame.tracer_to_var.get(id(tracer))
|
|
|
|
if var is None:
|
|
|
|
self.frame.tracers.append(tracer)
|
|
|
|
var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
|
|
|
|
return var
|
|
|
|
|
|
|
|
def getconstvar(self, c):
|
|
|
|
var = self.frame.constid_to_var.get(id(c))
|
|
|
|
if var is None:
|
|
|
|
var = self.frame.constid_to_var[id(c)] = self.frame.newvar(get_aval(c))
|
|
|
|
return var
|
|
|
|
|
|
|
|
def instantiate_const(self, val):
|
2020-08-30 01:16:51 -07:00
|
|
|
if (isinstance(val, Tracer) and val._trace.main is self.main
|
2020-07-30 12:59:36 -07:00
|
|
|
and val._trace.sublevel == self.sublevel):
|
|
|
|
return val
|
|
|
|
else:
|
|
|
|
return self.new_const(val)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
avals = [t.aval for t in tracers]
|
|
|
|
out_avals = primitive.abstract_eval(*avals, **params)
|
|
|
|
out_avals = [out_avals] if not primitive.multiple_results else out_avals
|
2020-09-16 15:59:50 -07:00
|
|
|
source_info = source_info_util.current()
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
outvars = map(self.getvar, out_tracers)
|
2020-09-16 15:59:50 -07:00
|
|
|
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.eqns.append(eqn)
|
|
|
|
return out_tracers if primitive.multiple_results else out_tracers.pop()
|
|
|
|
|
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
2020-08-30 01:16:51 -07:00
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
|
2020-07-30 12:59:36 -07:00
|
|
|
if not jaxpr.eqns:
|
|
|
|
return core.eval_jaxpr(jaxpr, consts, *tracers)
|
2020-09-16 15:59:50 -07:00
|
|
|
source_info = source_info_util.current()
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
outvars = map(self.getvar, out_tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
|
|
|
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
|
|
|
update_params = call_param_updaters.get(call_primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, [True] * len(tracers))
|
2020-09-16 15:59:50 -07:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
|
|
|
|
new_params, source_info)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.eqns.append(eqn)
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
assert False # unreachable
|
|
|
|
|
|
|
|
def process_map(self, map_primitive, f, tracers, params):
|
|
|
|
in_avals = [t.aval for t in tracers]
|
|
|
|
axis_name, axis_size = params['axis_name'], params['axis_size']
|
|
|
|
reduced_in_avals = [core.mapped_aval(axis_size, a) if m else a
|
|
|
|
for m, a in zip(params['mapped_invars'], in_avals)]
|
2020-08-14 18:22:04 +02:00
|
|
|
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
2020-08-30 01:16:51 -07:00
|
|
|
f, self.main, reduced_in_avals)
|
2020-07-30 12:59:36 -07:00
|
|
|
out_avals = [core.unmapped_aval(params['axis_size'], a) for a in reduced_out_avals]
|
2020-09-16 15:59:50 -07:00
|
|
|
source_info = source_info_util.current()
|
|
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
2020-07-30 12:59:36 -07:00
|
|
|
invars = map(self.getvar, tracers)
|
|
|
|
outvars = map(self.getvar, out_tracers)
|
|
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
|
|
|
new_mapped_invars = (False,) * len(consts) + params['mapped_invars']
|
|
|
|
new_params = dict(params, mapped_invars=new_mapped_invars,
|
|
|
|
call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
|
|
|
update_params = call_param_updaters.get(map_primitive)
|
|
|
|
if update_params:
|
|
|
|
new_params = update_params(new_params, [True] * len(tracers))
|
2020-09-16 15:59:50 -07:00
|
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
|
|
|
|
new_params, source_info)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.frame.eqns.append(eqn)
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def post_process_map(self, map_primitive, out_tracers, params):
|
|
|
|
assert False # unreachable
|
|
|
|
|
|
|
|
|
|
|
|
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
|
|
|
assert config.omnistaging_enabled
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
|
|
|
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
|
|
|
main.jaxpr_stack = () # type: ignore
|
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
|
|
|
del main
|
2020-07-30 12:59:36 -07:00
|
|
|
return jaxpr, out_avals, consts
|
|
|
|
|
2020-08-30 01:16:51 -07:00
|
|
|
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
|
|
|
in_avals: Sequence[AbstractValue]):
|
2020-07-30 12:59:36 -07:00
|
|
|
frame = JaxprStackFrame()
|
2020-08-30 01:16:51 -07:00
|
|
|
with extend_jaxpr_stack(main, frame):
|
|
|
|
trace = DynamicJaxprTrace(main, core.cur_sublevel())
|
2020-07-30 12:59:36 -07:00
|
|
|
in_tracers = map(trace.new_arg, in_avals)
|
|
|
|
ans = fun.call_wrapped(*in_tracers)
|
|
|
|
out_tracers = map(trace.full_raise, ans)
|
|
|
|
jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
|
|
|
|
return jaxpr, out_avals, consts
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
2020-08-30 01:16:51 -07:00
|
|
|
def extend_jaxpr_stack(main, frame):
|
|
|
|
main.jaxpr_stack = main.jaxpr_stack + (frame,)
|
2020-07-30 12:59:36 -07:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
2020-08-30 01:16:51 -07:00
|
|
|
assert frame is main.jaxpr_stack[-1]
|
|
|
|
main.jaxpr_stack = main.jaxpr_stack[:-1]
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
|
|
|
assert config.omnistaging_enabled
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
|
|
|
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
|
|
|
main.jaxpr_stack = () # type: ignore
|
|
|
|
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
|
|
|
del main
|
2020-07-30 12:59:36 -07:00
|
|
|
return jaxpr, out_avals, consts
|
|
|
|
|
2020-08-10 07:20:09 -07:00
|
|
|
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
|
|
|
|
# This function provides a partial evaluation behavior used by Flax. We can't
|
|
|
|
# use trace_to_jaxpr directly because of an interaction with the curent
|
|
|
|
# custom_derivatives.py, which we work around by adding the EvalTrace.
|
|
|
|
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
|
|
|
|
assert config.omnistaging_enabled
|
2020-08-30 01:16:51 -07:00
|
|
|
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
2020-08-10 07:20:09 -07:00
|
|
|
return trace_to_jaxpr(fun, in_pvals)
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def fun_sourceinfo(fun):
|
|
|
|
if isinstance(fun, functools.partial):
|
|
|
|
fun = fun.func
|
|
|
|
try:
|
|
|
|
filename = fun.__code__.co_filename
|
|
|
|
lineno = fun.__code__.co_firstlineno
|
|
|
|
return f"{fun.__name__} at {filename}:{lineno}"
|
|
|
|
except AttributeError:
|
|
|
|
return "<unknown>"
|
|
|
|
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
@config.register_omnistaging_disabler
|
|
|
|
@no_type_check
|
|
|
|
def omnistaging_disabler() -> None:
|
|
|
|
global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
2020-09-15 08:06:46 -07:00
|
|
|
instantiate: Union[bool, Sequence[bool]] = False,
|
|
|
|
stage_out=False, bottom=False,
|
|
|
|
trace_type: Optional[Type[Trace]] = None,
|
|
|
|
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
|
|
|
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
|
|
|
|
|
|
|
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
|
|
|
|
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
|
|
|
|
for the outputs. The intermediate values that depend only on known inputs and
|
|
|
|
are needed to compute the output of `jaxpr` are in `consts` and are passed in
|
|
|
|
as the constvars of the `jaxpr`. The handling of the known outputs depends on
|
|
|
|
`instantiate`.
|
|
|
|
|
|
|
|
For example, given `fun` defined as follows::
|
|
|
|
|
|
|
|
def fun(ki, ui): # ki will be a known input in this example
|
|
|
|
ka = ki + 2
|
|
|
|
kb = ka + 3
|
|
|
|
return (kb, ui + ka)
|
|
|
|
|
|
|
|
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
|
|
|
|
computation that depends on unknown inputs is `ui + ka` and will be the only
|
|
|
|
computation in the body of the `jaxpr`. This computation depends on the known
|
|
|
|
intermediate value `ka`, which will be computed statically. Currently, such
|
|
|
|
constants are either embedded in the Jaxpr if they are scalars, or passed as a
|
|
|
|
constvar to `jaxpr`, and then the value of the actual constant will be in
|
|
|
|
`consts`:
|
|
|
|
|
|
|
|
When `instantiate=False` we get::
|
|
|
|
|
|
|
|
jaxpr =
|
|
|
|
{ lambda ka ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (*, c) } # known outputs are `*`
|
|
|
|
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
|
|
|
|
consts = [3] # the constant for `ka`
|
|
|
|
|
|
|
|
When `instantiate=True` we get::
|
|
|
|
|
|
|
|
jaxpr =
|
|
|
|
{ lambda ka kb ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (kb, c) } # known output are explicit
|
|
|
|
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
|
|
|
|
consts = [3, 6] # values for `ka` and `kb` constvars
|
|
|
|
"""
|
|
|
|
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
|
|
|
|
with core.new_main(trace_type, bottom=bottom) as main:
|
2020-08-30 12:38:14 +03:00
|
|
|
fun = trace_to_subjaxpr(fun, main, instantiate)
|
2020-07-30 12:59:36 -07:00
|
|
|
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
|
|
|
assert not env
|
2020-08-30 12:38:14 +03:00
|
|
|
del main
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
return jaxpr, out_pvals, consts
|
|
|
|
|
|
|
|
def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
|
|
|
|
instantiate: Union[bool, Sequence[bool]],
|
2020-09-15 08:06:46 -07:00
|
|
|
trace_type: Optional[Type[core.Trace]]
|
2020-07-30 12:59:36 -07:00
|
|
|
) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]:
|
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
|
|
|
|
|
|
|
cell = []
|
|
|
|
def fun(*vals):
|
|
|
|
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
|
|
|
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
2020-09-15 08:06:46 -07:00
|
|
|
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
|
|
|
|
trace_type=trace_type)
|
2020-07-30 12:59:36 -07:00
|
|
|
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
|
|
|
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
|
|
|
return out_consts_2 + consts_2
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
# The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
|
|
|
|
# the avals, and it will never actually reach any primitives, because the `fun` above will
|
|
|
|
# execute the jaxpr with the right avals (it reconstructs `pvals` inside).
|
|
|
|
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
|
|
|
|
for aval, uk in zip(jaxpr.in_avals, unknowns)]
|
|
|
|
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
|
2020-07-30 12:59:36 -07:00
|
|
|
(out_pvs_2, jaxpr_2, num_res), = cell
|
|
|
|
assert len(jaxpr_2.constvars) == num_res
|
|
|
|
|
|
|
|
# jaxpr :: a -> b
|
|
|
|
# jaxpr_1 :: a1 -> [b1, res]
|
|
|
|
# jaxpr_2 :: res | a2 -> b2
|
|
|
|
# jaxpr_2 :: [a2, res] -> b2
|
|
|
|
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
|
|
|
|
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
|
|
|
|
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
|
|
|
|
if not unknown:
|
|
|
|
var.aval = abstract_unit
|
|
|
|
|
|
|
|
uk_out = [pv is not None for pv in out_pvs_2]
|
|
|
|
|
|
|
|
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
|
|
|
|
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
|
|
|
|
# out_avals_1 and in_avals_2 need the residuals added
|
2020-09-15 08:06:46 -07:00
|
|
|
out_pvs, _ = unzip2(out_pvals)
|
|
|
|
res_avals = out_pvs[len(jaxpr.out_avals):]
|
2020-07-30 12:59:36 -07:00
|
|
|
assert len(res_avals) == num_res
|
2020-09-15 08:06:46 -07:00
|
|
|
out_avals_1 = out_avals_1 + res_avals
|
|
|
|
in_avals_2 = in_avals_2 + res_avals
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1)
|
|
|
|
typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2)
|
|
|
|
return typed_jaxpr_1, typed_jaxpr_2, uk_out
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
# See comment at top of `JaxprTrace`. This method should be reachable
|
|
|
|
# only when we stage out, and in that case we drop the custom differentiation
|
|
|
|
# rules, because we do not need them.
|
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
assert self.main.trace_type is StagingJaxprTrace
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
|
|
|
# See comment in the above process_custom_jvp_call method.
|
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
assert self.main.trace_type is StagingJaxprTrace
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
JaxprTrace.process_custom_vjp_call = process_custom_vjp_call
|
|
|
|
|
|
|
|
staged_out_calls = set()
|
|
|
|
|
|
|
|
class StagingJaxprTrace(JaxprTrace): pass
|