rocm_jax/jax/interpreters/partial_eval.py

1543 lines
66 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import contextlib
import functools
from functools import partial
import inspect
import itertools as it
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
2021-08-25 20:46:11 -07:00
List, Union, cast)
from weakref import ref
2018-11-17 18:03:33 -08:00
import numpy as np
2018-11-17 18:03:33 -08:00
from .. import core
from .._src import dtypes
2018-11-17 18:03:33 -08:00
from .. import linear_util as lu
from jax._src.ad_util import Zero
from .._src.api_util import flattened_fun_in_tree
from .._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves
from .._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
partition_list, cache, OrderedSet,
2021-08-25 20:46:11 -07:00
as_hashable_function)
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
2021-08-06 11:09:29 -07:00
dropvar, ConcreteArray, raise_to_shaped, Var, Atom,
JaxprEqn, Primitive)
from jax._src import source_info_util
from ..config import config
2018-11-17 18:03:33 -08:00
map = safe_map
zip = safe_zip
2019-02-15 06:35:54 -08:00
def identity(x): return x
2018-11-17 18:03:33 -08:00
class PartialVal(tuple):
"""Partial value: either a known value or an unknown (abstract) value.
Represented as a pair `(aval_opt, const)` of one of two kinds:
* `(None, <Constant>)` indicates a known value, either a Python regular
value, or a Tracer.
* `(<AbstractValue>, *)` indicates an unknown value characterized by an
abstract value.
"""
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
pv, const = xs
2021-03-19 13:49:38 -07:00
if config.jax_enable_checks:
# type checks
assert isinstance(pv, (AbstractValue, type(None))), xs
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
# invariant checks
if isinstance(pv, AbstractValue):
assert get_aval(const) == core.abstract_unit, xs
return tuple.__new__(cls, xs)
@classmethod
def known(cls, const: core.Value) -> 'PartialVal':
return PartialVal((None, const))
@classmethod
def unknown(cls, aval: AbstractValue) -> 'PartialVal':
return PartialVal((aval, core.unit))
def is_known(self) -> bool:
return self[0] is None
def get_known(self) -> Optional[core.Value]:
"""Get the known value, if known, else None."""
return self[1] if self[0] is None else None
def get_aval(self) -> AbstractValue:
"""Get AbstractValue directly (if unknown) or from the constant (known)."""
known = self.get_known()
if known is not None:
return get_aval(known)
else:
return self[0]
def merge_with_known(self, val: core.Value) -> core.Value:
"""Either the stored known value, or the given 'val'."""
known = self.get_known()
return known if known is not None else val
2018-11-17 18:03:33 -08:00
class JaxprTrace(Trace):
def pure(self, val) -> 'JaxprTracer':
return self.new_const(val)
2018-11-17 18:03:33 -08:00
def lift(self, val) -> 'JaxprTracer':
2018-11-17 18:03:33 -08:00
return self.new_const(val)
def sublift(self, val) -> 'JaxprTracer':
2018-11-17 18:03:33 -08:00
return JaxprTracer(self, val.pval, FreeVar(val))
def new_const(self, val) -> 'JaxprTracer':
if isinstance(val, Tracer) and val._trace.level == self.level:
2018-11-17 18:03:33 -08:00
raise Exception
return JaxprTracer(self, PartialVal.known(val), unit)
2018-11-17 18:03:33 -08:00
def new_instantiated_literal(self, val) -> 'JaxprTracer':
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), Literal(val))
def new_instantiated_const(self, val) -> 'JaxprTracer':
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
2018-11-17 18:03:33 -08:00
def new_arg(self, pval: PartialVal) -> 'JaxprTracer':
const = pval.get_known()
Add reverse-mode AD support for pjit This is a somewhat big patch, because the transposition process turns out to be quite difficult. The biggest issue appears when we do partial evaluation and we have to add a whole bunch of intermediate values as outputs of the primal computation, but we don't have any partition specs for them! A simple workaround would be to mark all of them as replicated, but that would likely tank performance which is why we didn't go with that option. Instead, we use a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile a throwaway executable that lets us query output sharding that XLA considers convenient for the computation. However, there's one more difficulty: XLA's `OpSharding` is much less constrained than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent "block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding` allows arbitrary assignment (permutation) of tensor chunks to devices. This means that not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a (somewhat involved) procedure that should recover one whenever it exists. Unfortunately this makes our support for reverse-mode AD partial, because we might be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA actually comes up with sharding specifications on its own. If it merely propagates the sharding obtained from `PartitionSpec`s into the middle of the computation, then we should be good. In any case, if we end up seeing failures in this path, we should consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided to avoid it unless there's no other way. PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
# XXX: Think twice before changing this constant argument pruning!
# This has really important consequences for partial_eval_jaxpr.
# Most importantly, this guarantees that the unknown jaxpr never uses
# known inputs (if it needs them, then they get passed through residuals).
if const is None:
return JaxprTracer(self, pval, LambdaBinding())
else:
return self.new_const(const)
2018-11-17 18:03:33 -08:00
2020-06-02 10:26:43 -04:00
def instantiate_const(self, tracer) -> Tracer:
const = tracer.pval.get_known()
if const is None:
2018-11-17 18:03:33 -08:00
return tracer
else:
if type(const) in core.literalable_types and np.shape(const) == ():
return self.new_instantiated_literal(const)
else:
return self.new_instantiated_const(const)
2018-11-17 18:03:33 -08:00
def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer':
const = tracer.pval.get_known()
if const is None:
make nested jit stage out full inner jit bodies Before this change, inner jitted functions wouldn't necessarily be fully staged out into an outer-jit trace; instead, as much as possible would be hoisted out of the inner jit. That led to extra constants getting materialized in #1640. For example: ```python @jit def f(x, y): z = 2 * x return y + z @jit def g(x): return f(2, x) g(3) ``` would lead to these XLA computations being compiled and executed: ``` HloModule jit_f.7 ENTRY jit_f.7 { parameter.2 = () parameter(1) tuple.3 = () tuple() parameter.1 = s32[] parameter(0) constant.4 = s32[] constant(2) multiply.5 = s32[] multiply(parameter.1, constant.4) ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5) } HloModule jit_g.14 jaxpr_subcomputation.4 { parameter.6 = () parameter(1) tuple.8 = () tuple() parameter.7 = s32[] parameter(2) parameter.5 = s32[] parameter(0) add.9 = s32[] add(parameter.7, parameter.5) ROOT tuple.10 = (s32[]) tuple(add.9) } ENTRY jit_g.14 { constant.1 = s32[] constant(4) tuple.3 = () tuple() parameter.2 = s32[] parameter(0) call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4 get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0 ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12) } ``` Notice that the `multiply` is separated out from the `add`, and in particular the XLA computation underlying `g` only has the `add` in it. This behavior was desirable when using partial evaluation for reverse-mode autodiff, since in that case we want to partially evaluate all the primal values underneath a call while staging out a jaxpr for the tangent values. But it was undesirable for the other use of partial evaluation, namely forming jaxprs under `jit` (and `pmap`). The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
return tracer
else:
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
make nested jit stage out full inner jit bodies Before this change, inner jitted functions wouldn't necessarily be fully staged out into an outer-jit trace; instead, as much as possible would be hoisted out of the inner jit. That led to extra constants getting materialized in #1640. For example: ```python @jit def f(x, y): z = 2 * x return y + z @jit def g(x): return f(2, x) g(3) ``` would lead to these XLA computations being compiled and executed: ``` HloModule jit_f.7 ENTRY jit_f.7 { parameter.2 = () parameter(1) tuple.3 = () tuple() parameter.1 = s32[] parameter(0) constant.4 = s32[] constant(2) multiply.5 = s32[] multiply(parameter.1, constant.4) ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5) } HloModule jit_g.14 jaxpr_subcomputation.4 { parameter.6 = () parameter(1) tuple.8 = () tuple() parameter.7 = s32[] parameter(2) parameter.5 = s32[] parameter(0) add.9 = s32[] add(parameter.7, parameter.5) ROOT tuple.10 = (s32[]) tuple(add.9) } ENTRY jit_g.14 { constant.1 = s32[] constant(4) tuple.3 = () tuple() parameter.2 = s32[] parameter(0) call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4 get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0 ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12) } ``` Notice that the `multiply` is separated out from the `add`, and in particular the XLA computation underlying `g` only has the `add` in it. This behavior was desirable when using partial evaluation for reverse-mode autodiff, since in that case we want to partially evaluate all the primal values underneath a call while staging out a jaxpr for the tangent values. But it was undesirable for the other use of partial evaluation, namely forming jaxprs under `jit` (and `pmap`). The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
2018-11-17 18:03:33 -08:00
def process_primitive(self, primitive, tracers, params):
if primitive in custom_partial_eval_rules:
return custom_partial_eval_rules[primitive](self, *tracers, **params)
else:
return self.default_process_primitive(primitive, tracers, params)
def default_process_primitive(self, primitive, tracers, params):
"""By default, if all the input tracers are known, then execute the primitive
and all the outputs are known. Otherwise, all the outputs are unknown."""
consts = [t.pval.get_known() for t in tracers]
if all(c is not None for c in consts):
return primitive.bind(*consts, **params)
tracers = map(self.instantiate_const, tracers)
avals = [t.aval for t in tracers]
out_aval = primitive.abstract_eval(*avals, **params)
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
source = source_info_util.current()
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in out_aval]
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
for t in out_tracers: t.recipe = eqn
return out_tracers
else:
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
params, source)
return out_tracer
2018-11-17 18:03:33 -08:00
# We use process_call to handle both call and map primitives.
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
if primitive in call_partial_eval_rules:
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
in_pvals = [t.pval for t in tracers]
2021-03-05 17:59:16 +00:00
ctx: Any
if primitive.map_primitive:
2021-03-05 17:59:16 +00:00
ctx = core.extend_axis_env(params['axis_name'], params['axis_size'], None)
mapped_aval = partial(core.mapped_aval, params['axis_size'])
in_pvals = [pval if pval.is_known() or in_axis is None
else PartialVal.unknown(mapped_aval(in_axis, pval[0]))
for pval, in_axis in zip(in_pvals, params['in_axes'])]
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def app(f, *args):
f, num_outputs = count_outputs(f)
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(closure=out_axes_thunk)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def new_out_axes_thunk():
out_axes = out_axes_thunk()
return out_axes + (0,) * (num_outputs() - len(out_axes))
pe_params = dict(params, out_axes_thunk=new_out_axes_thunk)
return primitive.bind(f, *args, **pe_params)
else:
2021-03-05 17:59:16 +00:00
ctx = contextlib.suppress() # This is a no-op
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
app = partial(primitive.bind, **params)
2021-03-05 17:59:16 +00:00
with ctx:
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
f, in_pvals, app, instantiate=False)
if primitive.map_primitive:
unmapped_aval = partial(core.unmapped_aval, params['axis_size'], params['axis_name'])
2021-03-05 17:59:16 +00:00
out_axes = params['out_axes_thunk']()
out_pvals = [pval if pval.is_known() else
PartialVal.unknown(unmapped_aval(out_axis, pval[0])) if out_axis is not None else
PartialVal.unknown(pval[0])
for pval, out_axis in zip(out_pvals, out_axes)]
# Skip known invars and outvars, and lift constants as regular invars
in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
Add reverse-mode AD support for pjit This is a somewhat big patch, because the transposition process turns out to be quite difficult. The biggest issue appears when we do partial evaluation and we have to add a whole bunch of intermediate values as outputs of the primal computation, but we don't have any partition specs for them! A simple workaround would be to mark all of them as replicated, but that would likely tank performance which is why we didn't go with that option. Instead, we use a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile a throwaway executable that lets us query output sharding that XLA considers convenient for the computation. However, there's one more difficulty: XLA's `OpSharding` is much less constrained than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent "block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding` allows arbitrary assignment (permutation) of tensor chunks to devices. This means that not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a (somewhat involved) procedure that should recover one whenever it exists. Unfortunately this makes our support for reverse-mode AD partial, because we might be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA actually comes up with sharding specifications on its own. If it merely propagates the sharding obtained from `PartitionSpec`s into the middle of the computation, then we should be good. In any case, if we end up seeing failures in this path, we should consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided to avoid it unless there's no other way. PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
jaxpr = _drop_vars(jaxpr, in_knowns, (False,) * len(jaxpr.outvars))
2021-03-05 17:59:16 +00:00
jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
# Known tracers get propagated as if they were constants
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
if pval.is_known()]
# Unknown tracers need to have the jaxpr set up as their recipe
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
if not pval.is_known()]
unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)
# Set up new params
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
if primitive.map_primitive:
in_axes = params['in_axes']
# NOTE: const_tracers are added as map outputs, and we always map them
# along axis 0 (see `new_out_axes_thunk` above).
new_in_axes = ((0,) * len(const_tracers) +
(None,) * len(env_tracers) +
tuple(axis for axis, t in zip(in_axes, tracers)
if not t.pval.is_known()))
new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals)
if not pval.is_known())
new_params = dict(new_params, in_axes=new_in_axes, out_axes=new_out_axes)
del new_params['out_axes_thunk']
update_params = call_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, [not t.pval.is_known() for t in tracers])
eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
2018-11-17 18:03:33 -08:00
process_map = process_call
# We use post_process_call to handle both call and map primitives.
def post_process_call(self, primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
nconsts = len(consts)
del consts, out_pv_consts
main = self.main
if primitive.map_primitive:
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
out_axes = params['out_axes_thunk']()
sz = params['axis_size']
out_pvs = [None if pv is None else core.unmapped_aval(sz, params['axis_name'], ax, pv)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
for pv, ax in zip(out_pvs, out_axes)]
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(main, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
in_tracers = (*const_tracers, *map(trace.full_raise, env))
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
if primitive.map_primitive:
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
# NOTE: We've assigned axis 0 to const tracers below, in out_axes_transform.
new_in_axes = (0,) * len(const_tracers) + (None,) * len(env)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
new_params = dict(new_params, in_axes=new_in_axes, out_axes=out_axes)
del new_params['out_axes_thunk']
update_params = call_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, [])
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
source_info_util.current())
for t in out_tracers:
t.recipe = eqn
return out_tracers
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
if primitive.map_primitive:
def out_axes_transform(out_axes):
return out_axes + (0,) * nconsts
todo = (todo, out_axes_transform)
return out, todo
post_process_map = post_process_call
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
instantiate: bool):
"""Partially evaluate f on a sequence of PartialVals."""
in_avals, in_consts = unzip2(pvals)
f = trace_to_subjaxpr(f, self.main, instantiate)
f, aux = partial_eval_wrapper(f, tuple(in_avals))
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvs = map(PartialVal, zip(out_avals, out_consts))
env_tracers = map(self.full_raise, env)
return jaxpr, out_pvs, consts, env_tracers
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
tracers = map(self.instantiate_const_abstracted, tracers)
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
fun = trace_to_subjaxpr(fun, self.main, True)
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
out_flat = prim.bind(fun, jvp, *in_consts)
out_avals, jaxpr, env = aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
env_tracers = map(self.full_raise, env)
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *tracers)
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
@_memoize
def jvp_jaxpr_thunk():
jvp_ = trace_to_subjaxpr(jvp, self.main, True)
jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
with core.new_sublevel():
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
out_avals, jaxpr, env = aux()
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*consts, *env)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts) + len(env)),
source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
def post_process_custom_jvp_call(self, out_tracers, params):
# This path should only be reachable if we expose a partial eval API
# unrelated to autodiff, since we raise an error when differentiation with
# respect to values over which a custom_jvp function closes is detected.
raise NotImplementedError # TODO(mattjj)
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
tracers = map(self.instantiate_const_abstracted, tracers)
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
fun = trace_to_subjaxpr(fun, self.main, True)
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees)
out_avals, jaxpr, env = aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
env_tracers = map(self.full_raise, env)
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *tracers)
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
@_memoize
def fwd_jaxpr_thunk():
fwd_ = trace_to_subjaxpr(fwd, self.main, True)
fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
with core.new_sublevel():
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
out_avals, jaxpr, env = aux()
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*consts, *env)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(consts) + len(env),
bwd=bwd, out_trees=out_trees),
source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
def post_process_custom_vjp_call(self, out_tracers, params):
# This path should only be reachable if we expose a partial eval API
# unrelated to autodiff, since we raise an error when differentiation with
# respect to values over which a custom_vjp function closes is detected.
raise NotImplementedError # TODO(mattjj)
2018-11-17 18:03:33 -08:00
@lu.transformation_with_aux
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
py_args = map(PartialVal, zip(pvs, consts))
jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
out_pvs, out_consts = unzip2(out_pvals)
out = tuple(out_consts) + tuple(consts)
yield out, (out_pvs, jaxpr, env)
2018-11-17 18:03:33 -08:00
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
@lu.transformation_with_aux
def count_outputs(*args, **kwargs):
ans = yield args, kwargs
yield ans, len(ans)
2018-11-17 18:03:33 -08:00
2021-08-06 11:09:29 -07:00
custom_partial_eval_rules: Dict[Primitive, Callable] = {}
call_partial_eval_rules: Dict[Primitive, Callable] = {}
call_param_updaters: Dict[Primitive, Callable] = {}
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _ = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params), avals, debug_info)
2021-03-29 13:58:04 -07:00
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
return avals_out
JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar',
'ConstVar', Literal, core.Unit]
2018-11-17 18:03:33 -08:00
class JaxprTracer(Tracer):
2019-01-16 16:51:54 +00:00
__slots__ = ['pval', 'recipe']
2018-11-17 18:03:33 -08:00
def __init__(self, trace: JaxprTrace, pval: PartialVal,
recipe: Optional[JaxprTracerRecipe]):
2018-11-17 18:03:33 -08:00
assert isinstance(pval, PartialVal)
pv, const = pval
if isinstance(const, Tracer) and const._trace.level >= trace.level:
raise core.escaped_tracer_error(
const, "Tracer from a higher level: {} in trace {}".format(const, trace))
self._trace = trace
2018-11-17 18:03:33 -08:00
self.pval = pval
self.recipe = recipe
def __repr__(self):
return 'Traced<{}:{}>'.format(self.aval, self._trace)
2018-11-17 18:03:33 -08:00
@property
def aval(self) -> AbstractValue:
return self.pval.get_aval()
2018-11-17 18:03:33 -08:00
@property
def parents(self) -> Sequence['JaxprTracer']:
if isinstance(self.recipe, JaxprEqnRecipe):
return self.recipe.invars
2018-11-17 18:03:33 -08:00
else:
return []
def full_lower(self):
known = self.pval.get_known()
if known is not None:
return core.full_lower(known)
2018-11-17 18:03:33 -08:00
else:
return self
def is_known(self):
2021-06-16 11:10:42 -07:00
return self.pval.is_known()
# TODO(necula): this could return a ClosedJaxpr with out_pvals
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
2020-09-15 08:06:46 -07:00
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
"""Traces a function into a Jaxpr, given PartialVals for inputs.
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
for the outputs. The intermediate values that depend only on known inputs and
are needed to compute the output of `jaxpr` are in `consts` and are passed in
as the constvars of the `jaxpr`. The handling of the known outputs depends on
`instantiate`.
For example, given `fun` defined as follows::
def fun(ki, ui): # ki will be a known input in this example
ka = ki + 2
kb = ka + 3
return (kb, ui + ka)
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
computation that depends on unknown inputs is `ui + ka` and will be the only
computation in the body of the `jaxpr`. This computation depends on the known
intermediate value `ka`, which will be computed statically. Currently, such
constants are either embedded in the Jaxpr if they are scalars, or passed as a
constvar to `jaxpr`, and then the value of the actual constant will be in
`consts`:
When `instantiate=False` we get::
jaxpr =
{ lambda ka ; ki ui.
let c = add ui ka
in (*, c) } # known outputs are `*`
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
consts = [3] # the constant for `ka`
When `instantiate=True` we get::
jaxpr =
{ lambda ka kb ; ki ui.
let c = add ui ka
in (kb, c) } # known output are explicit
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
consts = [3, 6] # values for `ka` and `kb` constvars
"""
2020-09-15 08:06:46 -07:00
with core.new_main(JaxprTrace) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
2018-11-17 18:03:33 -08:00
assert not env
del main, fun, env
2018-11-17 18:03:33 -08:00
return jaxpr, out_pvals, consts
2018-11-17 18:03:33 -08:00
@lu.transformation
def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
pvals: Sequence[PartialVal]):
2018-11-17 18:03:33 -08:00
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
trace = JaxprTrace(main, core.cur_sublevel())
2018-11-17 18:03:33 -08:00
in_tracers = map(trace.new_arg, pvals)
2019-06-23 15:31:13 -07:00
ans = yield in_tracers, {}
assert isinstance(ans, (list, tuple)), (
f"Got unexpected return type when tracing function to jaxpr: {ans}")
assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), (
f"Got unexpected return type when tracing function to jaxpr: {ans}")
instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
out_pvals = [t.pval for t in out_tracers]
del trace, in_tracers, out_tracers
yield jaxpr, (out_pvals, consts, env)
2018-11-17 18:03:33 -08:00
def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
if instantiate:
return trace.instantiate_const(trace.full_raise(tracer))
else:
return tracer
2018-11-17 18:03:33 -08:00
FreeVar = namedtuple('FreeVar', ['val'])
ConstVar = namedtuple('ConstVar', ['val'])
LambdaBinding = namedtuple('LambdaBinding', [])
class JaxprEqnRecipe(NamedTuple):
eqn_id: object
invars: Sequence[JaxprTracer]
outvars: 'Sequence[ref[JaxprTracer]]'
2021-08-06 11:09:29 -07:00
primitive: Primitive
params: Dict[str, Any]
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
source_info: Optional[source_info_util.Traceback]
def new_eqn_recipe(invars: Sequence[JaxprTracer],
outvars: Sequence[JaxprTracer],
2021-08-06 11:09:29 -07:00
primitive: Primitive,
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
params: Dict[str, Any],
source_info: Optional[source_info_util.Traceback]
) -> JaxprEqnRecipe:
"""Constructs a new JaxEqnRecipe.
Params:
invars: the tracers for the primitive inputs.
outvars: the tracers for the primitive outputs.
primitive: the primitive.
params: the primitive params
"""
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
2021-08-30 11:10:10 -07:00
# assert len(invars) == len(params["call_jaxpr"].invars) # TODO constvars?
assert len(outvars) == len(params["call_jaxpr"].outvars)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
if primitive.map_primitive:
assert ("in_axes" in params and
len(params["in_axes"]) == len(params["call_jaxpr"].invars))
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
params, source_info)
2021-08-06 11:09:29 -07:00
def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
invars = [getvar(t) for t in in_tracers]
2021-08-06 11:09:29 -07:00
outvars = [core.dropvar if t is None else cast(Var, getvar(t))
for t in out_tracers]
Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
2018-11-17 18:03:33 -08:00
def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer]
2020-06-02 10:26:43 -04:00
) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
"""Constructs Jaxpr given tracers for inputs and outputs.
Params:
in_tracers: the tracers that were created for the function inputs
out_tracers: the tracers that were output by the function.
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
the `constvars` in the returned Jaxps, and a list of environment values.
The vars for the environment values have been prepended to the Jaxpr's
`invars`.
"""
newvar = core.gensym()
2021-08-06 11:09:29 -07:00
t_to_var: Dict[int, Atom] = {}
def getvar(t: JaxprTracer) -> Atom:
2020-03-09 09:14:23 +00:00
var = t_to_var.get(id(t))
if var is None:
aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
var = t_to_var[id(t)] = newvar(aval)
2020-03-09 09:14:23 +00:00
return var
sorted_tracers = toposort(out_tracers)
invars = map(getvar, in_tracers)
eqns: List[core.JaxprEqn] = []
2021-08-06 11:09:29 -07:00
env: Dict[Var, Any] = {}
consts: Dict[Var, Any] = {}
const_to_var: Dict[int, Var] = {}
2020-03-09 09:14:23 +00:00
def getconstvar(c):
var = const_to_var.get(id(c))
if var is None:
var = const_to_var[id(c)] = newvar(get_aval(c))
2020-03-09 09:14:23 +00:00
return var
processed_eqn_ids = set()
2018-11-17 18:03:33 -08:00
for t in sorted_tracers:
recipe = t.recipe
if isinstance(recipe, JaxprEqnRecipe):
if recipe.eqn_id not in processed_eqn_ids:
eqns.append(recipe_to_eqn(getvar, recipe))
processed_eqn_ids.add(recipe.eqn_id)
2018-11-17 18:03:33 -08:00
elif isinstance(recipe, LambdaBinding):
if not any(t is in_tracer for in_tracer in in_tracers):
raise core.escaped_tracer_error(
t, "Tracer not among input tracers {}".format(t))
2018-11-17 18:03:33 -08:00
assert in_tracers, "Lambda binding with no args"
elif isinstance(recipe, FreeVar):
2021-08-06 11:09:29 -07:00
env[cast(Var, getvar(t))] = recipe.val
2018-11-17 18:03:33 -08:00
elif isinstance(recipe, ConstVar):
2020-03-09 09:14:23 +00:00
v = t_to_var[id(t)] = getconstvar(recipe.val)
consts[v] = recipe.val
elif isinstance(recipe, Literal):
t_to_var[id(t)] = recipe
2018-11-17 18:03:33 -08:00
elif recipe is unit:
t_to_var[id(t)] = unitvar
2018-11-17 18:03:33 -08:00
else:
raise TypeError(recipe)
2018-11-17 18:03:33 -08:00
env_vars, env_vals = unzip2(env.items())
const_vars, const_vals = unzip2(consts.items())
# The env_vars are pre-pended to the invars
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
2021-03-19 13:49:38 -07:00
config.jax_enable_checks and core.check_jaxpr(jaxpr)
2018-11-17 18:03:33 -08:00
return jaxpr, const_vals, env_vals
@cache()
def convert_constvars_jaxpr(jaxpr: Jaxpr):
"""Moves the constvars to the start of invars."""
2021-03-19 13:49:38 -07:00
config.jax_enable_checks and core.check_jaxpr(jaxpr)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
2021-03-19 13:49:38 -07:00
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int):
2021-03-19 13:49:38 -07:00
config.jax_enable_checks and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
2021-03-19 13:49:38 -07:00
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr
def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
instantiate: Union[bool, Sequence[bool]],
) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
"""Specializes a Jaxpr given an indication of which inputs are known.
Returns: (jaxpr_known, jaxpr_unknown, out_unknowns).
`out_unknowns` specifies which outputs are unknown (depend on some unknown inputs).
`jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs,
and performs *all* the computation in `jaxpr` that depends only on the known inputs.
Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`,
appended with the known residuals (the intermediate computations in `jaxpr`
that depend only on known inputs and that are needed to compute the unknown outputs).
`jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals
computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known
outputs replaced by `*`.
Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively
unknown inputs into:
jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *)
let _, uout = jaxpr_unknown(ki, ui, kresidual)
in (kout, uout)
For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2
in (ki + 3, ui + ka)"
then
`jaxpr_known` = lambda ki, ui: let ka = ki + 2
in (ki + 3, *, ka)
'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka)
Note that if instantiate is True for a given output, then jaxpr_known always returns a
unit in its place. So when instantiate is True, the expectation is the one doesn't
run `jaxpr_known` for any of its outputs, but only to generate residuals that will allow
to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
simply get passed as residual constants and returned immediately.
"""
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
cell = []
def fun(*vals):
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
2020-09-15 08:06:46 -07:00
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
return out_consts_2 + consts_2
2020-09-15 08:06:46 -07:00
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
# known inputs.
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
(out_pvs_2, jaxpr_2, num_res), = cell
assert len(jaxpr_2.constvars) == num_res
# jaxpr :: a -> b
# jaxpr_1 :: a1 -> [b1, res]
# jaxpr_2 :: res | a2 -> b2
# jaxpr_2 :: [a2, res] -> b2
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
2020-03-09 09:14:23 +00:00
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
if not unknown:
var.aval = abstract_unit
uk_out = [pv is not None for pv in out_pvs_2]
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
# out_avals_1 and in_avals_2 need the residuals added
2020-09-15 08:06:46 -07:00
res_avals = out_avals[len(jaxpr.out_avals):]
assert len(res_avals) == num_res
2020-09-15 08:06:46 -07:00
out_avals_1 = [*out_avals_1, *res_avals]
in_avals_2 = [*in_avals_2, *res_avals]
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
2021-08-06 11:09:29 -07:00
remat_call_p: Primitive = core.CallPrimitive('remat_call')
remat_call = remat_call_p.bind
remat_call_p.def_impl(core.call_impl)
def _remat_partial_eval(trace, _, f, tracers, params):
concrete = params['concrete']
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
# the function being called, not just for the unknown parts. To do that, we
# instantiate all the input tracers as constants in the jaxpr being formed.
# Those tracers might have concrete avals, and doing abstract interpretation
# on concrete avals engenders a tradeoff: it allows data-dependent Python
# control flow to work, but it can in some cases lead to redundant FLOPs (done
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
# `concrete` parameter to switch this behavior, and if `concrete` is False
# then we raise the avals to the Shaped level.
make nested jit stage out full inner jit bodies Before this change, inner jitted functions wouldn't necessarily be fully staged out into an outer-jit trace; instead, as much as possible would be hoisted out of the inner jit. That led to extra constants getting materialized in #1640. For example: ```python @jit def f(x, y): z = 2 * x return y + z @jit def g(x): return f(2, x) g(3) ``` would lead to these XLA computations being compiled and executed: ``` HloModule jit_f.7 ENTRY jit_f.7 { parameter.2 = () parameter(1) tuple.3 = () tuple() parameter.1 = s32[] parameter(0) constant.4 = s32[] constant(2) multiply.5 = s32[] multiply(parameter.1, constant.4) ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5) } HloModule jit_g.14 jaxpr_subcomputation.4 { parameter.6 = () parameter(1) tuple.8 = () tuple() parameter.7 = s32[] parameter(2) parameter.5 = s32[] parameter(0) add.9 = s32[] add(parameter.7, parameter.5) ROOT tuple.10 = (s32[]) tuple(add.9) } ENTRY jit_g.14 { constant.1 = s32[] constant(4) tuple.3 = () tuple() parameter.2 = s32[] parameter(0) call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4 get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0 ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12) } ``` Notice that the `multiply` is separated out from the `add`, and in particular the XLA computation underlying `g` only has the `add` in it. This behavior was desirable when using partial evaluation for reverse-mode autodiff, since in that case we want to partially evaluate all the primal values underneath a call while staging out a jaxpr for the tangent values. But it was undesirable for the other use of partial evaluation, namely forming jaxprs under `jit` (and `pmap`). The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
if concrete:
instantiated_tracers = map(trace.instantiate_const, tracers)
else:
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
2021-08-27 16:59:54 -07:00
in_pvals = [PartialVal.unknown(t.pval[0]) for t in instantiated_tracers]
2021-03-29 13:58:04 -07:00
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
jaxpr = convert_constvars_jaxpr(jaxpr)
# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
2021-08-19 17:12:13 -07:00
if params['policy']:
2021-08-25 20:46:11 -07:00
# unzip into known and jaxpr_unknown
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = _partial_eval_jaxpr_custom(
2021-08-19 17:12:13 -07:00
jaxpr, in_unknowns, params['policy'])
2021-08-25 20:46:11 -07:00
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
jaxpr_unknown, in_used_unknown = dce_jaxpr(jaxpr_unknown, used_outs_unknown)
2021-08-06 11:09:29 -07:00
# compute known outputs and residuals (hoisted out of a remat_call)
if concrete: raise NotImplementedError # TODO(mattjj)
_, in_consts_ = unzip2(t.pval for t in it.chain(env_tracers, tracers)
if t.pval.is_known())
2021-08-25 20:46:11 -07:00
_, in_consts = partition_list(in_used_known, [*consts, *in_consts_])
out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
2021-08-06 11:09:29 -07:00
out_consts_ = iter(out_consts)
# reconstruct known outs, inserting units
outs1 = [pval.get_known() if x.aval is abstract_unit else next(out_consts_)
for uk, pval, x in zip(out_unknowns, eval_out_pvals, jaxpr.outvars)
if not uk]
2021-08-06 11:09:29 -07:00
# form known outputs and collect residual tracers
out_known_tracers = [JaxprTracer(trace, PartialVal.known(c), None)
for c in outs1]
residuals = list(out_consts_)
# set up unknown outputs with a recipe to call remat
res_tracers = map(trace.new_instantiated_const, residuals)
const_tracers = map(trace.new_instantiated_const, consts)
in_jaxpr_tracers = [*res_tracers, *const_tracers, *env_tracers,
*instantiated_tracers]
2021-08-25 20:46:11 -07:00
_, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
2021-08-06 11:09:29 -07:00
out_jaxpr_tracers = [JaxprTracer(trace, PartialVal.unknown(x.aval), None)
2021-08-25 20:46:11 -07:00
for x in jaxpr_unknown.outvars]
new_params = dict(params, call_jaxpr=jaxpr_unknown, differentiated=True)
2021-08-06 11:09:29 -07:00
recipe = new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_call_p,
new_params, source_info_util.current())
for t in out_jaxpr_tracers: t.recipe = recipe
return _zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
else:
# TODO(mattjj): this is an old parallel code path, to be deleted once the
# new path is fully functional
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
closed_jaxpr, in_unknowns, instantiate=False) # type: ignore
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
# Next, we need values for the outputs that should be known. Since consts
# weren't passed through Python for evaluation, we need to evaluate
# jaxpr_known, minus the residual outputs that we don't need. When
# `concrete=True`, as an optimization we can avoid redoing *some* redundant
# FLOPs, namely those that produced concrete avals at the output, simply by
# using those as computed values. For the use case of inverse-mode ad in
# op-by-op ("eager mode") evaluation, all the primal outputs should be
# concrete (thus not recomputed).
to_compute = [type(pval[0]) is not ConcreteArray
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
num_outputs = len(jaxpr_unknown.out_avals)
num_res = len(jaxpr_known.out_avals) - num_outputs
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res,
drop_outputs=True)
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
# Known outputs should keep propagating as constants
assert all(pv.is_known() for pv in out_known_pvals)
known_output_tracers = [trace.new_const(pval.get_known())
for pval in out_known_pvals]
# Unknown outputs get wrapped in tracers with the appropriate recipe
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
for out_pval in out_unknown_pvals]
# dce jaxpr outputs
new_jaxpr = _dce_jaxpr(closed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
new_params = dict(params, call_jaxpr=new_jaxpr, differentiated=True)
# set up eqn for unknown outputs
const_tracers = map(trace.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p,
new_params, source_info_util.current())
for t in unknown_output_tracers: t.recipe = eqn
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
def _partition_knowns(pvals, unknowns: Sequence[bool]):
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
[e for e, unknown in zip(pvals, unknowns) if unknown])
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
2021-08-06 11:09:29 -07:00
assert len(known_list) + len(unknown_list) == len(which_unknown)
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
2021-08-06 11:09:29 -07:00
def _partial_eval_jaxpr_custom(
2021-08-25 20:46:11 -07:00
jaxpr: Jaxpr, in_unknowns: Sequence[bool], saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, Sequence[bool], Sequence[bool], int]:
2021-08-27 17:42:42 -07:00
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
2021-08-06 11:09:29 -07:00
env: Dict[Var, Tuple[bool, bool]] = {}
2021-08-25 20:46:11 -07:00
residuals: OrderedSet[Var] = OrderedSet()
2021-08-06 11:09:29 -07:00
def read(x: Atom) -> Tuple[bool, bool]:
if type(x) is Var:
return env[x]
return (False, True)
def write(unk: bool, inst: bool, v: Var) -> None:
assert (unk, inst) != (True, False)
env[v] = (unk, inst)
def ensure_instantiated(inst: bool, x: Atom) -> Atom:
if type(x) is Var and not inst:
residuals.add(x)
return x
2021-08-25 20:46:11 -07:00
known_eqns, staged_eqns = [], []
2021-08-06 11:09:29 -07:00
write(False, True, unitvar)
map(write, in_unknowns, [True] * len(in_unknowns), jaxpr.invars)
for eqn in jaxpr.eqns:
unks_in, inst_in = unzip2(map(read, eqn.invars))
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
if rule:
eqn1, eqn2, unks_out, inst_out, res = rule(saveable, unks_in, inst_in, eqn)
2021-08-25 20:46:11 -07:00
eqn1 and known_eqns.append(eqn1); eqn2 and staged_eqns.append(eqn2) # type: ignore
2021-08-06 11:09:29 -07:00
residuals.update(res)
map(write, unks_out, inst_out, eqn.outvars)
elif any(unks_in):
inputs = map(ensure_instantiated, inst_in, eqn.invars)
2021-08-25 20:46:11 -07:00
staged_eqns.append(new_jaxpr_eqn(inputs, eqn.outvars, eqn.primitive,
eqn.params, eqn.source_info))
2021-08-06 11:09:29 -07:00
map(partial(write, True, True), eqn.outvars)
else:
2021-08-25 20:46:11 -07:00
known_eqns.append(eqn)
2021-08-06 11:09:29 -07:00
if saveable(eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params):
map(partial(write, False, False), eqn.outvars)
else:
inputs = map(ensure_instantiated, inst_in, eqn.invars)
2021-08-25 20:46:11 -07:00
staged_eqns.append(new_jaxpr_eqn(inputs, eqn.outvars, eqn.primitive,
eqn.params, eqn.source_info))
2021-08-06 11:09:29 -07:00
map(partial(write, False, True), eqn.outvars)
out_unknowns, out_inst = unzip2(map(read, jaxpr.outvars))
assert all(type(v) is Var for v in residuals), residuals
2021-08-25 20:46:11 -07:00
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
outs_known_, _ = partition_list(out_unknowns, jaxpr.outvars)
outs_known = [x for x in outs_known_ if x.aval is not abstract_unit]
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns)
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
2021-08-06 11:09:29 -07:00
2021-08-25 20:46:11 -07:00
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
jaxpr_staged = Jaxpr((), [*residuals, *jaxpr.invars], outs_staged, staged_eqns)
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
2021-08-06 11:09:29 -07:00
2021-08-25 20:46:11 -07:00
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals)
2021-08-06 11:09:29 -07:00
2021-08-19 17:12:13 -07:00
# A primitive rule for policy-driven partial evaluation returns a 5-tuple
# with the components representing, respectively:
# * the JaxprEqn for the 'known' side (or None if there is no known component),
# * the JaxprEqn for the 'unknown' side (or None),
# * a list of booleans indicating which of the original outputs are unknown,
# * a list of booleans indicating which of the original outputs are
# instantiated (i.e. available) in the 'unknown' side,
# * a list of Var instances representing residuals to be added (i.e. to be
# plumbed as outputs of the 'known' side jaxpr and added as input binders to
# the 'unknown' jaxpr).
2021-08-06 11:09:29 -07:00
PartialEvalCustomResult = Tuple[Optional[JaxprEqn], Optional[JaxprEqn],
2021-08-25 20:46:11 -07:00
Sequence[bool], Sequence[bool], List[Var]]
2021-08-06 11:09:29 -07:00
PartialEvalCustomRule = Callable[
2021-08-25 20:46:11 -07:00
[Callable[..., bool], Sequence[bool], Sequence[bool], JaxprEqn],
2021-08-06 11:09:29 -07:00
PartialEvalCustomResult]
partial_eval_jaxpr_custom_rules: Dict[Primitive, PartialEvalCustomRule] = {}
def partial_eval_jaxpr_custom_rule_not_implemented(
2021-08-25 20:46:11 -07:00
saveable: Callable[..., bool], unks_in: Sequence[bool], inst_in: Sequence[bool],
2021-08-06 11:09:29 -07:00
eqn: JaxprEqn) -> PartialEvalCustomResult:
raise NotImplementedError
ParamsUpdater = Callable[[List[bool], int, dict, dict], Tuple[dict, dict]]
2021-08-19 17:12:13 -07:00
def call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
eqn: JaxprEqn
2021-08-25 20:46:11 -07:00
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = _partial_eval_jaxpr_custom(
eqn.params[jaxpr_param_name], unks_in, saveable)
2021-08-25 20:46:11 -07:00
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
out_binders_known = [v for v in out_binders_known if v is not dropvar]
2021-08-25 20:46:11 -07:00
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym([jaxpr_known, jaxpr_staged])
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
params_known = dict(eqn.params, call_jaxpr=jaxpr_known)
params_staged = dict(eqn.params, call_jaxpr=jaxpr_staged)
params_known, params_staged = params_updater(unks_in, num_res, params_known, params_staged)
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, eqn.source_info)
eqn_staged = new_jaxpr_eqn([*residuals, *eqn.invars], out_binders_staged,
eqn.primitive, params_staged, eqn.source_info)
2021-08-25 20:46:11 -07:00
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
2021-08-06 11:09:29 -07:00
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is Var and not inst]
2021-08-25 20:46:11 -07:00
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
2021-08-06 11:09:29 -07:00
partial_eval_jaxpr_custom_rules[core.call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, x, y: (x, y))
2021-08-27 17:42:42 -07:00
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, x, y: (x, y))
2021-08-06 11:09:29 -07:00
partial_eval_jaxpr_custom_rules[remat_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
2021-08-06 11:09:29 -07:00
lambda _, __, p1, p2: (p1, dict(p2, differentiated=True)))
2021-08-27 17:42:42 -07:00
# TODO(mattjj): unify with dce code below
2021-08-06 11:09:29 -07:00
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: List[bool]
) -> Tuple[Jaxpr, List[bool]]:
2021-08-27 17:42:42 -07:00
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
2021-08-06 11:09:29 -07:00
env: Dict[Var, bool] = {}
def read(v: Var) -> bool:
return env.get(v, False)
def write(x: Atom, b: bool) -> None:
if type(x) is Var:
env[x] = read(x) or b
new_eqns = []
map(write, jaxpr.outvars, used_outputs)
for eqn in jaxpr.eqns[::-1]:
used_outs = map(read, eqn.outvars)
rule = dce_rules.get(eqn.primitive)
if rule:
used_ins, new_eqn = rule(used_outs, eqn)
if any(used_ins): new_eqns.append(new_eqn)
2021-09-02 15:41:39 -07:00
elif any(used_outs):
new_eqns.append(eqn)
used_ins = [True] * len(eqn.invars)
2021-08-06 11:09:29 -07:00
else:
2021-09-02 15:41:39 -07:00
used_ins = [False] * len(eqn.invars)
2021-08-06 11:09:29 -07:00
map(write, eqn.invars, used_ins)
used_inputs = map(read, jaxpr.invars)
new_jaxpr = Jaxpr((),
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
[v for v, b in zip(jaxpr.outvars, used_outputs) if b],
new_eqns[::-1])
config.jax_enable_checks and core.check_jaxpr(new_jaxpr)
return new_jaxpr, used_inputs
DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], JaxprEqn]]
dce_rules: Dict[Primitive, DCERule] = {}
def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], JaxprEqn]:
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
update_params = call_param_updaters.get(eqn.primitive)
if update_params:
new_params = update_params(new_params, used_inputs)
new_eqn = new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, eqn.source_info)
2021-08-06 11:09:29 -07:00
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
2021-08-27 17:42:42 -07:00
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
dce_rules[remat_call_p] = dce_jaxpr_call_rule
2021-08-06 11:09:29 -07:00
def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJaxpr:
new_jaxpr = _dce_open_jaxpr(closed_jaxpr.jaxpr, tuple(outputs), drop_outputs)
return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
@cache()
def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan, call, or other higher-order primitives.
2021-08-27 17:42:42 -07:00
# TODO(mattjj): better DCE (i.e. use above dce_jaxpr)
if drop_outputs:
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
else:
new_outvars = [var if output else unitvar
for var, output in zip(jaxpr.outvars, outputs)]
needed_vars = {v for v in new_outvars if type(v) is not Literal}
new_eqns = []
for eqn in jaxpr.eqns[::-1]:
if set(eqn.outvars) & needed_vars:
new_eqns.append(eqn)
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
new_eqns = new_eqns[::-1]
2021-08-06 11:09:29 -07:00
return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars, new_eqns)
@cache()
Add reverse-mode AD support for pjit This is a somewhat big patch, because the transposition process turns out to be quite difficult. The biggest issue appears when we do partial evaluation and we have to add a whole bunch of intermediate values as outputs of the primal computation, but we don't have any partition specs for them! A simple workaround would be to mark all of them as replicated, but that would likely tank performance which is why we didn't go with that option. Instead, we use a newly added XLA option called `allow_spmd_sharding_propagation_to_output` to compile a throwaway executable that lets us query output sharding that XLA considers convenient for the computation. However, there's one more difficulty: XLA's `OpSharding` is much less constrained than our `PartitionSpec`s. In particular, while `PartitionSpec`s can only represent "block permutations" of devices (with blocks deliniated by mesh axes), `OpSharding` allows arbitrary assignment (permutation) of tensor chunks to devices. This means that not every `OpSharding` has a corresponding `PartitionSpec`, but I did implement a (somewhat involved) procedure that should recover one whenever it exists. Unfortunately this makes our support for reverse-mode AD partial, because we might be unable to handle `OpSharding` returned by XLA. But this will only happen if XLA actually comes up with sharding specifications on its own. If it merely propagates the sharding obtained from `PartitionSpec`s into the middle of the computation, then we should be good. In any case, if we end up seeing failures in this path, we should consider relaxing `PartitionSpec`s, but that would be a pretty large change, so I decided to avoid it unless there's no other way. PiperOrigin-RevId: 399680306
2021-09-29 07:19:28 -07:00
def _drop_vars(jaxpr: Jaxpr, drop_ins: Tuple[bool, ...], drop_outs: Tuple[bool, ...]):
return Jaxpr(jaxpr.constvars,
[v for v, d in zip(jaxpr.invars, drop_ins) if not d],
[v for v, d in zip(jaxpr.outvars, drop_outs) if not d],
jaxpr.eqns)
def _reconstruct_pval(pval1: PartialVal, const2: core.Value):
pv1, _ = pval1
if pval1.is_known():
return pval1
else:
if type(pv1) is ConcreteArray:
return PartialVal.known(pv1.val) # pytype: disable=attribute-error
else:
return PartialVal.known(const2)
def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr:
"""Reorder the `invars` to move to front the ones for which `to_move` is True."""
assert not closed_jaxpr.jaxpr.constvars
assert len(closed_jaxpr.in_avals) == len(to_move)
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
2021-08-06 11:09:29 -07:00
new_jaxpr = Jaxpr((), new_invars, closed_jaxpr.jaxpr.outvars,
closed_jaxpr.jaxpr.eqns)
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
return new_closed_jaxpr
def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
return ([elt for elt, move in zip(lst, to_move) if move] +
[elt for elt, move in zip(lst, to_move) if not move])
class DynamicJaxprTracer(core.Tracer):
__slots__ = ['aval']
def __init__(self, trace, aval, line_info=None):
self._trace = trace
self._line_info = line_info
self.aval = aval
def full_lower(self):
return self
def _contents(self):
return ()
2020-09-15 08:06:46 -07:00
def _origin_msg(self):
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
dbg = self._trace.main.debug_info
2021-05-06 20:23:34 -07:00
if dbg is None:
return ""
if invar_pos:
origin = (f"While tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}, "
"this concrete value was not available in Python because it "
f"depends on the value{'s' if len(invar_pos) > 1 else ''} "
f"of {dbg.arg_info(invar_pos)}.")
elif progenitor_eqns:
msts = [" operation "
f"{core.pp_eqn(eqn, core.JaxprPpContext(), print_shapes=True)}\n"
2020-09-18 10:49:04 -07:00
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns]
origin = (f"While tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}, "
2020-09-15 08:06:46 -07:00
"this value became a tracer due to JAX operations on these lines:"
2020-09-18 10:49:04 -07:00
"\n\n" + "\n\n".join(msts))
2020-09-15 08:06:46 -07:00
else:
origin = (f"The error occured while tracing the function {dbg.func_src_info} "
f"for {dbg.traced_for}.")
return "\n" + origin
def _assert_live(self) -> None:
if not self._trace.main.jaxpr_stack: # type: ignore
raise core.escaped_tracer_error(self, None)
class JaxprStackFrame:
__slots__ = ['gensym', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
'tracers', 'eqns', 'invars']
def __init__(self):
self.gensym = core.gensym()
self.tracer_to_var = {}
self.constid_to_var = {}
self.constvar_to_val = {}
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
self.eqns = [] # cleared when we pop frame from main
self.invars = []
def to_jaxpr(self, in_tracers, out_tracers):
invars = [self.tracer_to_var[id(t)] for t in in_tracers]
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
jaxpr, constvals = _prune_convert_element_types(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
out_avals = [t.aval for t in out_tracers]
return jaxpr, out_avals, constvals
def newvar(self, aval):
return self.gensym(aval)
def find_progenitors(self, tracer):
2020-09-15 08:06:46 -07:00
var = self.tracer_to_var.get(id(tracer))
if not var:
2020-10-20 16:10:56 -07:00
return None, None
2020-09-15 08:06:46 -07:00
active_vars = {var}
for eqn in self.eqns[::-1]:
produced = set(eqn.outvars) & active_vars
if produced:
active_vars.difference_update(produced)
active_vars.update(eqn.invars)
invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars]
constvars = active_vars & set(self.constvar_to_val)
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns
def _prune_convert_element_types(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
new_eqns = []
for eqn in jaxpr.eqns:
if eqn.primitive is core.convert_element_type_p:
c = consts.get(eqn.invars[0])
if type(c) in core.literalable_types and not np.shape(c):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype'])
continue
if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't stage out no-op convert_element_type calls as clutter
consts[eqn.outvars[0]] = c
continue
new_eqns.append(eqn)
new_constvars, new_constvals = unzip2(consts.items())
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
return new_jaxpr, new_constvals
def _inline_literals(jaxpr, constvals):
# This function also ensures variables are labeled in a canonical ordering,
# prunes unused constants, and inserts `dropvar` symbols.
consts = dict(zip(jaxpr.constvars, constvals))
newvar = core.gensym()
newvars = {}
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
2021-08-06 11:09:29 -07:00
def lit(var: Var) -> Optional[Any]:
val = consts.get(var)
if type(val) in core.literalable_types and not np.shape(val):
return Literal(val)
else:
return None
used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals)
if v in used and not lit(v)]
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(v) or var(v) for v in eqn.invars]
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals
class DynamicJaxprTrace(core.Trace):
__slots__ = [] # type: ignore
@property
def frame(self):
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
def new_arg(self, aval):
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
self.frame.invars.append(var)
return tracer
def new_const(self, val):
aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val))
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
self.frame.constvar_to_val[var] = val
return tracer
pure = lift = sublift = new_const
def getvar(self, tracer):
var = self.frame.tracer_to_var.get(id(tracer))
if var is None:
raise core.escaped_tracer_error(tracer)
2020-10-20 16:10:56 -07:00
return var
def makevar(self, tracer):
var = self.frame.tracer_to_var.get(id(tracer))
assert var is None, "a jaxpr variable must be created only once per tracer"
self.frame.tracers.append(tracer)
var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
return var
def getconstvar(self, c):
var = self.frame.constid_to_var.get(id(c))
if var is None:
var = self.frame.constid_to_var[id(c)] = self.frame.newvar(get_aval(c))
return var
def instantiate_const(self, val):
if (isinstance(val, Tracer) and val._trace.main is self.main
and val._trace.sublevel == self.sublevel):
return val
else:
return self.new_const(val)
def process_primitive(self, primitive, tracers, params):
avals = [t.aval for t in tracers]
out_avals = primitive.abstract_eval(*avals, **params)
out_avals = [out_avals] if not primitive.multiple_results else out_avals
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
2020-10-20 16:10:56 -07:00
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
self.frame.eqns.append(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()
def process_call(self, call_primitive, f, tracers, params):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *tracers)
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
2020-10-20 16:10:56 -07:00
outvars = map(self.makevar, out_tracers)
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
update_params = call_param_updaters.get(call_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
new_params, source_info)
self.frame.eqns.append(eqn)
return out_tracers
def post_process_call(self, call_primitive, out_tracers, params):
assert False # unreachable
def process_map(self, map_primitive, f, tracers, params):
in_avals = [t.aval for t in tracers]
axis_name, axis_size = params['axis_name'], params['axis_size']
reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a)
if in_axis is not None else a
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals)
out_axes = params['out_axes_thunk']()
out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a)
if out_axis is not None else a
for a, out_axis in zip(reduced_out_avals, out_axes)]
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
outvars = map(self.makevar, out_tracers)
new_in_axes = (None,) * len(consts) + params['in_axes']
new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
call_jaxpr=convert_constvars_jaxpr(jaxpr))
del new_params['out_axes_thunk']
update_params = call_param_updaters.get(map_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
new_params, source_info)
self.frame.eqns.append(eqn)
return out_tracers
def post_process_map(self, map_primitive, out_tracers, params):
assert False # unreachable
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
jvp_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2])
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
2020-10-20 16:10:56 -07:00
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
dict(fun_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts)),
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers
def post_process_custom_jvp_call(self, out_tracers, params):
assert False # unreachable
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
fwd_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2])
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
2020-10-20 16:10:56 -07:00
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
dict(fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(consts),
bwd=bwd, out_trees=out_trees),
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers
def post_process_custom_vjp_call(self, out_tracers, params):
assert False # unreachable
def _memoize(thunk):
cell = []
saved_state = core.thread_local_state.trace_state.copy()
def memoized():
if not cell:
prev_state = core.thread_local_state.trace_state
core.thread_local_state.trace_state = saved_state
try:
cell.append(thunk())
finally:
core.thread_local_state.trace_state = prev_state
return cell[0]
return memoized
class DebugInfo(NamedTuple):
func_src_info: str
traced_for: str
arg_info: Callable[[int], str]
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
in_tree, has_kwargs = flattened_fun_in_tree(fn) or (None, False)
return debug_info(fn.f, in_tree, has_kwargs, traced_for)
def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], has_kwargs: bool,
traced_for: str) -> DebugInfo:
func_src_info = fun_sourceinfo(fn)
if in_tree is not None:
arg_info = partial(arg_info_pytree, fn, in_tree, has_kwargs)
else:
arg_info = arg_info_flattened # type: ignore
return DebugInfo(func_src_info, traced_for, arg_info)
def fun_sourceinfo(fun: Callable):
while isinstance(fun, functools.partial):
fun = fun.func
fun = inspect.unwrap(fun)
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
line_info = f"{fun.__name__} at {filename}:{lineno}"
return line_info
except AttributeError:
return "<unknown>"
def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
flat_pos: List[int]) -> str:
dummy_args = [False] * in_tree.num_leaves
for i in flat_pos: dummy_args[i] = True
if has_kwargs:
args, kwargs = tree_unflatten(in_tree, dummy_args)
else:
args, kwargs = tree_unflatten(in_tree, dummy_args), {}
try:
ba = inspect.signature(fn).bind(*args, **kwargs)
except (TypeError, ValueError):
return arg_info_flattened(flat_pos)
arg_names = [f"'{name}'" for name, x in ba.arguments.items()
if any(tree_leaves(x))]
if len(arg_names) == 1:
return f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
return f"the arguments {arg_names[0]} and {arg_names[1]}"
else:
*rest, last = arg_names
return f"the arguments {', '.join(rest)}, and {last}"
def arg_info_flattened(flat_pos: List[int]) -> str:
if len(flat_pos) > 1:
return f"the argument passed at flattened positions {flat_pos}"
else:
return f"the argument passed at flattened position {flat_pos[0]}"
2021-03-23 19:47:58 +00:00
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
del main, fun
return jaxpr, out_avals, consts
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue]):
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame):
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = map(trace.new_arg, in_avals)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(trace.full_raise, ans)
jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
del fun, main, trace, frame, in_tracers, out_tracers, ans
return jaxpr, out_avals, consts
@contextlib.contextmanager
def extend_jaxpr_stack(main, frame):
main.jaxpr_stack = main.jaxpr_stack + (frame,)
try:
yield
finally:
assert frame is main.jaxpr_stack[-1]
main.jaxpr_stack = main.jaxpr_stack[:-1]
2021-03-23 19:47:58 +00:00
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
del fun, main
return jaxpr, out_avals, consts
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
# This function provides a partial evaluation behavior used by Flax. We can't
# use trace_to_jaxpr directly because of an interaction with the curent
# custom_derivatives.py, which we work around by adding the EvalTrace.
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
return trace_to_jaxpr(fun, in_pvals)