2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
import itertools as it
|
2020-03-09 20:42:08 +01:00
|
|
|
from collections import namedtuple
|
2019-11-27 19:15:53 -08:00
|
|
|
import contextlib
|
|
|
|
import threading
|
2020-04-17 20:08:24 +03:00
|
|
|
from typing import Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union
|
2019-11-19 12:26:30 -08:00
|
|
|
from weakref import ref
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-28 22:38:06 -07:00
|
|
|
import numpy as onp
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from .. import core
|
|
|
|
from .. import linear_util as lu
|
2019-11-22 10:53:11 -08:00
|
|
|
from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped
|
2020-05-01 09:16:31 +03:00
|
|
|
from ..ad_util import zero
|
2020-01-26 23:27:56 -08:00
|
|
|
from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
|
2020-02-05 13:55:59 +01:00
|
|
|
wrap_name, cache)
|
2019-11-19 12:26:30 -08:00
|
|
|
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
|
2020-03-18 07:11:44 +01:00
|
|
|
AbstractValue, unit, unitvar, abstract_unit,
|
|
|
|
TypedJaxpr, new_jaxpr_eqn)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
|
|
|
zip = safe_zip
|
2019-02-15 06:35:54 -08:00
|
|
|
def identity(x): return x
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
class PartialVal(tuple):
|
|
|
|
"""Partial value: either a known value or an unknown (abstract) value.
|
|
|
|
|
|
|
|
Represented as a pair `(aval_opt, const)` of one of two kinds:
|
|
|
|
* `(None, <Constant>)` indicates a known value, either a Python regular
|
|
|
|
value, or a Tracer.
|
|
|
|
* `(<AbstractValue>, *)` indicates an unknown value characterized by an
|
|
|
|
abstract value.
|
|
|
|
"""
|
|
|
|
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
|
|
|
|
pv, const = xs
|
|
|
|
if not core.skip_checks:
|
|
|
|
# type checks
|
|
|
|
assert isinstance(pv, (AbstractValue, type(None))), xs
|
2020-05-01 09:16:31 +03:00
|
|
|
assert isinstance(const, core.Tracer) or const is zero or core.valid_jaxtype(const), xs
|
2020-03-18 07:11:44 +01:00
|
|
|
# invariant checks
|
|
|
|
if isinstance(pv, AbstractValue):
|
|
|
|
assert const == core.unit, xs
|
|
|
|
return tuple.__new__(cls, xs)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def known(cls, const: core.Value) -> 'PartialVal':
|
|
|
|
return PartialVal((None, const))
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def unknown(cls, aval: AbstractValue) -> 'PartialVal':
|
|
|
|
return PartialVal((aval, core.unit))
|
|
|
|
|
|
|
|
def is_known(self):
|
|
|
|
return self[0] is None
|
|
|
|
|
|
|
|
def get_known(self) -> Optional[core.Value]:
|
|
|
|
"""Get the known value, if known, else None."""
|
|
|
|
return self[1] if self[0] is None else None
|
|
|
|
|
|
|
|
def get_aval(self) -> AbstractValue:
|
|
|
|
"""Get the AbstractValue either directly for unknown values, or from the known constant."""
|
|
|
|
known = self.get_known()
|
|
|
|
if known is not None:
|
|
|
|
return get_aval(known)
|
|
|
|
else:
|
|
|
|
return self[0]
|
|
|
|
|
|
|
|
def merge_with_known(self, val: core.Value) -> core.Value:
|
|
|
|
"""Either the stored known value, or the given 'val'."""
|
|
|
|
known = self.get_known()
|
|
|
|
return known if known is not None else val
|
2019-05-09 07:23:39 -07:00
|
|
|
|
|
|
|
|
2020-04-17 20:08:24 +03:00
|
|
|
# We form Jaxprs using `JaxprTrace` for three distinct purposes:
|
|
|
|
# (1) to stage program representations completely out of the JAX system
|
|
|
|
# (e.g. for XLA using jit or pmap). In this case we are using the
|
|
|
|
# `StagingJaxprTrace` subclass.
|
|
|
|
# (3) to linearize a function for reverse-mode AD. In this case we are
|
|
|
|
# using the `JaxprTrace` subclass.
|
|
|
|
# (2) to build a representation of a function that may require further JAX
|
|
|
|
# transformations (e.g. in "initial-style" higher-order primitives, like
|
|
|
|
# for control flow). In this case we use the `JaxprTrace` class.
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTrace(Trace):
|
|
|
|
def pure(self, val):
|
2019-09-20 15:35:43 -07:00
|
|
|
return self.new_const(val)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lift(self, val):
|
|
|
|
return self.new_const(val)
|
|
|
|
|
|
|
|
def sublift(self, val):
|
|
|
|
return JaxprTracer(self, val.pval, FreeVar(val))
|
|
|
|
|
|
|
|
def new_const(self, val):
|
2020-01-29 16:23:27 -05:00
|
|
|
if isinstance(val, Tracer) and val._trace.level == self.level:
|
2018-11-17 18:03:33 -08:00
|
|
|
raise Exception
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.known(val), unit)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-05-13 08:48:13 -07:00
|
|
|
def new_instantiated_literal(self, val):
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), Literal(val))
|
2019-05-13 08:48:13 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def new_instantiated_const(self, val):
|
2020-03-18 07:11:44 +01:00
|
|
|
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def new_arg(self, pval: PartialVal):
|
2018-11-17 18:03:33 -08:00
|
|
|
return JaxprTracer(self, pval, LambdaBinding())
|
|
|
|
|
|
|
|
def instantiate_const(self, tracer):
|
2020-03-18 07:11:44 +01:00
|
|
|
const = tracer.pval.get_known()
|
|
|
|
if const is None:
|
2018-11-17 18:03:33 -08:00
|
|
|
return tracer
|
2020-03-18 07:11:44 +01:00
|
|
|
else:
|
2019-09-20 15:35:43 -07:00
|
|
|
if type(const) in core.literalable_types and onp.shape(const) == ():
|
|
|
|
return self.new_instantiated_literal(const)
|
2019-05-13 08:48:13 -07:00
|
|
|
else:
|
|
|
|
return self.new_instantiated_const(const)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
def instantiate_const_abstracted(self, tracer):
|
2020-03-18 07:11:44 +01:00
|
|
|
const = tracer.pval.get_known()
|
|
|
|
if const is None:
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
return tracer
|
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
aval = raise_to_shaped(get_aval(const), onp.isscalar(const))
|
|
|
|
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def process_primitive(self, primitive, tracers, params):
|
2019-04-01 16:03:56 -04:00
|
|
|
if primitive in custom_partial_eval_rules:
|
2019-07-27 15:46:14 -07:00
|
|
|
return custom_partial_eval_rules[primitive](self, *tracers, **params)
|
2019-04-01 16:03:56 -04:00
|
|
|
else:
|
2020-01-30 15:03:00 -08:00
|
|
|
return self.default_process_primitive(primitive, tracers, params)
|
|
|
|
|
|
|
|
def default_process_primitive(self, primitive, tracers, params):
|
2020-04-17 20:08:24 +03:00
|
|
|
"""By default, if all the input tracers are known, then execute the primitive
|
|
|
|
and all the ouputs are known. Otherwise, all the outputs are unknown."""
|
2020-03-18 07:11:44 +01:00
|
|
|
consts = tuple(t.pval.get_known() for t in tracers)
|
|
|
|
if all(c is not None for c in consts):
|
2020-02-09 21:06:37 -08:00
|
|
|
return primitive.bind(*consts, **params)
|
|
|
|
tracers = map(self.instantiate_const, tracers)
|
|
|
|
avals = [t.aval for t in tracers]
|
|
|
|
out_aval = primitive.abstract_eval(*avals, **params)
|
|
|
|
if primitive.multiple_results:
|
2020-03-18 07:11:44 +01:00
|
|
|
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
|
2020-02-09 21:06:37 -08:00
|
|
|
for aval in out_aval]
|
|
|
|
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params)
|
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
|
2020-02-09 21:06:37 -08:00
|
|
|
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, params)
|
|
|
|
return out_tracer
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
2020-01-26 23:27:56 -08:00
|
|
|
name = params.get('name', f.__name__)
|
2020-04-24 18:19:24 -07:00
|
|
|
if (self.master.trace_type is StagingJaxprTrace
|
|
|
|
and call_primitive in staged_out_calls):
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
tracers = map(self.instantiate_const_abstracted, tracers)
|
2020-01-26 23:27:56 -08:00
|
|
|
params = dict(params, name=name)
|
2020-04-21 18:12:02 -07:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
if call_primitive in call_partial_eval_rules:
|
2020-01-15 15:00:38 -08:00
|
|
|
return call_partial_eval_rules[call_primitive](self, call_primitive, f, tracers, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
|
|
|
fun, aux = partial_eval(f, self, in_pvs)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat = call_primitive.bind(fun, *in_consts, **params)
|
2019-07-26 16:48:17 -04:00
|
|
|
out_pvs, jaxpr, env = aux()
|
2020-01-07 13:11:32 -08:00
|
|
|
env_tracers = map(self.full_raise, env)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
2020-05-20 14:30:33 -07:00
|
|
|
if not jaxpr.eqns:
|
|
|
|
env = {core.unitvar: core.unit}
|
|
|
|
map(env.setdefault, jaxpr.invars, (*env_tracers, *tracers))
|
|
|
|
map(env.setdefault, jaxpr.constvars, consts)
|
|
|
|
return [pv_const if pv is None else v.val if type(v) is Literal else env[v]
|
|
|
|
for v, pv, pv_const in zip(jaxpr.outvars, out_pvs, out_pv_consts)]
|
2018-11-17 18:03:33 -08:00
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
2020-02-03 20:58:56 +01:00
|
|
|
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
|
2019-07-26 18:01:38 -04:00
|
|
|
out_tracers = [JaxprTracer(self, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2020-02-05 15:38:25 +01:00
|
|
|
new_params = dict(params, call_jaxpr=lifted_jaxpr)
|
2020-01-07 13:11:32 -08:00
|
|
|
# The `jaxpr` already contains the env_vars at start of invars
|
2020-02-03 20:58:56 +01:00
|
|
|
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
|
2020-02-05 15:38:25 +01:00
|
|
|
out_tracers, call_primitive, new_params)
|
2019-07-26 18:01:38 -04:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-21 18:12:02 -07:00
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
|
|
|
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
|
|
|
out = out_pv_consts + consts
|
|
|
|
del consts, out_pv_consts
|
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
|
|
|
n = len(jaxpr.outvars)
|
|
|
|
out_pv_consts, consts = x[:n], x[n:]
|
|
|
|
trace = JaxprTrace(master, core.cur_sublevel())
|
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
|
|
|
env_tracers = map(trace.full_raise, env)
|
|
|
|
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
|
|
|
|
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
|
|
|
new_params = dict(params, call_jaxpr=lifted_jaxpr)
|
|
|
|
# The `jaxpr` already contains the env_vars at start of invars
|
|
|
|
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
|
|
|
|
out_tracers, call_primitive, new_params)
|
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
return out, todo
|
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
2020-04-21 18:12:02 -07:00
|
|
|
name = params.get('name', f.__name__)
|
|
|
|
if self.master.trace_type is StagingJaxprTrace:
|
|
|
|
tracers = map(self.instantiate_const_abstracted, tracers)
|
|
|
|
else:
|
|
|
|
name = wrap_name(name, 'pe')
|
|
|
|
|
|
|
|
params = dict(params, name=name)
|
2019-02-20 12:36:18 -08:00
|
|
|
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
reduced_pvs = [None if pv is None else
|
2020-05-21 13:11:58 -07:00
|
|
|
core.mapped_aval(params['axis_size'], pv) if m else pv
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
for pv, m in zip(in_pvs, params['mapped_invars'])]
|
2019-02-21 11:47:26 -08:00
|
|
|
fun, aux = partial_eval(f, self, reduced_pvs)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_flat = map_primitive.bind(fun, *in_consts, **params)
|
|
|
|
out_pvs_reduced, jaxpr, env = aux()
|
|
|
|
out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
2020-05-21 13:11:58 -07:00
|
|
|
out_pvs = [None if pv is None else core.unmapped_aval(params['axis_size'], pv)
|
2019-07-27 15:46:14 -07:00
|
|
|
for pv in out_pvs_reduced]
|
2019-02-20 12:36:18 -08:00
|
|
|
const_tracers = map(self.new_instantiated_const, consts)
|
2020-01-07 13:11:32 -08:00
|
|
|
env_tracers = map(self.full_raise, env)
|
2020-02-03 20:58:56 +01:00
|
|
|
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_tracers = [JaxprTracer(self, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2020-01-07 13:11:32 -08:00
|
|
|
# The `jaxpr` already contains the env_vars at start of invars
|
2020-02-06 09:44:34 +01:00
|
|
|
new_params = dict(params,
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
mapped_invars=((True,) * len(const_tracers) +
|
|
|
|
(False,) * len(env_tracers) +
|
|
|
|
params['mapped_invars']),
|
2020-02-05 15:38:25 +01:00
|
|
|
call_jaxpr=lifted_jaxpr)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
assert (len(new_params['mapped_invars'])
|
|
|
|
== len(const_tracers) + len(env_tracers) + len(tracers))
|
2020-01-07 13:11:32 -08:00
|
|
|
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
|
2020-02-05 15:38:25 +01:00
|
|
|
out_tracers, map_primitive, new_params)
|
2019-07-27 15:46:14 -07:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
2019-09-20 07:01:01 -07:00
|
|
|
def post_process_map(self, map_primitive, out_tracers, params):
|
|
|
|
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
|
|
|
|
out_pvs_reduced, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
2020-05-21 13:11:58 -07:00
|
|
|
out_pvs = [None if pv is None
|
|
|
|
else core.unmapped_aval(params['axis_size'], pv)
|
2019-09-20 07:01:01 -07:00
|
|
|
for pv in out_pvs_reduced]
|
|
|
|
out = out_pv_consts + consts
|
|
|
|
del consts, out_pv_consts
|
|
|
|
master = self.master
|
|
|
|
def todo(x):
|
|
|
|
n = len(jaxpr.outvars)
|
|
|
|
out_pv_consts, consts = x[:n], x[n:]
|
|
|
|
trace = JaxprTrace(master, core.cur_sublevel())
|
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
2020-01-07 13:11:32 -08:00
|
|
|
# The `jaxpr` already contains the env_vars at start of invars
|
2020-02-03 20:58:56 +01:00
|
|
|
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
|
2019-09-20 07:01:01 -07:00
|
|
|
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
|
|
|
|
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
|
2020-02-06 09:44:34 +01:00
|
|
|
new_params = dict(params,
|
|
|
|
mapped_invars=tuple([True] * len(const_tracers) +
|
2020-02-05 15:38:25 +01:00
|
|
|
[False] * len(env)),
|
|
|
|
call_jaxpr=lifted_jaxpr)
|
2020-01-07 13:11:32 -08:00
|
|
|
env_tracers = map(trace.full_raise, env)
|
|
|
|
eqn = new_eqn_recipe(it.chain(const_tracers, env_tracers),
|
2020-02-05 15:38:25 +01:00
|
|
|
out_tracers, map_primitive, new_params)
|
2019-09-20 07:01:01 -07:00
|
|
|
for t in out_tracers:
|
|
|
|
t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
return out, todo
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
2020-04-17 20:08:24 +03:00
|
|
|
# See comment at top of `JaxprTrace`. This method should be reachable
|
|
|
|
# only when we stage out, and in that case we drop the custom differentiation
|
|
|
|
# rules, because we do not need them.
|
2020-03-28 14:15:46 -07:00
|
|
|
assert self.master.trace_type is StagingJaxprTrace
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
2020-03-30 11:57:03 -07:00
|
|
|
# See comment in the above process_custom_jvp_call method.
|
2020-03-28 14:15:46 -07:00
|
|
|
assert self.master.trace_type is StagingJaxprTrace
|
|
|
|
return fun.call_wrapped(*tracers)
|
|
|
|
|
2020-04-17 20:08:24 +03:00
|
|
|
# This subclass is used just for its type tag (see comment for `JaxprTrace`)
|
|
|
|
# This switches the behavior of process_call to stage out into the jaxpr any
|
|
|
|
# call primitives encountered (rather than doing partial evaluation into the call).
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
class StagingJaxprTrace(JaxprTrace):
|
|
|
|
pass
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
|
|
|
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
|
2020-04-24 18:19:24 -07:00
|
|
|
staged_out_calls: Set[core.Primitive] = set()
|
2019-02-21 11:47:26 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def partial_eval(f, trace, pvs: Sequence[Optional[AbstractValue]], instantiate=False):
|
2020-03-28 14:15:46 -07:00
|
|
|
f = trace_to_subjaxpr(f, trace.master, instantiate)
|
2018-11-17 18:03:33 -08:00
|
|
|
return partial_eval_wrapper(f, tuple(pvs))
|
|
|
|
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2020-03-18 07:11:44 +01:00
|
|
|
def partial_eval_wrapper(avals: Sequence[Optional[AbstractValue]], *consts):
|
2019-04-10 22:09:14 -07:00
|
|
|
py_args = (map(PartialVal, zip(avals, consts)),)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, (out_pvals, consts, env) = yield py_args, {}
|
|
|
|
out_pvs, out_consts = unzip2(out_pvals)
|
|
|
|
out = tuple(out_consts) + tuple(consts) # TODO: can consts be traced?
|
|
|
|
yield out, (out_pvs, jaxpr, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-13 14:28:30 -08:00
|
|
|
def abstract_eval_fun(fun, *avals, **params):
|
2020-03-18 07:11:44 +01:00
|
|
|
pvals_in = [PartialVal.unknown(a) for a in avals]
|
2019-07-27 15:46:14 -07:00
|
|
|
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
|
2020-03-29 20:51:51 -07:00
|
|
|
instantiate=True, stage_out=True)
|
2019-07-27 15:46:14 -07:00
|
|
|
avals_out, _ = unzip2(pvals_out)
|
2019-07-26 16:48:17 -04:00
|
|
|
for aval_out in avals_out:
|
|
|
|
assert isinstance(aval_out, AbstractValue) # instantiate=True
|
|
|
|
return avals_out
|
2019-02-13 14:28:30 -08:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class JaxprTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['pval', 'recipe']
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def __init__(self, trace, pval: PartialVal, recipe):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert isinstance(pval, PartialVal)
|
|
|
|
pv, const = pval
|
2020-02-15 06:35:49 +01:00
|
|
|
if isinstance(const, Tracer) and const._trace.level >= trace.level:
|
2020-03-28 14:55:58 -07:00
|
|
|
raise core.escaped_tracer_error(
|
|
|
|
"Tracer from a higher level: {} in trace {}".format(const, trace))
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.pval = pval
|
|
|
|
self.recipe = recipe
|
|
|
|
|
|
|
|
def __repr__(self):
|
2020-01-29 16:23:27 -05:00
|
|
|
return 'Traced<{}:{}>'.format(self.aval, self._trace)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
2020-03-18 07:11:44 +01:00
|
|
|
return self.pval.get_aval()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def parents(self):
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(self.recipe, JaxprEqnRecipe):
|
2020-02-03 20:58:56 +01:00
|
|
|
return self.recipe.invars
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return []
|
|
|
|
|
|
|
|
def full_lower(self):
|
2020-03-18 07:11:44 +01:00
|
|
|
known = self.pval.get_known()
|
|
|
|
if known is not None:
|
|
|
|
return core.full_lower(known)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
# TODO(necula): this should return a TypedJaxpr
|
2020-04-17 20:08:24 +03:00
|
|
|
# TODO(necula): remove stage_out, replace trace_type=pe.StagingJaxprTrace
|
2020-03-21 13:54:30 +01:00
|
|
|
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
2020-04-09 14:10:52 +02:00
|
|
|
instantiate: Union[bool, Sequence[bool]] = False,
|
2020-04-17 20:08:24 +03:00
|
|
|
stage_out=False, bottom=False,
|
|
|
|
trace_type: Optional[Type[Trace]] = None) \
|
2020-04-09 14:10:52 +02:00
|
|
|
-> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
2020-03-18 07:11:44 +01:00
|
|
|
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
|
|
|
|
2020-04-17 20:08:24 +03:00
|
|
|
`trace_type` can be one of `StagingJaxprTrace` or `JaxprTrace` (see
|
|
|
|
comments for that class).
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
Returns (`jaxpr`, `out_pvals`, `consts`).
|
|
|
|
The `jaxpr` contains only the computation that depends on unknown inputs.
|
|
|
|
The `out_pvals` are the PartialVal for the outputs. The intermediate
|
|
|
|
values that depend only on known inputs and are needed to compute the output
|
|
|
|
of `jaxpr` are in `consts` and are passed in as the constvars of
|
|
|
|
the `jaxpr`. The handling of the known outputs depends on `instantiate`.
|
|
|
|
|
|
|
|
For example, given `fun` defined as follows::
|
|
|
|
|
|
|
|
def fun(ki, ui): # ki will be a known input in this example
|
|
|
|
ka = ki + 2
|
|
|
|
kb = ka + 3
|
|
|
|
return (kb, ui + ka)
|
|
|
|
|
|
|
|
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
|
|
|
|
computation that depends on unknown inputs is `ui + ka` and will be the only
|
|
|
|
computation in the body of the `jaxpr`. This computation depends on the
|
|
|
|
known intermediate value `ka`, which will be computed statically. Currently,
|
|
|
|
such constants are either embedded in the Jaxpr if they are scalars, or
|
|
|
|
passed as a constvar to `jaxpr`, and then the value of the actual constant
|
|
|
|
will be in `consts`:
|
|
|
|
|
|
|
|
When `instantiate=False` we get::
|
|
|
|
|
|
|
|
jaxpr =
|
|
|
|
{ lambda ka ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (*, c) } # known outputs are `*`
|
|
|
|
out_pvals = [known(6), unknown(ShapedArray)] # the known outputs are known PartialVal
|
|
|
|
consts = [3] # the constant for `ka`
|
|
|
|
|
|
|
|
When `instantiate=True` we get::
|
|
|
|
|
|
|
|
jaxpr =
|
|
|
|
{ lambda ka kb ; ki ui.
|
|
|
|
let c = add ui ka
|
|
|
|
in (kb, c) } # known output are explicit
|
|
|
|
out_pvals = [abstract(ConcreteArray(6)), abstract(ShapedArray)] # all are unknown PartialVal
|
|
|
|
consts = [3, 6] # values for `ka` and `kb` constvars
|
|
|
|
"""
|
2020-04-17 20:08:24 +03:00
|
|
|
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
|
2020-02-24 21:20:41 +00:00
|
|
|
with new_master(trace_type, bottom=bottom) as master:
|
2019-04-01 16:03:56 -04:00
|
|
|
fun = trace_to_subjaxpr(fun, master, instantiate)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
2018-11-17 18:03:33 -08:00
|
|
|
assert not env
|
|
|
|
del master
|
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
return jaxpr, out_pvals, consts
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation
|
2020-03-18 07:11:44 +01:00
|
|
|
def trace_to_subjaxpr(master: core.MasterTrace, instantiate: Union[bool, Sequence[bool]],
|
|
|
|
pvals: Sequence[PartialVal]):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
|
|
|
trace = JaxprTrace(master, core.cur_sublevel())
|
|
|
|
in_tracers = map(trace.new_arg, pvals)
|
2019-06-23 15:31:13 -07:00
|
|
|
ans = yield in_tracers, {}
|
2020-03-18 07:11:44 +01:00
|
|
|
instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
|
2019-07-27 15:46:14 -07:00
|
|
|
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
|
|
|
|
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
|
2019-07-26 16:48:17 -04:00
|
|
|
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
|
|
|
|
out_pvals = [t.pval for t in out_tracers]
|
|
|
|
del trace, in_tracers, out_tracers
|
|
|
|
yield jaxpr, (out_pvals, consts, env)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def instantiate_const_at(trace, instantiate: bool, tracer):
|
2019-07-27 15:46:14 -07:00
|
|
|
if instantiate:
|
|
|
|
return trace.instantiate_const(trace.full_raise(tracer))
|
2019-05-10 08:58:05 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return tracer
|
2019-05-10 08:58:05 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
FreeVar = namedtuple('FreeVar', ['val'])
|
|
|
|
ConstVar = namedtuple('ConstVar', ['val'])
|
|
|
|
LambdaBinding = namedtuple('LambdaBinding', [])
|
2019-11-19 12:26:30 -08:00
|
|
|
JaxprEqnRecipe = namedtuple('JaxprEqnRecipe',
|
2020-02-05 15:38:25 +01:00
|
|
|
['eqn_id', 'invars', 'outvars', 'primitive', 'params'])
|
2019-11-19 12:26:30 -08:00
|
|
|
|
2020-02-05 15:38:25 +01:00
|
|
|
def new_eqn_recipe(invars, outvars, primitive, params):
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Constructs a new JaxEqnRecipe.
|
|
|
|
|
|
|
|
Params:
|
|
|
|
invars: the tracers for the primitive inputs.
|
|
|
|
outvars: the tracers for the primitive outputs.
|
|
|
|
primitive: the primitive.
|
|
|
|
params: the primitive params
|
|
|
|
"""
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
|
|
|
|
if primitive.call_primitive or primitive.map_primitive:
|
2020-02-05 15:38:25 +01:00
|
|
|
assert "call_jaxpr" in params
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
if primitive.map_primitive:
|
|
|
|
assert "mapped_invars" in params
|
2020-01-07 13:11:32 -08:00
|
|
|
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
|
2020-02-05 15:38:25 +01:00
|
|
|
params)
|
|
|
|
|
2019-11-19 12:26:30 -08:00
|
|
|
|
2019-11-20 09:12:15 -08:00
|
|
|
def recipe_to_eqn(unused_var, getvar, recipe):
|
2020-02-05 15:38:25 +01:00
|
|
|
_, in_tracers, out_tracer_refs, primitive, params = recipe
|
2019-11-20 09:12:15 -08:00
|
|
|
out_tracers = [t_ref() for t_ref in out_tracer_refs]
|
|
|
|
invars = [getvar(t) for t in in_tracers]
|
|
|
|
outvars = [unused_var() if t is None else getvar(t) for t in out_tracers]
|
2020-02-05 15:38:25 +01:00
|
|
|
return new_jaxpr_eqn(invars, outvars, primitive, params)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
def tracers_to_jaxpr(in_tracers, out_tracers):
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Constructs Jaxpr given tracers for inputs and outputs.
|
|
|
|
|
|
|
|
Params:
|
|
|
|
in_tracers: the tracers that were created for the function inputs
|
|
|
|
out_tracers: the tracers that were output by the function.
|
|
|
|
|
|
|
|
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
|
|
|
|
the `constvars` in the returned Jaxps, and a list of environment values.
|
2020-03-18 07:11:44 +01:00
|
|
|
The vars for the environment values have been prepended to the Jaxpr's
|
2020-01-07 13:11:32 -08:00
|
|
|
`invars`.
|
|
|
|
"""
|
2020-05-26 11:28:50 -07:00
|
|
|
newvar = core.gensym()
|
2020-03-09 09:14:23 +00:00
|
|
|
t_to_var = {}
|
|
|
|
def getvar(t):
|
|
|
|
var = t_to_var.get(id(t))
|
|
|
|
if var is None:
|
2020-03-18 07:11:44 +01:00
|
|
|
aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
|
2020-04-08 22:29:07 -07:00
|
|
|
var = t_to_var[id(t)] = newvar(aval)
|
2020-03-09 09:14:23 +00:00
|
|
|
return var
|
2019-07-26 16:48:17 -04:00
|
|
|
sorted_tracers = toposort(out_tracers)
|
2019-11-20 09:12:15 -08:00
|
|
|
invars = map(getvar, in_tracers)
|
2018-11-17 18:03:33 -08:00
|
|
|
eqns = []
|
|
|
|
env = {}
|
|
|
|
consts = {}
|
2020-03-09 09:14:23 +00:00
|
|
|
const_to_var = {}
|
|
|
|
def getconstvar(c):
|
|
|
|
var = const_to_var.get(id(c))
|
|
|
|
if var is None:
|
2020-04-08 22:29:07 -07:00
|
|
|
var = const_to_var[id(c)] = newvar(get_aval(c))
|
2020-03-09 09:14:23 +00:00
|
|
|
return var
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids = set()
|
2018-11-17 18:03:33 -08:00
|
|
|
for t in sorted_tracers:
|
|
|
|
recipe = t.recipe
|
2019-11-19 12:26:30 -08:00
|
|
|
if isinstance(recipe, JaxprEqnRecipe):
|
2019-11-20 09:12:15 -08:00
|
|
|
if recipe.eqn_id not in processed_eqn_ids:
|
2020-03-09 09:14:23 +00:00
|
|
|
eqns.append(recipe_to_eqn(lambda: newvar(core.abstract_unit), getvar, recipe))
|
2019-11-20 09:12:15 -08:00
|
|
|
processed_eqn_ids.add(recipe.eqn_id)
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, LambdaBinding):
|
2020-02-15 06:35:49 +01:00
|
|
|
if not any(t is in_tracer for in_tracer in in_tracers):
|
2020-03-28 14:55:58 -07:00
|
|
|
raise core.escaped_tracer_error(
|
|
|
|
"Tracer not among input tracers {}".format(t))
|
2018-11-17 18:03:33 -08:00
|
|
|
assert in_tracers, "Lambda binding with no args"
|
|
|
|
elif isinstance(recipe, FreeVar):
|
2019-11-20 09:12:15 -08:00
|
|
|
env[getvar(t)] = recipe.val
|
2018-11-17 18:03:33 -08:00
|
|
|
elif isinstance(recipe, ConstVar):
|
2020-03-09 09:14:23 +00:00
|
|
|
v = t_to_var[id(t)] = getconstvar(recipe.val)
|
2019-06-18 08:09:37 -07:00
|
|
|
consts[v] = recipe.val
|
2019-05-13 08:48:13 -07:00
|
|
|
elif isinstance(recipe, Literal):
|
|
|
|
t_to_var[id(t)] = recipe
|
2018-11-17 18:03:33 -08:00
|
|
|
elif recipe is unit:
|
2018-11-21 13:20:44 -08:00
|
|
|
t_to_var[id(t)] = unitvar
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
raise TypeError(recipe)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
env_vars, env_vals = unzip2(env.items())
|
|
|
|
const_vars, const_vals = unzip2(consts.items())
|
2020-01-07 13:11:32 -08:00
|
|
|
# The env_vars are pre-pended to the invars
|
|
|
|
jaxpr = Jaxpr(const_vars, list(it.chain(env_vars, invars)), list(map(getvar, out_tracers)), eqns)
|
2019-09-20 15:35:43 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2018-11-17 18:03:33 -08:00
|
|
|
return jaxpr, const_vals, env_vals
|
|
|
|
|
2020-02-05 13:55:59 +01:00
|
|
|
@cache()
|
2020-02-03 20:58:56 +01:00
|
|
|
def convert_constvars_jaxpr(jaxpr):
|
2020-01-07 13:11:32 -08:00
|
|
|
"""Moves the constvars to the start of invars."""
|
2019-05-08 16:27:23 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(jaxpr)
|
2020-01-07 13:11:32 -08:00
|
|
|
lifted_jaxpr = Jaxpr(constvars=(),
|
2019-07-27 15:46:14 -07:00
|
|
|
invars=jaxpr.constvars + jaxpr.invars,
|
|
|
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
2019-05-01 15:47:01 -07:00
|
|
|
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
2019-04-09 08:45:34 -07:00
|
|
|
return lifted_jaxpr
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def partial_eval_jaxpr(jaxpr: TypedJaxpr, unknowns: Sequence[bool],
|
2020-04-17 20:08:24 +03:00
|
|
|
instantiate: Union[bool, Sequence[bool]],
|
|
|
|
trace_type: Optional[Type[core.Trace]]
|
2020-04-09 14:10:52 +02:00
|
|
|
) -> Tuple[TypedJaxpr, TypedJaxpr, Sequence[bool]]:
|
2020-03-18 07:11:44 +01:00
|
|
|
"""Specializes a Jaxpr given an indication of which inputs are known.
|
|
|
|
|
|
|
|
Returns: (jaxpr_known, jaxpr_unknown, out_unknowns).
|
|
|
|
|
|
|
|
`out_unknowns` specifies which outputs are unknown (depend on some unknown inputs).
|
|
|
|
`jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs,
|
|
|
|
and performs *all* the computation in `jaxpr` that depends only on the known inputs.
|
|
|
|
Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`,
|
|
|
|
appended with the known residuals (the intermediate computations in `jaxpr`
|
|
|
|
that depend only on known inputs and that are needed to compute the unknown outputs).
|
|
|
|
|
|
|
|
`jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals
|
|
|
|
computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known
|
|
|
|
outputs replaced by `*`.
|
|
|
|
|
|
|
|
Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively
|
|
|
|
unknown inputs into:
|
|
|
|
|
|
|
|
jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *)
|
|
|
|
let _, uout = jaxpr_unknown(ki, ui, kresidual)
|
|
|
|
in (kout, uout)
|
|
|
|
|
|
|
|
For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2
|
|
|
|
in (ki + 3, ui + ka)"
|
|
|
|
then
|
|
|
|
`jaxpr_known` = lambda ki, ui: let ka = ki + 2
|
|
|
|
in (ki + 3, *, ka)
|
|
|
|
'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka)
|
|
|
|
"""
|
2019-04-23 09:15:16 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2019-04-11 14:50:58 -07:00
|
|
|
|
|
|
|
cell = []
|
|
|
|
def fun(*vals):
|
2020-03-18 07:11:44 +01:00
|
|
|
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
2019-07-27 15:46:14 -07:00
|
|
|
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
2020-04-17 20:08:24 +03:00
|
|
|
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
|
|
|
|
trace_type=trace_type)
|
2019-07-27 15:46:14 -07:00
|
|
|
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
|
|
|
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
|
|
|
return out_consts_2 + consts_2
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
|
|
|
|
# known inputs.
|
|
|
|
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
|
2019-07-27 15:46:14 -07:00
|
|
|
for aval, uk in zip(jaxpr.in_avals, unknowns)]
|
|
|
|
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
|
|
|
|
(out_pvs_2, jaxpr_2, num_res), = cell
|
|
|
|
assert len(jaxpr_2.constvars) == num_res
|
|
|
|
|
|
|
|
# jaxpr :: a -> b
|
|
|
|
# jaxpr_1 :: a1 -> [b1, res]
|
|
|
|
# jaxpr_2 :: res | a2 -> b2
|
|
|
|
# jaxpr_2 :: [a2, res] -> b2
|
2020-02-03 20:58:56 +01:00
|
|
|
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
|
2020-03-09 09:14:23 +00:00
|
|
|
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
|
|
|
|
if not unknown:
|
|
|
|
var.aval = abstract_unit
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
uk_out = [pv is not None for pv in out_pvs_2]
|
|
|
|
|
|
|
|
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
|
|
|
|
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
|
|
|
|
# out_avals_1 and in_avals_2 need the residuals added
|
|
|
|
out_pvs, _ = unzip2(out_pvals)
|
|
|
|
res_avals = out_pvs[len(jaxpr.out_avals):]
|
|
|
|
assert len(res_avals) == num_res
|
|
|
|
out_avals_1 = out_avals_1 + res_avals
|
|
|
|
in_avals_2 = in_avals_2 + res_avals
|
|
|
|
|
|
|
|
typed_jaxpr_1 = TypedJaxpr(jaxpr_1, consts_1, in_avals_1, out_avals_1)
|
|
|
|
typed_jaxpr_2 = TypedJaxpr(jaxpr_2, (), in_avals_2, out_avals_2)
|
|
|
|
return typed_jaxpr_1, typed_jaxpr_2, uk_out
|
|
|
|
|
|
|
|
def _split_aval(unknown, aval):
|
|
|
|
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
|
2019-04-11 14:50:58 -07:00
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
remat_call_p = core.Primitive('remat_call')
|
2020-02-05 15:38:25 +01:00
|
|
|
remat_call_p.call_primitive = True
|
2019-11-22 10:53:11 -08:00
|
|
|
remat_call = partial(core.call_bind, remat_call_p)
|
|
|
|
remat_call_p.def_custom_bind(remat_call)
|
|
|
|
remat_call_p.def_impl(core.call_impl)
|
|
|
|
remat_call_p.multiple_results = True
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
def _remat_partial_eval(trace, _, f, tracers, params):
|
2019-11-22 10:53:11 -08:00
|
|
|
concrete = params['concrete']
|
|
|
|
|
|
|
|
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
|
|
|
# the function being called, not just for the unknown parts. To do that, we
|
|
|
|
# instantiate all the input tracers as constants in the jaxpr being formed.
|
|
|
|
# Those tracers might have concrete avals, and doing abstract interpretation
|
|
|
|
# on concrete avals engenders a tradeoff: it allows data-dependent Python
|
|
|
|
# control flow to work, but it can in some cases lead to redundant FLOPs (done
|
|
|
|
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
|
|
|
|
# `concrete` parameter to switch this behavior, and if `concrete` is False
|
|
|
|
# then we raise the avals to the Shaped level.
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
if concrete:
|
|
|
|
instantiated_tracers = map(trace.instantiate_const, tracers)
|
|
|
|
else:
|
|
|
|
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
|
|
|
in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers)
|
|
|
|
fun, aux = partial_eval(f, trace, in_pvs)
|
2020-03-29 23:29:55 -07:00
|
|
|
with core.initial_style_staging():
|
2019-11-27 19:15:53 -08:00
|
|
|
out_flat = remat_call_p.bind(fun, *in_consts, **params)
|
2019-11-22 10:53:11 -08:00
|
|
|
out_pvs, jaxpr, env = aux()
|
2019-11-27 14:28:13 -08:00
|
|
|
env = map(trace.full_raise, env)
|
2019-11-22 10:53:11 -08:00
|
|
|
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
|
|
|
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]
|
|
|
|
|
|
|
|
# Since we traced with everything marked as unknown, but we need to know which
|
|
|
|
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
2020-01-07 13:11:32 -08:00
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env]
|
2019-11-27 14:28:13 -08:00
|
|
|
+ [raise_to_shaped(pv) for pv in in_pvs])
|
2019-12-23 11:49:01 -08:00
|
|
|
out_avals = [raise_to_shaped(pv if pv is not None
|
|
|
|
else abstract_unit if var is unitvar
|
|
|
|
else get_aval(var.val) if type(var) is Literal
|
|
|
|
else get_aval(const))
|
|
|
|
for var, pv, const in zip(jaxpr.outvars, out_pvs, out_pval_consts1)]
|
2020-01-07 13:11:32 -08:00
|
|
|
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
2019-11-27 14:28:13 -08:00
|
|
|
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
|
2020-04-17 20:08:24 +03:00
|
|
|
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns,
|
|
|
|
instantiate=False,
|
|
|
|
trace_type=trace.master.trace_type)
|
2019-11-22 10:53:11 -08:00
|
|
|
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
|
|
|
|
|
2019-11-27 15:25:49 -08:00
|
|
|
# First, we prune the jaxpr to be staged out not to have too many outputs.
|
|
|
|
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
# Next, we need values for the outputs that should be known. Since consts
|
|
|
|
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
|
|
|
|
# minus the residual outputs that we don't need. When `concrete=True`, as an
|
|
|
|
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
|
|
|
|
# produced concrete avals at the output, simply by using those as computed
|
2019-11-27 14:28:13 -08:00
|
|
|
# values. For the use case of reverse-mode ad in op-by-op ("eager mode")
|
|
|
|
# evaluation, all the primal outputs should be concrete (thus not recomputed).
|
2019-11-22 10:53:11 -08:00
|
|
|
to_compute = [not uk and type(pv) is not ConcreteArray
|
|
|
|
for uk, pv in zip(out_unknowns, out_pvs)]
|
2019-11-27 14:28:13 -08:00
|
|
|
jaxpr_1_primals = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
|
|
|
|
_, in_consts = unzip2(t.pval for t in it.chain(env, tracers))
|
|
|
|
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1_primals)(*in_consts)[:-num_res or None]
|
2019-11-22 10:53:11 -08:00
|
|
|
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
|
|
|
|
|
2019-11-27 15:25:49 -08:00
|
|
|
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
|
2019-11-27 14:28:13 -08:00
|
|
|
instantiated_tracers = env + instantiated_tracers
|
2019-11-22 10:53:11 -08:00
|
|
|
const_tracers = map(trace.new_instantiated_const, consts)
|
2020-02-03 20:58:56 +01:00
|
|
|
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
|
2019-11-22 10:53:11 -08:00
|
|
|
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
|
2020-02-05 15:38:25 +01:00
|
|
|
new_params = dict(params, call_jaxpr=lifted_jaxpr)
|
2020-02-03 20:58:56 +01:00
|
|
|
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, instantiated_tracers)),
|
2020-02-05 15:38:25 +01:00
|
|
|
out_tracers, remat_call_p, new_params)
|
2019-11-27 15:25:49 -08:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
2019-11-22 10:53:11 -08:00
|
|
|
return out_tracers
|
|
|
|
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
|
|
|
|
2019-11-27 15:25:49 -08:00
|
|
|
def _dce_jaxpr(typed_jaxpr, outputs):
|
2019-11-22 10:53:11 -08:00
|
|
|
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
2019-11-27 14:28:13 -08:00
|
|
|
# nontrivially DCE through scan, call, or other higher-order primitives.
|
|
|
|
# TODO(mattjj): better DCE
|
2019-11-22 10:53:11 -08:00
|
|
|
jaxpr = typed_jaxpr.jaxpr
|
|
|
|
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
|
2019-12-23 11:49:01 -08:00
|
|
|
out_pairs = [(var, aval) if output else (unitvar, core.abstract_unit)
|
2019-11-27 15:25:49 -08:00
|
|
|
for var, aval, output in zip(outvars, out_avals, outputs)]
|
2019-11-22 10:53:11 -08:00
|
|
|
new_outvars, new_out_avals = unzip2(out_pairs)
|
|
|
|
|
2020-01-22 17:19:14 -08:00
|
|
|
needed_vars = {v for v in new_outvars if type(v) is not Literal}
|
2019-11-22 10:53:11 -08:00
|
|
|
new_eqns = []
|
|
|
|
for eqn in jaxpr.eqns[::-1]:
|
|
|
|
if set(eqn.outvars) & needed_vars:
|
|
|
|
new_eqns.append(eqn)
|
2020-01-22 17:19:14 -08:00
|
|
|
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
|
2019-11-22 10:53:11 -08:00
|
|
|
new_eqns = new_eqns[::-1]
|
2020-01-07 13:11:32 -08:00
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.constvars, jaxpr.invars,
|
2019-11-22 10:53:11 -08:00
|
|
|
new_outvars, new_eqns)
|
|
|
|
return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals,
|
|
|
|
new_out_avals)
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def _reconstruct_pval(pval1: PartialVal, const2: core.Value, unknown: bool):
|
|
|
|
pv1, _ = pval1
|
|
|
|
if unknown or pval1.is_known():
|
2019-11-22 10:53:11 -08:00
|
|
|
return pval1
|
|
|
|
else:
|
|
|
|
if type(pv1) is ConcreteArray:
|
2020-03-18 07:11:44 +01:00
|
|
|
return PartialVal.known(pv1.val)
|
2019-11-22 10:53:11 -08:00
|
|
|
else:
|
2020-03-18 07:11:44 +01:00
|
|
|
return PartialVal.known(const2)
|
2019-11-27 14:28:13 -08:00
|
|
|
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def move_binders_to_front(typed_jaxpr: TypedJaxpr, to_move: Sequence[bool]) -> TypedJaxpr:
|
|
|
|
"""Reorder the `invars` to move to front the ones for which `to_move` is True."""
|
2020-01-07 13:11:32 -08:00
|
|
|
assert not typed_jaxpr.jaxpr.constvars
|
2019-11-27 14:28:13 -08:00
|
|
|
assert len(typed_jaxpr.in_avals) == len(to_move)
|
|
|
|
new_invars = _move_to_front(typed_jaxpr.jaxpr.invars, to_move)
|
2020-01-07 13:11:32 -08:00
|
|
|
new_jaxpr = core.Jaxpr((), new_invars, typed_jaxpr.jaxpr.outvars,
|
2019-11-27 14:28:13 -08:00
|
|
|
typed_jaxpr.jaxpr.eqns)
|
|
|
|
new_in_avals = _move_to_front(typed_jaxpr.in_avals, to_move)
|
|
|
|
new_typed_jaxpr = core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals,
|
|
|
|
new_in_avals, typed_jaxpr.out_avals)
|
|
|
|
return new_typed_jaxpr
|
|
|
|
|
2020-03-18 07:11:44 +01:00
|
|
|
def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
|
2019-11-27 14:28:13 -08:00
|
|
|
return ([elt for elt, move in zip(lst, to_move) if move] +
|
|
|
|
[elt for elt, move in zip(lst, to_move) if not move])
|