From 6cc1bf54a111fe86971d33b6ec07d18258590a88 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 27 Mar 2023 13:29:59 -0700 Subject: [PATCH] Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval. Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters. PiperOrigin-RevId: 519813664 --- jax/__init__.py | 1 + jax/_src/ad_checkpoint.py | 8 +- jax/_src/api.py | 4 +- jax/_src/checkify.py | 4 +- jax/_src/custom_batching.py | 6 +- jax/_src/custom_derivatives.py | 4 +- jax/_src/custom_transpose.py | 6 +- jax/_src/debugging.py | 2 +- jax/_src/dispatch.py | 4 +- jax/_src/interpreters/ad.py | 2 +- jax/_src/interpreters/batching.py | 2 +- jax/_src/interpreters/mlir.py | 6 +- jax/_src/interpreters/partial_eval.py | 2506 ++++++++++++++++++++ jax/_src/interpreters/pxla.py | 2 +- jax/_src/lax/ann.py | 3 +- jax/_src/lax/control_flow/common.py | 2 +- jax/_src/lax/control_flow/conditionals.py | 13 +- jax/_src/lax/control_flow/for_loop.py | 12 +- jax/_src/lax/control_flow/loops.py | 9 +- jax/_src/lax/lax.py | 28 +- jax/_src/lax/linalg.py | 40 +- jax/_src/lax/slicing.py | 2 +- jax/_src/lax/utils.py | 3 +- jax/_src/maps.py | 4 +- jax/_src/pjit.py | 4 +- jax/_src/state/discharge.py | 2 +- jax/_src/state/primitives.py | 6 +- jax/experimental/custom_partitioning.py | 6 +- jax/experimental/host_callback.py | 4 +- jax/experimental/jet.py | 2 +- jax/experimental/shard_map.py | 6 +- jax/experimental/sparse/bcoo.py | 2 +- jax/experimental/sparse/transform.py | 6 +- jax/interpreters/partial_eval.py | 2579 +-------------------- setup.cfg | 1 + tests/api_test.py | 2 +- tests/core_test.py | 2 +- tests/jaxpr_effects_test.py | 4 +- tests/state_test.py | 2 +- 39 files changed, 2705 insertions(+), 2596 deletions(-) create mode 100644 jax/_src/interpreters/partial_eval.py diff --git a/jax/__init__.py b/jax/__init__.py index c3e55010e..3a2ee492e 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -127,6 +127,7 @@ from jax._src.api import xla_computation as xla_computation from jax.interpreters import ad # TODO(phawkins): update users to avoid this. from jax.interpreters import pxla # TODO(phawkins): update users to avoid this. +from jax.interpreters import partial_eval # TODO(phawkins): update users to avoid this. from jax.interpreters import xla # TODO(phawkins): update users to avoid this. from jax._src.array import ( diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 57ebd0e90..c2405eb02 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -21,12 +21,10 @@ import types import numpy as np import jax -from jax.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr from jax._src import ad_util from jax._src import core +from jax._src import dispatch from jax._src import linear_util as lu from jax._src import effects from jax._src import source_info_util @@ -35,6 +33,8 @@ from jax._src import util from jax._src.api_util import flatten_fun, shaped_abstractify from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution from jax._src.lib.mlir.dialects import hlo @@ -737,7 +737,7 @@ def _optimization_barrier(arg): optimization_barrier_p = core.Primitive('optimization_barrier') optimization_barrier_p.multiple_results = True optimization_barrier_p.def_impl( - partial(xla.apply_primitive, optimization_barrier_p)) + partial(dispatch.apply_primitive, optimization_barrier_p)) optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) mlir.register_lowering(optimization_barrier_p, _optimization_barrier_lowering_rule) diff --git a/jax/_src/api.py b/jax/_src/api.py index 5b9a9bfe8..8dadd1422 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -79,8 +79,8 @@ from jax.custom_batching import custom_vmap from jax.custom_derivatives import (custom_gradient, custom_jvp, custom_vjp, linear_call) from jax.custom_transpose import custom_transpose -from jax.interpreters import partial_eval as pe -from jax.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.config import ( diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 80ac676d9..369de3845 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -26,8 +26,6 @@ import jax.tree_util as jtu from jax import lax from jax.api_util import flatten_fun from jax.experimental import pjit -from jax.interpreters import mlir -from jax.interpreters import partial_eval as pe from jax.tree_util import tree_flatten from jax.tree_util import tree_map from jax.tree_util import tree_unflatten @@ -42,6 +40,8 @@ from jax._src import traceback_util from jax._src.config import config from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, unzip3, weakref_lru_cache) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 740e67611..2bea7562c 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -18,9 +18,6 @@ from typing import Callable, Optional import jax from jax import tree_util -from jax.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla from jax.tree_util import (tree_flatten, tree_map, tree_structure, tree_unflatten, treedef_tuple) from jax._src import core @@ -33,6 +30,9 @@ from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters.batching import not_mapped +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import xla source_info_util.register_exclusion(__file__) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index d0087b5a9..6b3a9e83a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -21,8 +21,6 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, register_pytree_node_class, tree_leaves) from jax.errors import UnexpectedTracerError -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla from jax.config import config from jax._src import core @@ -38,6 +36,8 @@ from jax._src.core import raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import xla from jax._src.interpreters.batching import not_mapped from jax._src.lax import lax from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 8bb21f4f2..1218bdb3e 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -15,9 +15,6 @@ import functools from typing import Any, Callable, Optional, Tuple -from jax.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla from jax.tree_util import (tree_flatten, tree_leaves, tree_map, tree_structure, treedef_tuple, tree_unflatten) from jax._src import ad_util @@ -29,6 +26,9 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import util from jax._src.interpreters import ad +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import xla source_info_util.register_exclusion(__file__) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 57ce4f3c7..8da760105 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -38,7 +38,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import Sharding from jax._src.sharding_impls import GSPMDSharding, NamedSharding -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe # pytype: disable=import-error try: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 4a46c4510..77b601cdb 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -32,8 +32,6 @@ import numpy as np import jax from jax.monitoring import record_event_duration_secs -import jax.interpreters.mlir as mlir -import jax.interpreters.partial_eval as pe from jax._src import array from jax._src import core @@ -49,6 +47,8 @@ from jax._src import xla_bridge as xb from jax._src.config import config, flags from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.interpreters import pxla from jax._src.lib.mlir import ir diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index bdcf3b961..d6e6cf2b9 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union import jax from jax._src import linear_util as lu -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe from jax.config import config from jax.tree_util import (tree_flatten, tree_unflatten, register_pytree_node, Partial) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 323104692..8db3d47dd 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -35,7 +35,7 @@ from jax._src import linear_util as lu from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache) -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe Array = Any map, unsafe_map = safe_map, map diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 76a9f6fa5..0d50c8231 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -32,9 +32,6 @@ import numpy as np from jax._src import linear_util as lu from jax.config import config -from jax._src.interpreters import ad -from jax.interpreters import partial_eval as pe -from jax._src.interpreters import xla from jax._src import ad_util from jax._src import core from jax._src import device_array @@ -43,6 +40,9 @@ from jax._src import effects as effects_lib from jax._src import source_info_util from jax._src import util from jax._src import xla_bridge as xb +from jax._src.interpreters import ad +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import xla from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py new file mode 100644 index 000000000..60611e55c --- /dev/null +++ b/jax/_src/interpreters/partial_eval.py @@ -0,0 +1,2506 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections import namedtuple +from contextlib import contextmanager, AbstractContextManager +import functools +from functools import partial +import inspect +import itertools as it +import operator as op +from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple, + List, Union, Hashable, Set) +from weakref import ref + +import numpy as np + +from jax._src import linear_util as lu +from jax.config import config +from jax._src import api_util +from jax._src import core +from jax._src import effects +from jax._src import dtypes +from jax._src import profiler +from jax._src import source_info_util +from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, + fun_sourceinfo) +from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, + AbstractValue, ClosedJaxpr, new_jaxpr_eqn, + ConcreteArray, Var, DropVar, raise_to_shaped, Atom, + JaxprEqn, Primitive, ShapedArray, DShapedArray, + mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, + InputType, OutputType, get_referent) +from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, + KeyPath, generate_key_paths, keystr) +from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, + merge_lists, partition_list, OrderedSet, + as_hashable_function, weakref_lru_cache) + + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip +def identity(x): return x + +TracerId = int +AvalId = int +ConstId = int + +def _update_annotation_known( + f: lu.WrappedFun, + orig_type: Optional[InputType], + in_knowns: List[bool] + ) -> lu.WrappedFun: + if orig_type is None: return f + # orig_type might contain DBIdx, but we're tossing out some args so we have to + # re-index. moreover some of the implicit args may not be needed anymore. + # so we basically just re-infer the lambda input type + if (all(e for _, e in orig_type) and + not any(type(d) is DBIdx for a, _ in orig_type for d in a.shape + if type(a) is DShapedArray)): + new_type = [ty for ty, known in zip(orig_type, in_knowns) if known] + return lu.annotate(f, tuple(new_type)) + + # Replace DBIdx with names, prune down to explicit only. + class Name: + def __init__(self, a): self.a = a + names = [Name(a) for a, _ in orig_type] + avals = [a.update(shape=tuple(names[d.val] if type(d) is DBIdx else d # type: ignore + for d in a.shape)) + if type(a) is DShapedArray else a for a, e in orig_type if e] + avals = [a for a, known in zip(avals, in_knowns) if known] + # Figure out the implicit part: names which aren't explicit and known. + expl_names = [o for o, (_, e) in zip(names, orig_type) if e] + expl_names = [o for o, k in zip(expl_names, in_knowns) if k] + expl_names_ = set(expl_names) + impl_names = {d for a in avals if type(a) is DShapedArray for d in a.shape + if type(d) is Name and d not in expl_names_} + impl_part = [(n.a, False) for n in impl_names] # type: ignore + # Figure out the explicit part: known explicit avals, replacing names w/ dbidx + name_map = {n: DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))} + expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape)) + if type(a) is DShapedArray else a, True) for a in avals] + return lu.annotate(f, (*impl_part, *expl_part)) + +class PartialVal(tuple): + """Partial value: either a known value or an unknown (abstract) value. + + Represented as a pair `(aval_opt, const)` of one of two kinds: + * `(None, )` indicates a known value, where the constant is either a + Tracer or satisfies `core.valid_jaxtype(const)`; + * `(, None)` indicates an unknown value characterized by an + abstract value. + """ + def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]): + pv, const = xs + if config.jax_enable_checks: + # type checks + assert isinstance(pv, (AbstractValue, type(None))), xs + assert (const is None or isinstance(const, core.Tracer) or + core.valid_jaxtype(const)), const + # invariant checks + assert (pv is None) ^ (const is None) + return tuple.__new__(cls, xs) + + @classmethod + def known(cls, const: core.Value) -> PartialVal: + return PartialVal((None, const)) + + @classmethod + def unknown(cls, aval: AbstractValue) -> PartialVal: + return PartialVal((aval, None)) + + def is_known(self) -> bool: + return self[0] is None + + def get_known(self) -> Optional[core.Value]: + """Get the known value, if known, else None.""" + return self[1] if self[0] is None else None + + def get_aval(self) -> AbstractValue: + """Get AbstractValue directly (if unknown) or from the constant (known).""" + known = self.get_known() + if known is not None: + return get_aval(known) + else: + return self[0] + + +class JaxprTrace(Trace['JaxprTracer']): + + def __init__(self, *args, name_stack: source_info_util.NameStack): + super().__init__(*args) + self.name_stack = name_stack + + def pure(self, val: Any) -> JaxprTracer: + return self.new_const(val) + + def lift(self, val: Tracer) -> JaxprTracer: + return self.new_const(val) + + def sublift(self, val: JaxprTracer) -> JaxprTracer: + return JaxprTracer(self, val.pval, FreeVar(val)) + + def new_const(self, val) -> JaxprTracer: + if isinstance(val, Tracer) and val._trace.level == self.level: + raise Exception + return JaxprTracer(self, PartialVal.known(val), None) + + def new_instantiated_literal(self, val) -> JaxprTracer: + aval = get_aval(val) + return JaxprTracer(self, PartialVal.unknown(aval), + Literal(val, raise_to_shaped(aval))) + + def new_instantiated_const(self, val) -> JaxprTracer: + aval = get_aval(val) + if isinstance(aval, DShapedArray): + shape = [self.new_instantiated_const(d) + if isinstance(d, Tracer) and d._trace.level < self.level else d + for d in aval.shape] + aval = aval.update(shape=tuple(shape)) + return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(val)) + + def new_arg(self, pval: PartialVal) -> JaxprTracer: + const = pval.get_known() + # XXX: Think twice before changing this constant argument pruning! + # This has really important consequences for partial_eval_jaxpr. + # Most importantly, this guarantees that the unknown jaxpr never uses + # known inputs (if it needs them, then they get passed through residuals). + if const is None: + aval = pval.get_aval() + if type(aval) is DShapedArray: + shape = [self.new_instantiated_const(d) + if isinstance(d, Tracer) and d._trace.level < self.level else d + for d in aval.shape] + aval = aval.update(shape=tuple(shape)) + return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding()) + else: + return self.new_const(const) + + def instantiate_const(self, tracer: JaxprTracer) -> JaxprTracer: + const = tracer.pval.get_known() + if const is None: + return tracer + else: + if type(const) in core.literalable_types and np.shape(const) == (): + return self.new_instantiated_literal(const) + else: + return self.new_instantiated_const(const) + + def instantiate_const_abstracted(self, tracer) -> JaxprTracer: + const = tracer.pval.get_known() + if const is None: + return tracer + else: + aval = raise_to_shaped(get_aval(const), np.isscalar(const)) + return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) + + def process_primitive(self, primitive, tracers, params): + if primitive in custom_partial_eval_rules: + return custom_partial_eval_rules[primitive](self, *tracers, **params) + else: + return self.default_process_primitive(primitive, tracers, params) + + def default_process_primitive(self, primitive, tracers, params): + # By default, if all the input tracers are known, then bind the primitive + # and consider all outputs known. Otherwise, stage the application into the + # jaxpr and consider all outputs unknown. + consts = [t.pval.get_known() for t in tracers] + if all(c is not None for c in consts): + return primitive.bind(*consts, **params) + tracers = map(self.instantiate_const, tracers) + avals = [t.aval for t in tracers] + out_aval, effects = primitive.abstract_eval(*avals, **params) + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) + if primitive.multiple_results: + out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) + for aval in out_aval] + eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, source) + for t in out_tracers: t.recipe = eqn + return out_tracers + else: + out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None) + out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, + params, effects, source) + return out_tracer + + def process_call(self, primitive, f, tracers, params): + rule = call_partial_eval_rules.get(primitive) + if rule: + return rule(self, primitive, f, tracers, params) + + update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) + in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers]) + # TODO(mattjj): check in_avals are consistent with f.in_type + + # We want to partially evaluate this call into two calls: one evaluated now + # taking known values (in_consts) as inputs and producing known values + # (out_consts) as outputs, and the other staged out as an eqn into the jaxpr + # being built. The latter takes as input residuals (res) produced as outputs + # of the first call, shared closed-over values (env), and explicit arguments + # which were unknown to the first call (corresponding to in_avals). + + # Wrap f to perform the partial evaluation and plumb out aux data. + if not config.jax_dynamic_shapes: + f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) + f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), + tuple(in_avals)) + else: + if f.in_type is None: + f = lu.annotate(f, tuple((a, True) for a in in_avals)) + f_, aux = trace_to_subjaxpr_nounits_dyn(f, self.main, tuple(in_knowns), + f.in_type, False) + # Adjust parameters (e.g. donated_invars) for the call to be evaluated now. + const_params = update_params(params, in_knowns, 0) + + # Run the call, getting known out vals and aux data used for staged-out call + out = primitive.bind(_update_annotation_known(f_, f.in_type, in_knowns), + *in_consts, **const_params) + fwds, out_knowns, out_type, jaxpr, env = aux() + # Split apart known outputs from the original call and non-fwded residuals. + out_consts, non_fwd_res_ = split_list(out, [sum(out_knowns)]) + + # Form the complete list of residuals by forwarding some inputs. + if config.jax_dynamic_shapes: + # With dynamic shapes, we may need to forward implicit arguments. + in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns) + in_consts_full = [None] * len(f.in_type) + for idx, (aval, explicit) in enumerate(f.in_type): + if explicit and next(in_knowns_): + c = in_consts_full[idx] = next(in_consts_) + if aval.shape: + for d1, d2 in zip(aval.shape, c.shape): + if type(d1) is DBIdx: + in_consts_full[d1.val] = d2 + else: + in_consts_full = in_consts + non_fwd_res = iter(non_fwd_res_) + res = [next(non_fwd_res) if i is None else in_consts_full[i] for i in fwds] + sentinel = object() + assert next(non_fwd_res, sentinel) is sentinel + + # Create the input tracers for the staged-out (unknown-value) call. + res_tracers = map(self.instantiate_const, map(self.new_const, res)) + env_tracers = map(self.full_raise, env) + unknown_arg_tracers = [t for t in tracers if not t.is_known()] + # Adjust parameters (e.g. donated_invars) for the staged-out call's args. + num_new_args = len(res_tracers) + len(env_tracers) + staged_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) + staged_params = update_params(staged_params, map(op.not_, in_knowns), + num_new_args) + # The outputs of the staged-out call are Tracers with the new eqn as recipe. + if config.jax_dynamic_shapes: + # With dynamic shapes, we may need to substitute Tracers into avals. + out_tracers = [] + for aval, _ in out_type: + assert not isinstance(aval, ConcreteArray) + if type(aval) is DShapedArray: + shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val] + if type(d) is InDBIdx else d for d in aval.shape] + aval = aval.update(shape=tuple(shape)) + out_tracers.append(JaxprTracer(self, PartialVal.unknown(aval), None)) + else: + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_type] + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) + eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers), + out_tracers, primitive, staged_params, jaxpr.effects, + source) + for t in out_tracers: t.recipe = eqn + return merge_lists(out_knowns, out_tracers, out_consts) + + def process_map(self, primitive, f: lu.WrappedFun, tracers, params): + update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) + in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers]) + + # This method is like process_call above, except: + # 1. we delete an axis from mapped-over input avals' shapes, and + # analogously add an axis to mapped-over output avals' shapes; + # 2. we update the in_axes and out_axes/out_axes_thunk parameters to + # reflect the inputs and outputs pruned from the unknown/known sides. + + # Map (delete an axis from) unknown inputs' avals as dictated by in_axes. + unk_in_axes, const_in_axes = partition_list(in_knowns, params['in_axes']) + in_avals_mapped = [mapped_aval(params['axis_size'], ax, aval) + for ax, aval in zip(unk_in_axes, in_avals)] + + # Wrap f to perform partial evaluation and plumb out aux data. + f = trace_to_subjaxpr_nounits(f, self.main, False) + f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), + tuple(in_avals_mapped)) + # Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk) + const_params = update_params(params, in_knowns, 0) # handles donated_invars + out_axes_thunk = params['out_axes_thunk'] + @as_hashable_function(closure=out_axes_thunk) + def const_out_axes_thunk(): + out_knowns, _, jaxpr, _ = aux() + _, out_axes = partition_list(out_knowns, out_axes_thunk()) + return tuple(out_axes) + (0,) * len(jaxpr.constvars) # res mapped axis 0 + const_params = dict(const_params, in_axes=tuple(const_in_axes), + out_axes_thunk=const_out_axes_thunk) + + # Run the map, getting known out vals and aux data used for staged-out map. + out = primitive.bind(f, *in_consts, **const_params) + out_knowns, out_avals_mapped, jaxpr, env = aux() + # Split apart known outputs from the original call and residuals. + out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) + + # We can only check_jaxpr with the dynamic axis environment extended: + with core.extend_axis_env(params['axis_name'], params['axis_size'], None): + call_jaxpr = convert_constvars_jaxpr(jaxpr) + + # Compute staged and const out_axes, taking into account residuals. + out_axes = params['out_axes_thunk']() + staged_out_axes, _ = partition_list(out_knowns, out_axes) + staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,) + + # Create the input tracers for the staged-out (unkonwn-value) call. + const_tracers = map(self.new_instantiated_const, res) + env_tracers = map(self.full_raise, env) + unknown_arg_tracers = [t for t in tracers if not t.is_known()] + # Adjust params for staged-out call on unknown values. + num_new_args = len(const_tracers) + len(env_tracers) + staged_params = update_params(params, map(op.not_, in_knowns), num_new_args) + staged_params = dict(staged_params, in_axes=staged_in_axes, + out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr) + del staged_params['out_axes_thunk'] + # The outputs of the staged-out call are Tracers with the new eqn as recipe. + out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a) + for ax, a in zip(staged_out_axes, out_avals_mapped)] + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_avals] + eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type] + out_tracers, primitive, staged_params, + jaxpr.effects, + source_info_util.current()) + for t in out_tracers: t.recipe = eqn + + return merge_lists(out_knowns, out_tracers, out_consts) + + def post_process_call(self, primitive, out_tracers, params): + unknown_out_tracers = [t for t in out_tracers if not t.is_known()] + jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers) + out_pvals = [t.pval for t in out_tracers] + out_knowns, out_avals, out_consts = partition_pvals(out_pvals) + out = [*out_consts, *res] + main = self.main + + def todo(out): + trace = main.with_cur_sublevel() + out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) + const_tracers = map(trace.new_instantiated_const, res) + in_tracers = (*const_tracers, *map(trace.full_raise, env)) + out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) + for a in out_avals] + update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) + new_params = update_params(params, [], len(in_tracers)) + new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) + eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, + jaxpr.effects, source) + for t in out_tracers: t.recipe = eqn + return merge_lists(out_knowns, out_tracers, out_consts) + + return out, todo + + def post_process_map(self, primitive, out_tracers, params): + unknown_out_tracers = [t for t in out_tracers if not t.is_known()] + jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers) + out_pvals = [t.pval for t in out_tracers] + out_knowns, out_avals_mapped, out_consts = partition_pvals(out_pvals) + out = [*out_consts, *res] + main = self.main + + with core.extend_axis_env(params['axis_name'], params['axis_size'], None): + call_jaxpr = convert_constvars_jaxpr(jaxpr) + + def todo(out): + trace = main.with_cur_sublevel() + out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) + const_tracers = map(trace.new_instantiated_const, res) + env_tracers = map(trace.full_raise, env) + + staged_out_axes = tuple(out_axes_unknown) # set by out_axes_transform + staged_in_axes = (0,) * len(res) + (None,) * len(env) + + update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) + staged_params = update_params(params, [], len(res) + len(env)) + staged_params = dict(staged_params, in_axes=staged_in_axes, + out_axes=tuple(staged_out_axes), + call_jaxpr=call_jaxpr) + + out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], d, a) + for d, a in zip(staged_out_axes, out_avals_mapped)] + out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) + for a in out_avals] + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) + eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, + primitive, staged_params, jaxpr.effects, source) + for t in out_tracers: t.recipe = eqn + return merge_lists(out_knowns, out_tracers, out_consts) + + def out_axes_transform(out_axes): + nonlocal out_axes_unknown + out_axes_unknown, out_axes_known = partition_list(out_knowns, out_axes) + return tuple(out_axes_known) + (0,) * len(jaxpr.constvars) + out_axes_unknown: Optional[list] = None + + return out, (todo, out_axes_transform) + + def _current_truncated_name_stack(self): + return source_info_util.current_name_stack()[len(self.name_stack):] + + def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + # We assume partial evaluation is only performed to build linear functions, + # and hence we don't need to keep the custom JVP rule around anymore. + del jvp, symbolic_zeros + assert not all(t.is_known() for t in tracers) + return fun.call_wrapped(*tracers) + + def post_process_custom_jvp_call(self, out_tracers, _): + # This path should only be reachable if we expose a partial eval API + # unrelated to autodiff, since we raise an error when differentiation with + # respect to values over which a custom_jvp function closes is detected. + raise NotImplementedError # TODO(mattjj) + + def process_custom_transpose(self, prim, call, tracers, **params): + res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves]) + assert all(t.is_known() for t in res_ts) + lin_all_known = all(t.is_known() for t in lin_ts) + if lin_all_known: + res_cvals = [t.pval[1] for t in res_ts] + lin_cvals = [t.pval[1] for t in lin_ts] + return prim.bind(call, *res_cvals, *lin_cvals, **params) + else: + out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) + for aval in params['out_types']] + in_tracers = map(self.instantiate_const, tracers) + new_params = dict(params, call=call) + eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params, + core.no_effects, source_info_util.current()) + for t in out_tracers: t.recipe = eqn + return out_tracers + + def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees): + # TODO(mattjj): after old remat is deleted, make this method trivial. + # Because we instantiate all tracers, in_knowns is all False. + tracers = map(self.instantiate_const_abstracted, tracers) + in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) + f = trace_to_subjaxpr_nounits(f, self.main, True) + f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) + out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees) + out_knowns, out_avals, jaxpr, env = aux() + out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) + res_tracers = map(self.new_instantiated_const, res) + env_tracers = map(self.full_raise, env) + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_avals] + closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) + + @_memoize + def fwd_jaxpr_thunk(): + fwd_ = trace_to_subjaxpr_nounits(fwd, self.main, True) + fwd_, aux = partial_eval_wrapper_nounits( + fwd_, tuple(in_knowns), tuple(in_avals)) + with core.new_sublevel(): + out_flat = fwd_.call_wrapped() + out_knowns, out_avals, jaxpr, env = aux() + _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) + converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) + return converted_jaxpr, (*res, *env) + + name_stack = self._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) + eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers), + out_tracers, prim.initial_style, + dict(fun_jaxpr=closed_jaxpr, + fwd_jaxpr_thunk=fwd_jaxpr_thunk, + num_consts=len(res) + len(env), + bwd=bwd, out_trees=out_trees), + jaxpr.effects, source) + for t in out_tracers: t.recipe = eqn + return merge_lists(out_knowns, out_tracers, out_consts) + + def post_process_custom_vjp_call(self, out_tracers, _): + # This path should only be reachable if we expose a partial eval API + # unrelated to autodiff, since we raise an error when differentiation with + # respect to values over which a custom_vjp function closes is detected. + raise NotImplementedError # TODO(mattjj) + +def partition_pvals( + pvals: List[PartialVal] + ) -> Tuple[List[bool], List[AbstractValue], List[Any]]: + knowns = [pval.is_known() for pval in pvals ] + avals = [pval.get_aval() for pval in pvals if not pval.is_known()] + consts = [pval.get_known() for pval in pvals if pval.is_known()] + return knowns, avals, consts + +@lu.transformation_with_aux +def partial_eval_wrapper_nounits( + in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], + *in_consts: Any): + in_avals_, in_consts_ = iter(in_avals), iter(in_consts) + in_pvals = [PartialVal.known(next(in_consts_)) if known else + PartialVal.unknown(next(in_avals_)) for known in in_knowns] + sentinel = object() + assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel + jaxpr, (*maybe_fwds, out_pvals, res, env) = yield (in_pvals,), {} + out_knowns, out_avals, out_consts = partition_pvals(out_pvals) + yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env) + +@lu.transformation_with_aux +def trace_to_subjaxpr_nounits_dyn( + main: core.MainTrace, in_knowns: Sequence[bool], in_type: InputType, + instantiate: Union[bool, Sequence[bool]], + *in_consts: Any): + trace = main.with_cur_sublevel() + in_avals, which_explicit = unzip2(in_type) + + # To form input tracers from in_type, we need to first build ConstVar tracers + # for all axis sizes, so that we can then use those tracers in the shapes of + # avals for unknown inputs' tracers. We use ConstVar recipes for on-the-fly + # type agreement checking via get_referent. + in_consts_full: List[Optional[JaxprTracer]] = [None] * len(in_type) + in_consts_iter, in_knowns_iter = iter(in_consts), iter(in_knowns) + for idx, (aval, explicit) in enumerate(in_type): + if explicit and next(in_knowns_iter): + constval = next(in_consts_iter) + if isinstance(aval, DShapedArray): + for i, d in enumerate(aval.shape): + if isinstance(d, DBIdx): + if in_consts_full[d.val] is None: + in_consts_full[d.val] = \ + JaxprTracer(trace, PartialVal.unknown(in_avals[d.val]), + ConstVar(constval.shape[i])) + assert core.same_referent(constval.shape[i], in_consts_full[d.val]) + shape = [in_consts_full[d.val] if type(d) is DBIdx else d # type: ignore + for d in aval.shape] + aval = aval.update(shape=tuple(shape)) + in_consts_full[idx] = JaxprTracer(trace, PartialVal.unknown(aval), + ConstVar(constval)) + # Check that we covered all axis sizes with ConstVar tracers. + for idx, (aval, explicit) in enumerate(in_type): + if not explicit: assert in_consts_full[idx] is not None + if isinstance(aval, DShapedArray): + assert all(type(d) is not DBIdx or in_consts_full[d.val] is not None # type: ignore + for d in aval.shape) + + # Next, build tracers for all unknown inputs, using the in_consts_full list + # for axis size tracers when necessary. + in_tracers = [] + in_knowns_iter = iter(in_knowns) + for aval, explicit in in_type: + if explicit and not next(in_knowns_iter): + if isinstance(aval, DShapedArray): + shape = [in_consts_full[d.val] if type(d) is DBIdx else d # type: ignore + for d in aval.shape] + aval = aval.update(shape=tuple(shape)) + tracer = JaxprTracer(trace, PartialVal.unknown(aval), LambdaBinding()) + in_tracers.append(tracer) + + # Merge in_consts and in_tracers and call wrapped fn with explicit arguments. + in_args = merge_lists(in_knowns, in_tracers, in_consts) + ans = yield in_args, {} + + # Instantiate outputs and build jaxpr. + if isinstance(instantiate, bool): + instantiate = [instantiate] * len(ans) + out_tracers = map(trace.full_raise, map(core.full_lower, ans)) + out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t + for inst, t in zip(instantiate, out_tracers)] + + # Collect known outputs. + out_knowns: List[bool] = [t.is_known() for t in out_tracers] + out_consts: List[Any] = [t.pval.get_known() for t in out_tracers + if t.is_known()] + + # Build the jaxpr. + out_tracers = [t for t in out_tracers if not t.is_known()] + jaxpr, res, env = tracers_to_jaxpr(in_tracers, out_tracers) + out_avals = [v.aval for v in jaxpr.outvars] + idx_map = {v: InDBIdx(i) + for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))} + out_type = [(a.update(shape=tuple(idx_map.get(d, d) for d in a.shape)) # type: ignore + if type(a) is DShapedArray else a, True) for a in out_avals] + + # Which residuals are just forwarded inputs? Check obj id, then prune. + id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore + if c is not None} + fwds: List[Optional[int]] = [id_map.get(id(c)) for c in res] + res = tuple([c for c, fwd in zip(res, fwds) if fwd is None]) + + del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \ + in_tracers, in_args, ans, out_tracers, out_avals + yield (*out_consts, *res), (fwds, out_knowns, tuple(out_type), jaxpr, env) + + +custom_partial_eval_rules: Dict[Primitive, Callable] = {} +call_partial_eval_rules: Dict[Primitive, Callable] = {} +call_param_updaters: Dict[Primitive, Callable] = {} + +def _closed_call_param_updater(params, _, __): + jaxpr = params.get('call_jaxpr') + if jaxpr is None: return params + assert type(jaxpr) is core.Jaxpr + return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ())) +call_param_updaters[core.closed_call_p] = _closed_call_param_updater + +def abstract_eval_fun(fun, *avals, debug_info=None, **params): + _, avals_out, _ = trace_to_jaxpr_dynamic( + lu.wrap_init(fun, params), avals, debug_info) + assert all(isinstance(aval, AbstractValue) for aval in avals_out) + return avals_out + + +JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar', + 'ConstVar', Literal] + +class JaxprTracer(Tracer): + __slots__ = ['pval', 'recipe'] + + def __init__(self, trace: JaxprTrace, pval: PartialVal, + recipe: Optional[JaxprTracerRecipe]): + assert isinstance(pval, PartialVal) + pv, const = pval + if isinstance(const, Tracer) and const._trace.level >= trace.level: + raise core.escaped_tracer_error( + const, f"Tracer from a higher level: {const} in trace {trace}") + if isinstance(pv, DShapedArray): + assert all(not isinstance(d, Tracer) or isinstance(d, JaxprTracer) and + d._trace.level == trace.level for d in pv.shape) + self._trace = trace + self.pval = pval + self.recipe = recipe + + def __repr__(self): + return f'Traced<{self.aval}:{self._trace}>' + + @property + def aval(self) -> AbstractValue: + return self.pval.get_aval() + + @property + def parents(self) -> Sequence[JaxprTracer]: + if isinstance(self.recipe, JaxprEqnRecipe): + # TODO broadcast_in_dim can create a new tracer... + return self.recipe.in_tracers + elif isinstance(self.aval, DShapedArray): + return [d for d in self.aval.shape if isinstance(d, JaxprTracer)] + else: + return [] + + def full_lower(self): + known = self.pval.get_known() + if known is not None: + return core.full_lower(known) + else: + return self + + def is_known(self): + return self.pval.is_known() + + def get_referent(self): + if self.pval.is_known(): + return get_referent(self.pval.get_known()) + elif isinstance(self.recipe, (FreeVar, ConstVar, Literal)): + return get_referent(self.recipe.val) # pytype: disable=attribute-error + else: + return self + + +@profiler.annotate_function +def trace_to_jaxpr( + fun: lu.WrappedFun, pvals: Sequence[PartialVal], + instantiate: Union[bool, Sequence[bool]] = False, + ) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]: + """ + Partially evaluate a function, building a jaxpr for un-evaluated computation. + + Args: + fun: lu.WrappedFun representing the function to be partially evaluated. The + function must be flattened, in the sense of accepting jaxpr type arguments + and returning a flat list of jaxpr type outputs. + pvals: sequence of PartialVals of length equal to the number of inputs to + `fun` indicating which inputs are known or unknown. + instantiate: optional bool or sequence of bools of length equal to the + number of outputs of `fun` indicating which outputs should be forced to be + treated as unknown and hence instantiated in the jaxpr. If a single bool, + the value is applied to all outputs. Default False. + + Returns: + A triple where the first element is a jaxpr representing the computation + which depends on unknown inputs; the second element is a list of PartialVals + of length equal to the length of the output of `fun` representing which + outputs are known and unknown (along with their values and abstract values, + respectively); the third element is a list of known residual values. The + returned jaxpr takes as inputs the known residual values followed by values + of the originally unknown inputs. + """ + current_name_stack = source_info_util.current_name_stack() + with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: + fun = trace_to_subjaxpr(fun, main, instantiate) + jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) + assert not env + del main, fun, env + + return jaxpr, out_pvals, consts + +@profiler.annotate_function +def trace_to_jaxpr_nounits( + fun: lu.WrappedFun, pvals: Sequence[PartialVal], + instantiate: Union[bool, Sequence[bool]] = False, + ) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]: + current_name_stack = source_info_util.current_name_stack() + with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: + fun = trace_to_subjaxpr_nounits(fun, main, instantiate) + jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) + assert not env + del main, fun, env + return jaxpr, out_pvals, consts + + +@lu.transformation +def trace_to_subjaxpr_nounits( + main: core.MainTrace, + instantiate: Union[bool, Sequence[bool]], + in_pvals: Sequence[PartialVal]): + assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals + out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( + main, instantiate, in_pvals) + out_pvals = [t.pval for t in out_tracers] + del out_tracers + yield jaxpr, (out_pvals, out_consts, env) + +def _trace_to_subjaxpr_nounits(main, instantiate, in_pvals): + trace = main.with_cur_sublevel() + in_knowns = [pval.is_known() for pval in in_pvals] + in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] + in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] + in_args = merge_lists(in_knowns, in_tracers, in_consts) + ans = yield in_args, {} + assert isinstance(ans, (list, tuple)), ( + f"Got unexpected return type when tracing function to jaxpr: {ans}") + assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( + f"Got unexpected return type when tracing function to jaxpr: {ans}") + if isinstance(instantiate, bool): + instantiate = [instantiate] * len(ans) + out_tracers = map(trace.full_raise, map(core.full_lower, ans)) + out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t + for inst, t in zip(instantiate, out_tracers)] + out_tracers_ = [t for t in out_tracers if not t.is_known()] + jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_) + return out_tracers, jaxpr, out_consts, env + +# The below variant implements an optimization where residuals which are also +# inputs are indicated in auxiliary data rather than passed as outputs. +# TODO(mattjj): update all callers to use this version, delete other version. +@lu.transformation +def trace_to_subjaxpr_nounits_fwd( + main: core.MainTrace, + instantiate: Union[bool, Sequence[bool]], + in_pvals: Sequence[PartialVal]): + assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals + out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( + main, instantiate, in_pvals) + out_pvals = [t.pval for t in out_tracers] + + # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. + in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] + id_map = {id(c): i for i, c in enumerate(in_consts)} + fwds: List[Optional[int]] = [id_map.get(id(c)) for c in out_consts] + pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] + + del out_tracers + yield jaxpr, (fwds, out_pvals, pruned_consts, env) + + +FreeVar = namedtuple('FreeVar', ['val']) +ConstVar = namedtuple('ConstVar', ['val']) +LambdaBinding = namedtuple('LambdaBinding', []) +class JaxprEqnRecipe(NamedTuple): + eqn_id: Any + in_tracers: Sequence[JaxprTracer] + out_tracer_refs: Sequence[ref[JaxprTracer]] + out_avals: Sequence[core.AbstractValue] + primitive: Primitive + params: Dict[str, Any] + effects: core.Effects + source_info: source_info_util.SourceInfo + +def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], + out_tracers: Sequence[JaxprTracer], + primitive: Primitive, + params: Dict[str, Any], + effects: core.Effects, + source_info: source_info_util.SourceInfo + ) -> JaxprEqnRecipe: + # TODO(necula): move these checks to core.check_jaxpr, and call in more places + if primitive.call_primitive or primitive.map_primitive: + assert "call_jaxpr" in params + assert ("donated_invars" not in params or + len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) + if primitive.map_primitive: + assert ("in_axes" in params and + len(params["in_axes"]) == len(params["call_jaxpr"].invars)) + assert ("donated_invars" in params and + len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) + out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] + return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), + out_avals, primitive, params, effects, source_info) + + +def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], + recipe: JaxprEqnRecipe) -> core.JaxprEqn: + (_, in_tracers, out_tracer_refs, out_avals, prim, params, eff, src) = recipe + invars = [getvar(t) for t in in_tracers] + out_tracers = [t_ref() for t_ref in out_tracer_refs] + outvars = [DropVar(a) if t is None else getvar(t) # type: ignore + for a, t in zip(out_avals, out_tracers)] + return new_jaxpr_eqn(invars, outvars, prim, params, eff, src) + +def tracers_to_jaxpr( + in_tracers: Sequence[JaxprTracer], + out_tracers: Sequence[JaxprTracer] + ) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]: + """Constructs Jaxpr given tracers for inputs and outputs. + + Params: + in_tracers: the tracers that were created for the function inputs + out_tracers: the tracers that were output by the function. + + Returns: a triple of a `Jaxpr`, a list of constant values corresponding to + the `constvars` in the returned Jaxps, and a list of environment values. + The vars for the environment values have been prepended to the Jaxpr's + `invars`. + """ + gensym = core.gensym() + + t_to_var: Dict[TracerId, Var] = {} + consts: Dict[Var, Any] = {} + env: Dict[Var, JaxprTracer] = {} + constid_to_var: Dict[ConstId, Var] = {} # for deduplication + + def get_atom(t: JaxprTracer) -> Atom: + return t.recipe if type(t.recipe) is Literal else t_to_var[id(t)] + + def newvar(t: Optional[JaxprTracer]) -> Var: + assert t is not None + var = gensym(type_substitute(t.aval)) + var_ = t_to_var.setdefault(id(t), var) + assert var is var_ + return var + + def type_substitute(aval: AbstractValue) -> AbstractValue: + if isinstance(aval, DShapedArray): + # Replace any Tracers in aval.shape with Vars or Literal values + shape = [get_atom(d) if type(d) is JaxprTracer else d for d in aval.shape] + shape = [d.val if type(d) is Literal else d for d in shape] + aval = aval.update(shape=tuple(shape)) + return aval + + processed_eqn_ids = set() + eqns: List[core.JaxprEqn] = [] + for t in toposort([*in_tracers, *out_tracers]): + r = t.recipe + if isinstance(r, JaxprEqnRecipe): + # TODO broadcast_in_dim can create a new tracer, not present in parents + if r.eqn_id not in processed_eqn_ids: + in_atoms = map(get_atom, r.in_tracers) + outvars = [DropVar(type_substitute(a)) if rf() is None else newvar(rf()) + for a, rf in zip(r.out_avals, r.out_tracer_refs)] + eqns.append(new_jaxpr_eqn(in_atoms, outvars, r.primitive, r.params, + r.effects, r.source_info)) + processed_eqn_ids.add(r.eqn_id) + elif isinstance(r, LambdaBinding): + if not any(t is in_tracer for in_tracer in in_tracers): + raise core.escaped_tracer_error(t, f"Tracer not in input tracers: {t}") + newvar(t) + elif isinstance(r, ConstVar): + var = constid_to_var.get(id(r.val)) + if var is None: + var = constid_to_var[id(r.val)] = newvar(t) + consts[var] = r.val + t_to_var[id(t)] = var + elif isinstance(r, FreeVar): + env[newvar(t)] = r.val # type: ignore + elif isinstance(r, Literal): + pass + elif r is None: + assert False + else: + raise TypeError(r) + + env_vars, env_vals = unzip2(env.items()) + invars = [*env_vars, *map(get_atom, in_tracers)] + const_vars, const_vals = unzip2(consts.items()) + outvars = map(get_atom, out_tracers) # type: ignore[arg-type] + jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns) + jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type] + outvars, eqns, jaxpr_effects) + config.jax_enable_checks and core.check_jaxpr(jaxpr) + # del getvar # needed to avoid cyclic-reference closure, apparently! + return jaxpr, const_vals, env_vals + +@weakref_lru_cache +def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: + """Moves the constvars to the start of invars.""" + config.jax_enable_checks and core.check_jaxpr(jaxpr) + dbg = jaxpr.debug_info and jaxpr.debug_info._replace( + arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) + lifted_jaxpr = Jaxpr(constvars=(), + invars=jaxpr.constvars + jaxpr.invars, + outvars=jaxpr.outvars, eqns=jaxpr.eqns, + effects=jaxpr.effects, debug_info=dbg) + config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) + return lifted_jaxpr + +@weakref_lru_cache +def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr: + """Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr.""" + if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects): + raise NotImplementedError + config.jax_enable_checks and core.check_jaxpr(jaxpr) + constvars, invars = split_list(jaxpr.invars, [n]) + dbg = jaxpr.debug_info and jaxpr.debug_info._replace( + arg_names=jaxpr.debug_info.arg_names[n:]) + lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars, + debug_info=dbg) + config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) + return lifted_jaxpr + +def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr: + if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects): + raise NotImplementedError + config.jax_enable_checks and core.check_jaxpr(jaxpr) + env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) + converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, + invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, + effects=jaxpr.effects) + config.jax_enable_checks and core.check_jaxpr(converted_jaxpr) + return converted_jaxpr + + +def partial_eval_jaxpr_nounits( + jaxpr: ClosedJaxpr, unknowns: Sequence[bool], + instantiate: Union[bool, Sequence[bool]], + ) -> Tuple[ClosedJaxpr, ClosedJaxpr, List[bool], List[AbstractValue]]: + """Unzip a jaxpr in two by data dependence into 'known' and 'unknown' parts. + + That is, given a jaxpr and a sequence of booleans indicating which jaxpr + inputs (i.e. invars) are considered unknown, produce two jaxprs, a list of + booleans representing which of the original jaxpr's outputs are unknown (i.e. + have a data dependence on an unknown input), and a list of abstract values + representing residuals (part of the first jaxpr's output and the second + jaxpr's input). The two jaxprs result from partitioning the original jaxpr's + first-order primitive applications based on whether all the inputs to the + application are known (in which case the application is represented in the + 'known' jaxpr and its result is considered known) or whether any inputs to the + application are unknown (in which case the application is represented in the + 'unknown' jaxpr and its result is considered unknown). Higher-order primitives + are recursively unzipped in two. + + The `instantiate` argument can be used to ensure some outputs are lifted into + the 'unknown' jaxpr. + + For example, give an input jaxpr: + + { lambda ; a:f32[] b:f32[]. let + c:f32[] = cos a + d:f32[] = sin a + e:f32[] = neg d + f:f32[] = mul e b + in (c, f) } + + then applying this function with `unknowns=[False, True]` and + `instantiate=False` produces as an output triple: + + # jaxpr_known + { lambda ; a:f32[]. let + b:f32[] = cos a + c:f32[] = sin a + d:f32[] = neg c + in (b, d) } + + # jaxpr_unknown + { lambda ; a:f32[] b:f32[]. let c:f32[] = mul b a in (c,) } + + # out_unknowns + [False, True] + + Notice in particular that the first output (jaxpr_known) contains all the + primitive applications which do not have a data dependence on an unknown + input. Also notice the input and output types: the input type of the first + jaxpr produced represents the type of the known inputs of the original jaxpr, + and the output type of the second jaxpr produced represents the type of the + unknown outputs of the original jaxpr. + + In the above example, the output of jaxpr_known named `d` is a _residual_ + output, and corresponds to the input named `a` in jaxpr_unknown. In general, + jaxpr_known will produce extra outputs (at the end of its output list) + corresponding to intermeidate values of the original jaxpr which must be + passed to jaxpr_unknown (as leading inputs). + """ + instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate) + +@weakref_lru_cache +def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate): + f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) + + cell = [] + def fun(*known_vals_in): + known_vals_in = iter(known_vals_in) + unknown_avals = (a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk) + in_pvals = [PartialVal.unknown(next(unknown_avals)) if uk + else PartialVal.known(next(known_vals_in)) for uk in in_unknowns] + assert next(known_vals_in, None) is next(unknown_avals, None) is None + jaxpr_unknown_, out_pvals, residuals = trace_to_jaxpr_nounits( + f, in_pvals, instantiate=instantiate) + jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) + out_unknowns = [not pval.is_known() for pval in out_pvals] + res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals] + cell.append((out_unknowns, jaxpr_unknown, res_avals)) + known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] + return [*known_vals_out, *residuals] + + known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk] + jaxpr_known, _, consts_known = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals) + (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking + + # check jaxpr_known and jaxpr_unknown in isolation + # TODO(mattjj): enable weak type checking here + if config.jax_enable_checks: + core.check_jaxpr(jaxpr_known) + core.check_jaxpr(jaxpr_unknown) + # check jaxpr_known has input type corresponding to known inputs of jaxpr + assert ([v.aval for v in jaxpr_known.invars] == + [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]) + # check jaxpr_known has out type corresponding to known outs of jaxpr plus res + assert ([v.aval.strip_weak_type() for v in jaxpr_known.outvars] == + [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) + if not uk] + [a.strip_weak_type() for a in res_avals]) + # check jaxpr_unknown has input type corresponding to res plus unknown inputs + assert ([v.aval.strip_weak_type() for v in jaxpr_unknown.invars] == + [a.strip_weak_type() for a in res_avals] + + [a.strip_weak_type() for a, uk in zip(jaxpr.in_avals, in_unknowns) + if uk]) + # check jaxpr_unknown has output type corresponding to unknown outputs + assert ([v.aval.strip_weak_type() for v in jaxpr_unknown.outvars] == + [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) + if uk]) + + closed_jaxpr_known = ClosedJaxpr(jaxpr_known, consts_known) + closed_jaxpr_unknown = ClosedJaxpr(jaxpr_unknown, ()) + return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals + + +def partial_eval_jaxpr_custom( + jaxpr: Jaxpr, + in_unknowns: Sequence[bool], + in_inst: Union[bool, Sequence[bool]], + ensure_out_unknowns: Union[bool, Sequence[bool]], + ensure_out_inst: Union[bool, Sequence[bool]], + saveable: Callable[..., bool], + ) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]: + if type(in_inst) is bool: + in_inst = (in_inst,) * len(jaxpr.invars) + if type(ensure_out_unknowns) is bool: + ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars) + if type(ensure_out_inst) is bool: + ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars) + return _partial_eval_jaxpr_custom_cached( + jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns), + tuple(ensure_out_inst), saveable) + +@weakref_lru_cache +def _partial_eval_jaxpr_custom_cached( + jaxpr: Jaxpr, + in_unknowns: Tuple[bool, ...], + in_inst: Tuple[bool, ...], + ensure_out_unknowns: Tuple[bool, ...], + ensure_out_inst: Tuple[bool, ...], + saveable: Callable[..., bool], + ) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]: + env: Dict[Var, Tuple[bool, bool]] = {} + residuals: OrderedSet[Var] = OrderedSet() + + def read(x: Atom) -> Tuple[bool, bool]: + if type(x) is Var: + return env[x] + return (False, True) + + def write(unk: bool, inst: bool, v: Var) -> None: + assert (unk, inst) != (True, False) + env[v] = (unk, inst) + + def ensure_instantiated(inst: bool, x: Atom) -> Atom: + if type(x) is Var and not inst: + residuals.add(x) + return x + + known_eqns, staged_eqns = [], [] + map(write, in_unknowns, in_inst, jaxpr.invars) + map(partial(write, False, True), jaxpr.constvars) + for eqn in jaxpr.eqns: + unks_in, inst_in = unzip2(map(read, eqn.invars)) + rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive) + if rule: + eqn1, eqn2, unks_out, inst_out, res = rule(saveable, unks_in, inst_in, eqn) + eqn1 and known_eqns.append(eqn1); eqn2 and staged_eqns.append(eqn2) # type: ignore + residuals.update(res) + map(write, unks_out, inst_out, eqn.outvars) + elif any(unks_in): + inputs = map(ensure_instantiated, inst_in, eqn.invars) + staged_eqns.append(eqn.replace(invars=inputs)) + map(partial(write, True, True), eqn.outvars) + else: + known_eqns.append(eqn) + # If it's an effectful primitive, we always to run and avoid staging it. + if eqn.effects or saveable( + eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params): + map(partial(write, False, False), eqn.outvars) + else: + inputs = map(ensure_instantiated, inst_in, eqn.invars) + staged_eqns.append(eqn.replace(invars=inputs)) + map(partial(write, False, True), eqn.outvars) + unzipped = unzip2(map(read, jaxpr.outvars)) + out_unknowns, out_inst = list(unzipped[0]), list(unzipped[1]) + assert all(type(v) is Var for v in residuals), residuals + + for x, inst, ensure_inst in zip(jaxpr.outvars, out_inst, ensure_out_inst): + if ensure_inst: ensure_instantiated(inst, x) + out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns) + out_inst = map(op.or_, out_inst, ensure_out_inst) + + ins_known, _ = partition_list(in_unknowns, jaxpr.invars) + outs_known, _ = partition_list(out_unknowns, jaxpr.outvars) + known_outvars = [*outs_known, *residuals] + known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known, known_outvars, + known_eqns) + jaxpr_known = Jaxpr(jaxpr.constvars, ins_known, known_outvars, + known_eqns, known_effects) + config.jax_enable_checks and core.check_jaxpr(jaxpr_known) + + _, ins_staged = partition_list(in_inst, jaxpr.invars) + _, outs_staged = partition_list(out_inst, jaxpr.outvars) + staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns)) + staged_invars = [*residuals, *ins_staged] + staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars, + outs_staged, staged_eqns) + jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars, + outs_staged, staged_eqns, staged_effects) + config.jax_enable_checks and core.check_jaxpr(jaxpr_staged) + + return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals) + +# A primitive rule for policy-driven partial evaluation returns a 5-tuple +# with the components representing, respectively: +# * the JaxprEqn for the 'known' side (or None if there is no known component), +# * the JaxprEqn for the 'unknown' side (or None), +# * a list of booleans indicating which of the original outputs are unknown, +# * a list of booleans indicating which of the original outputs are +# instantiated (i.e. available) in the 'unknown' side, +# * a list of Var instances representing residuals to be added (i.e. to be +# plumbed as outputs of the 'known' side jaxpr and added as input binders to +# the 'unknown' jaxpr). +PartialEvalCustomResult = Tuple[Optional[JaxprEqn], Optional[JaxprEqn], + Sequence[bool], Sequence[bool], List[Var]] +PartialEvalCustomRule = Callable[ + [Callable[..., bool], Sequence[bool], Sequence[bool], JaxprEqn], + PartialEvalCustomResult] +partial_eval_jaxpr_custom_rules: Dict[Primitive, PartialEvalCustomRule] = {} + +def partial_eval_jaxpr_custom_rule_not_implemented( + name: str, saveable: Callable[..., bool], unks_in: Sequence[bool], + inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult: + msg = (f'custom-policy remat rule not implemented for {name}, ' + 'open a feature request at https://github.com/google/jax/issues!') + raise NotImplementedError(msg) + + +ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool], + Sequence[bool], int, dict, dict], + Tuple[dict, dict]] +ResAvalUpdater = Callable[[Dict[str, Any], AbstractValue], AbstractValue] +def _default_res_aval_updater( + params: Dict[str, Any], aval: AbstractValue) -> AbstractValue: + return aval + +@contextmanager +def trivial_ctx(_): yield + +def call_partial_eval_custom_rule( + jaxpr_param_name: str, params_updater: ParamsUpdater, + saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool], + eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, + ctx: Callable[[core.ParamDict], AbstractContextManager[None]] = trivial_ctx, + ) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]: + jaxpr = eqn.params[jaxpr_param_name] + with ctx(eqn.params): + jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ + partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) + ins_known, _ = partition_list(unks_in, eqn.invars) + out_binders_known, _ = partition_list(unks_out, eqn.outvars) + _, ins_staged = partition_list(inst_in, eqn.invars) + _, out_binders_staged = partition_list(inst_out, eqn.outvars) + newvar = core.gensym([jaxpr_known, jaxpr_staged]) + params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} + params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} + params_known, params_staged = params_updater( + unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, + params_staged) + residuals = [newvar(res_aval(params_known, var.aval)) + for var in jaxpr_staged.invars[:num_res]] + eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], + eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info) + eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged, + eqn.primitive, params_staged, + jaxpr_staged.effects, eqn.source_info) + assert len(eqn_staged.invars) == len(jaxpr_staged.invars) + new_inst = [x for x, inst in zip(eqn.invars, inst_in) + if type(x) is Var and not inst] + return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals + +def closed_call_partial_eval_custom_rule( + jaxpr_param_name: str, params_updater: ParamsUpdater, + saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool], + eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, + ) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]: + # TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule. + closed_jaxpr = eqn.params[jaxpr_param_name] + jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \ + partial_eval_jaxpr_custom(closed_jaxpr.jaxpr, unks_in, inst_in, + False, False, saveable) + # Forming these fresh ClosedJaxprs defeats caching, but caller handles caching + jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts) + jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts) + ins_known, _ = partition_list(unks_in, eqn.invars) + out_binders_known, _ = partition_list(unks_out, eqn.outvars) + _, ins_staged = partition_list(inst_in, eqn.invars) + _, out_binders_staged = partition_list(inst_out, eqn.outvars) + newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr]) + params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} + params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} + params_known, params_staged = params_updater( + unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, + params_staged) + residuals = [newvar(res_aval(params_known, a)) + for a in jaxpr_staged.in_avals[:num_res]] + eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], + eqn.primitive, params_known, jaxpr_known.effects, + eqn.source_info) + eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged, + eqn.primitive, params_staged, jaxpr_staged.effects, + eqn.source_info) + assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals) + new_inst = [x for x, inst in zip(eqn.invars, inst_in) + if type(x) is Var and not inst] + return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals + +partial_eval_jaxpr_custom_rules[core.call_p] = \ + partial(call_partial_eval_custom_rule, 'call_jaxpr', + lambda _, __, ___, ____, _____, x, y: (x, y)) +partial_eval_jaxpr_custom_rules[core.closed_call_p] = \ + partial(closed_call_partial_eval_custom_rule, 'call_jaxpr', + lambda _, __, ___, ____, _____, x, y: (x, y)) + + +def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]: + # Compute which inputs are just forwarded to outputs. + fwds: Dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars)) + for eqn in jaxpr.eqns: + if eqn.primitive in forwarding_rules: + eqn = eqn.replace(invars=[a if type(a) is Literal else fwds.get(a, a) # type: ignore + for a in eqn.invars]) + fwd_vars, _ = forwarding_rules[eqn.primitive](eqn) + for v_orig, v_new in zip(eqn.outvars, fwd_vars): + if v_new is not None: + fwds[v_orig] = v_new + idxs: Dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)} + return [None if type(v) is Literal else idxs.get(fwds.get(v)) # type: ignore + for v in jaxpr.outvars] + + +def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool], + instantiate: Union[bool, Sequence[bool]] = False, + ) -> Tuple[Jaxpr, List[bool]]: + if type(instantiate) is bool: + instantiate = (instantiate,) * len(jaxpr.invars) + return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate)) + + +def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], + instantiate: Union[bool, Sequence[bool]] = False, + ) -> Tuple[Jaxpr, List[bool], List[bool]]: + jaxpr_ = convert_constvars_jaxpr(jaxpr) + new_jaxpr_, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs) + used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)]) + new_jaxpr = convert_invars_to_constvars(new_jaxpr_, sum(used_consts)) + return new_jaxpr, used_consts, used_inputs + + +@weakref_lru_cache +def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...], + instantiate: Tuple[bool, ...] + ) -> Tuple[Jaxpr, List[bool]]: + env: Dict[Var, bool] = {} + + def read(v: Var) -> bool: + return env.get(v, False) + + def write(x: Atom, b: bool) -> None: + if type(x) is Var: + env[x] = read(x) or b + + def has_effects(e: JaxprEqn) -> bool: + return bool(e.effects) or core.primitive_uses_outfeed(e.primitive, e.params) + + new_eqns = [] + map(write, jaxpr.outvars, used_outputs) + for eqn in jaxpr.eqns[::-1]: + used_outs = map(read, eqn.outvars) + if not any(used_outs) and not has_effects(eqn): + used_ins = [False] * len(eqn.invars) + else: + rule = dce_rules.get(eqn.primitive, _default_dce_rule) + used_ins, new_eqn = rule(used_outs, eqn) + if new_eqn is not None: + new_eqns.append(new_eqn) + map(write, eqn.invars, used_ins) + used_inputs = map(read, jaxpr.invars) + used_inputs = map(op.or_, instantiate, used_inputs) + + invars = [v for v, b in zip(jaxpr.invars, used_inputs) if b] + outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b] + eqns = new_eqns[::-1] + jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns) + + dbg = jaxpr.debug_info and core.JaxprDebugInfo( + jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, + tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b), + tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b)) + new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) + config.jax_enable_checks and core.check_jaxpr(new_jaxpr) + + return new_jaxpr, used_inputs + +DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], Optional[JaxprEqn]]] + +def _default_dce_rule( + used_outs: List[bool], eqn: JaxprEqn + ) -> Tuple[List[bool], JaxprEqn]: + return [True] * len(eqn.invars), eqn + +dce_rules: Dict[Primitive, DCERule] = {} + + +def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn + ) -> Tuple[List[bool], Optional[JaxprEqn]]: + new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) + new_params = dict(eqn.params, call_jaxpr=new_jaxpr) + update_params = call_param_updaters.get(eqn.primitive) + if update_params: + new_params = update_params(new_params, used_inputs, 0) + if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects: + return used_inputs, None + else: + new_eqn = new_jaxpr_eqn( + [v for v, used in zip(eqn.invars, used_inputs) if used], + [v for v, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) + return used_inputs, new_eqn +dce_rules[core.call_p] = dce_jaxpr_call_rule + + +def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn + ) -> Tuple[List[bool], JaxprEqn]: + # TODO(mattjj): de-duplicate with above rule? + jaxpr_ = eqn.params['call_jaxpr'] + jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts + new_jaxpr, used_inputs = dce_jaxpr(jaxpr, used_outputs) + new_params = dict(eqn.params, call_jaxpr=core.ClosedJaxpr(new_jaxpr, consts)) + new_eqn = new_jaxpr_eqn( + [v for v, used in zip(eqn.invars, used_inputs) if used], + [v for v, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) + return used_inputs, new_eqn +dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule + +@weakref_lru_cache +def close_jaxpr(jaxpr: Jaxpr) -> ClosedJaxpr: + return ClosedJaxpr(jaxpr, ()) + +def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] + ) -> ClosedJaxpr: + """Reorder `invars` by moving those indicated in `to_move` to the front.""" + return _move_binders_to_front(closed_jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Tuple[bool, ...] + ) -> ClosedJaxpr: + assert len(closed_jaxpr.in_avals) == len(to_move) + new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) + new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, + closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, + closed_jaxpr.jaxpr.effects) + new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) + return new_closed_jaxpr + +def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: + return ([elt for elt, move in zip(lst, to_move) if move] + + [elt for elt, move in zip(lst, to_move) if not move]) + +def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] + ) -> ClosedJaxpr: + """Reorder `invars` by moving those indicated in `to_move` to the back.""" + return move_binders_to_front(closed_jaxpr, map(op.not_, to_move)) + +class DynamicJaxprTracer(core.Tracer): + __slots__ = ['aval', '_debug_info'] + + def __init__(self, trace, aval, line_info=None): + self._trace = trace + self._line_info = line_info + # Needed for UnexpectedTracerError. + self._debug_info = self._trace.frame.debug_info + self.aval = aval + + def full_lower(self): + return self + + def _contents(self): + return () + + def _origin_msg(self): + if not self._trace.main.jaxpr_stack: # type: ignore + # If this Tracer has been leaked the jaxpr stack may no longer be + # available. So we can't print as much origin information. + return ("\nThis DynamicJaxprTracer was created on line " + f"{source_info_util.summarize(self._line_info)}") + else: + invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self) + dbg = self._debug_info + if dbg is None: + return "" + + origin = ("The error occurred while tracing the function " + f"{dbg.func_src_info or ''} for {dbg.traced_for}. ") + arg_info = arg_info_all(dbg) + if invar_pos and arg_info: + arg_info = [arg_info[i] for i in invar_pos] + arg_names = [f'{name}{keystr(path)}' for name, path in arg_info] + if len(arg_names) == 1: + arg_info_str = f"the argument {arg_names[0]}" + elif len(arg_names) == 2: + arg_info_str = f"the arguments {arg_names[0]} and {arg_names[1]}" + else: + *rest, last = arg_names + arg_info_str = f"the arguments {', '.join(rest)}, and {last}" + origin += ("This concrete value was not available in Python because it " + f"depends on the value{'s' if len(invar_pos) > 1 else ''} " + f"of {arg_info_str}.") + elif progenitor_eqns: + msts = [" operation " + f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n" + f" from line {source_info_util.summarize(eqn.source_info)}" + for eqn in progenitor_eqns[:5]] # show at most 5 + origin += ("This value became a tracer due to JAX operations on these lines:" + "\n\n" + "\n\n".join(msts)) + if len(progenitor_eqns) > 5: + origin += "\n\n(Additional originating lines are not shown.)" + return "\n" + origin + + def _assert_live(self) -> None: + if not self._trace.main.jaxpr_stack: # type: ignore + raise core.escaped_tracer_error(self, None) + + def get_referent(self): + frame = self._trace.frame + val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) + return self if val is None else get_referent(val) +api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval") + +def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: + del outvars + jaxpr_effects = set() + all_vars = [*constvars, *invars] + for eqn in eqns: + for eff in eqn.effects: + if isinstance(eff, effects.JaxprInputEffect): + invar = eqn.invars[eff.input_index] + if invar not in all_vars: + raise ValueError( + "`JaxprInputEffect` does not have corresponding input.") + eff = eff.replace(input_index=all_vars.index(invar)) + jaxpr_effects.add(eff) + return jaxpr_effects + + +class JaxprStackFrame: + gensym: Callable[[AbstractValue], Var] + tracer_to_var: Dict[TracerId, Var] + constid_to_tracer: Dict[ConstId, Tracer] + constvar_to_val: Dict[Var, Any] + tracers: List[DynamicJaxprTracer] # hold onto strong refs for all tracers + eqns: List[JaxprEqn] + invars: List[Var] + effects: core.Effects + debug_info: Optional[DebugInfo] + + def __init__(self): + self.gensym = core.gensym() + self.tracer_to_var = {} + self.constid_to_tracer = {} + self.constvar_to_val = {} + self.tracers = [] # circ refs, frame->tracer->trace->main->frame, + self.eqns = [] # cleared when we pop frame from main + self.invars = [] + self.effects = set() + self.debug_info = None + + def add_eqn(self, eqn: core.JaxprEqn): + self.eqns.append(eqn) + + def to_jaxpr(self, out_tracers): + # It's not necessary, but we keep the tracer-to-var mapping injective: + assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) + outvars = [self.tracer_to_var[id(t)] for t in out_tracers] + constvars, constvals = unzip2(self.constvar_to_val.items()) + jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars, + self.eqns) + jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects) + jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) + jaxpr, constvals = _inline_literals(jaxpr, constvals) + return jaxpr, constvals + + def to_jaxpr2(self, out_tracers): + # It's not necessary, but we keep the tracer-to-var mapping injective: + assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) + constvars, constvals = unzip2(self.constvar_to_val.items()) + expl_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] + jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars, + self.eqns) + jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns, + jaxpr_effects) + # We can't run check_jaxpr until after we normalize. + jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) + jaxpr, constvals = _inline_literals(jaxpr, constvals) + jaxpr, out_type = _add_implicit_outputs(jaxpr) + config.jax_enable_checks and core.check_jaxpr(jaxpr) + return jaxpr, out_type, constvals + + def newvar(self, aval): + if isinstance(aval, DShapedArray): + # this aval may have tracers in it, so we replace those with variables + new_shape = [self.tracer_to_var[id(d)] if isinstance(d, Tracer) else d + for d in aval.shape] + aval = aval.update(shape=tuple(new_shape)) + return self.gensym(aval) + + def find_progenitors(self, tracer): + var = self.tracer_to_var.get(id(tracer)) + if not var: + return None, None + active_vars = {var} + for eqn in self.eqns[::-1]: + produced = set(eqn.outvars) & active_vars + if produced: + active_vars.difference_update(produced) + active_vars.update(eqn.invars) + invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] + constvars = active_vars & set(self.constvar_to_val) + const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars] + return invar_positions, const_eqns + +def _const_folding_and_forwarding(jaxpr, constvals): + consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) + var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined + new_eqns = [] + for eqn in jaxpr.eqns: + # always apply invar substitutions + eqn = eqn.replace(invars=[var_subs.get(v, v) for v in eqn.invars]) + # if any inputs are constants and we have a constant-folding rule, apply it + has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) + for eff in eqn.effects) + if (eqn.primitive in const_fold_rules and any(v in consts for v in eqn.invars) + and not has_input_effect): + consts_in = [consts.get(v) for v in eqn.invars] + consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) + assert (new_eqn is None) == all(c is not None for c in consts_out) + for v, c in zip(eqn.outvars, consts_out): + if c is not None: consts[v] = c + if new_eqn is None: continue + else: eqn = new_eqn + # if the application trivially maps some inputs to outputs, simplify + if eqn.primitive in forwarding_rules and not has_input_effect: + fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn) + assert (new_eqn is None) == all(v is not None for v in fwd_vars) + for v_orig, v_new in zip(eqn.outvars, fwd_vars): + if v_new is not None: var_subs[v_orig] = v_new + if new_eqn is None: continue + else: eqn = new_eqn + new_eqns.append(eqn) + new_constvars, new_constvals = unzip2(consts.items()) + new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars] + jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars, + new_eqns) + new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns, + jaxpr_effects, jaxpr.debug_info) + return new_jaxpr, new_constvals + +ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn], + Tuple[List[Optional[Any]], Optional[JaxprEqn]]] +const_fold_rules: Dict[Primitive, ConstFoldRule] = {} + +ForwardingRule = Callable[[JaxprEqn], + Tuple[List[Optional[Var]], Optional[JaxprEqn]]] +forwarding_rules: Dict[Primitive, ForwardingRule] = {} + +def _inline_literals(jaxpr, constvals): + # This function also prunes unused constants and inserts `dropvar` symbols. + input_effects = {eff for eff in jaxpr.effects + if isinstance(eff, effects.JaxprInputEffect)} + # Don't inline any literal with an input effect + has_input_effect = [any(eff.input_index == i for eff in input_effects) + for i in range(len(constvals))] + lits = {v: Literal(c, v.aval) for v, c, e in zip(jaxpr.constvars, constvals, + has_input_effect) + if type(c) in core.literalable_types and not np.shape(c) and not e} + lit: Callable[[Var], Optional[Literal]] = lits.get + newname: Callable[[AbstractValue], Var] = core.gensym() + newvars: Dict[Var, Var] = {} + newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) + var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) + dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) + + def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: + if isinstance(aval, DShapedArray): + return [d for d in aval.shape if isinstance(d, Var)] + return [] + + used = {v for eqn in jaxpr.eqns for invar in eqn.invars + for v in it.chain([invar], vars_in_shape(invar.aval))} + used |= {v for outvar in jaxpr.outvars + for v in it.chain([outvar], vars_in_shape(outvar.aval))} + new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] + new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) + if v in used and not lit(v)] + new_invars = [var(v) for v in jaxpr.invars] + new_eqns = [] + for eqn in jaxpr.eqns: + invars = [lit(v) or var(v) for v in eqn.invars] + outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] + new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) + new_outvars = [lit(v) or var(v) for v in jaxpr.outvars] + jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, + new_eqns) + new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, + jaxpr_effects, jaxpr.debug_info) + return new_jaxpr, new_constvals + +class DynamicJaxprTrace(core.Trace): + __slots__ = [] # type: ignore + + @property + def frame(self): + return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error + + def new_arg(self, aval): + tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) + self.frame.tracers.append(tracer) + self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) + self.frame.invars.append(var) + return tracer + + def new_const(self, c): + # TODO(mattjj): for ints, or hashable consts, don't rely on id + tracer = self.frame.constid_to_tracer.get(id(c)) + if tracer is None: + aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) + aval = self._lift_tracers_in_aval(aval) + tracer = self._new_const(aval, c) + return tracer + + pure = lift = new_const + + def _new_const(self, aval, c): + tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) + self.frame.tracers.append(tracer) + self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) + self.frame.constid_to_tracer[id(c)] = tracer + self.frame.constvar_to_val[var] = c + return tracer + + def sublift(self, t): + # When lifting closed-over tracers corresponding to this same trace, the + # variable to lift could have tracers (representing axis size variables) in + # its shape. We must lift those too! + tracer = self.frame.constid_to_tracer.get(id(t)) + if tracer is None: + aval = raise_to_shaped(get_aval(t), weak_type=dtypes.is_weakly_typed(t)) + aval = self._lift_tracers_in_aval(aval) + tracer = self._new_const(aval, t) + return tracer + + def _lift_tracers_in_aval(self, aval): + if (not isinstance(aval, DShapedArray) or + not any(isinstance(d, Tracer) for d in aval.shape)): + return aval + shape = [self.full_raise(d) if isinstance(d, Tracer) else d + for d in aval.shape] + return aval.update(shape=tuple(shape)) + + def getvar(self, tracer): + var = self.frame.tracer_to_var.get(id(tracer)) + if var is None: + raise core.escaped_tracer_error(tracer) + return var + + def makevar(self, tracer): + var = self.frame.tracer_to_var.get(id(tracer)) + assert var is None, "a jaxpr variable must be created only once per tracer" + self.frame.tracers.append(tracer) + var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) + return var + + def instantiate_const(self, val): + if (isinstance(val, Tracer) and val._trace.main is self.main + and val._trace.sublevel == self.sublevel): + return val + else: + return self.new_const(val) + + def process_primitive(self, primitive, tracers, params): + if primitive in custom_staging_rules: + return custom_staging_rules[primitive](self, *tracers, **params) + return self.default_process_primitive(primitive, tracers, params) + + def default_process_primitive(self, primitive, tracers, params): + avals = [t.aval for t in tracers] + out_avals, effects = primitive.abstract_eval(*avals, **params) + out_avals = [out_avals] if not primitive.multiple_results else out_avals + source_info = source_info_util.current() + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] + invars = map(self.getvar, tracers) + outvars = map(self.makevar, out_tracers) + eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info) + self.frame.add_eqn(eqn) + return out_tracers if primitive.multiple_results else out_tracers.pop() + + def process_call(self, call_primitive, f, explicit_tracers, params): + if f.in_type is None: + f = lu.annotate(f, tuple((raise_to_shaped(t.aval), True) + for t in explicit_tracers)) + implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) + in_tracers = [*implicit_tracers, *explicit_tracers] + # TODO(mattjj): check in_tracers are consistent with f.in_type annotation + with core.new_sublevel(): + # TODO(lenamartens): Make call_primitive name -> API function name mapping. + # (currently this will display eg. 'xla_call' instead of `jit`) + dbg = debug_info_final(f, call_primitive.name) + jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg) + if params.get('inline', False): + return core.eval_jaxpr(jaxpr, consts, *in_tracers) + source_info = source_info_util.current() + out_tracers = [] + for aval, _ in out_type: + if type(aval) is DShapedArray: + shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else + out_tracers[d.val] if type(d) is OutDBIdx else + d for d in aval.shape] + aval = aval.update(shape=tuple(get_referent(d) for d in shape)) + out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) + invars = map(self.getvar, in_tracers) + constvars = map(self.getvar, map(self.instantiate_const, consts)) + outvars = map(self.makevar, out_tracers) + new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) + update_params = call_param_updaters.get(call_primitive) + if update_params: + new_params = update_params(new_params, [True] * len(explicit_tracers), + len(consts) + len(implicit_tracers)) + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, + new_params, new_params['call_jaxpr'].effects, + source_info) + self.frame.add_eqn(eqn) + return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] + + def post_process_call(self, call_primitive, out_tracers, params): + assert False # unreachable + + def process_map(self, map_primitive, f, tracers, params): + in_avals = [t.aval for t in tracers] + axis_name, axis_size = params['axis_name'], params['axis_size'] + reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) + if in_axis is not None else a + for a, in_axis in zip(in_avals, params['in_axes'])] + with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore + with core.new_sublevel(): + jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic( + f, self.main, reduced_in_avals, + debug_info=debug_info_final(f, map_primitive.name)) + ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) + if ordered_effects: + raise ValueError("Ordered effects not supported for " + f"map primitives: {ordered_effects}") + out_axes = params['out_axes_thunk']() + out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a) + if out_axis is not None else a + for a, out_axis in zip(reduced_out_avals, out_axes)] + source_info = source_info_util.current() + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] + invars = map(self.getvar, tracers) + constvars = map(self.getvar, map(self.instantiate_const, consts)) + outvars = map(self.makevar, out_tracers) + new_in_axes = (None,) * len(consts) + params['in_axes'] + new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, + call_jaxpr=convert_constvars_jaxpr(jaxpr)) + del new_params['out_axes_thunk'] + update_params = call_param_updaters.get(map_primitive) + if update_params: + new_params = update_params(new_params, [True] * len(tracers), len(consts)) + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, + new_params, jaxpr.effects, source_info) + self.frame.add_eqn(eqn) + return out_tracers + + def post_process_map(self, map_primitive, out_tracers, params): + assert False # unreachable + + def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + in_avals = [t.aval for t in tracers] + with core.new_sublevel(): + fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) + closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + main_ = ref(self.main) + jvp_jaxpr_thunk = _memoize( + lambda: trace_to_subjaxpr_dynamic(jvp, main_(), 2 * in_avals)[::2]) + out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + invars = map(self.getvar, tracers) + constvars = map(self.getvar, map(self.instantiate_const, consts)) + outvars = map(self.makevar, out_tracers) + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, + dict(call_jaxpr=closed_fun_jaxpr, + jvp_jaxpr_thunk=jvp_jaxpr_thunk, + num_consts=len(consts), + symbolic_zeros=symbolic_zeros), + fun_jaxpr.effects, + source_info_util.current()) + self.frame.add_eqn(eqn) + return out_tracers + + def post_process_custom_jvp_call(self, out_tracers, _): + assert False # unreachable + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): + in_avals = [t.aval for t in tracers] + with core.new_sublevel(): + fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) + closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + main_ = ref(self.main) + fwd_jaxpr_thunk = _memoize( + lambda: trace_to_subjaxpr_dynamic(fwd, main_(), in_avals)[::2]) + out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + invars = map(self.getvar, tracers) + constvars = map(self.getvar, map(self.instantiate_const, consts)) + outvars = map(self.makevar, out_tracers) + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, + dict(fun_jaxpr=closed_fun_jaxpr, + fwd_jaxpr_thunk=fwd_jaxpr_thunk, + num_consts=len(consts), + bwd=bwd, out_trees=out_trees), + fun_jaxpr.effects, + source_info_util.current()) + self.frame.add_eqn(eqn) + return out_tracers + + def post_process_custom_vjp_call(self, out_tracers, _): + assert False # unreachable + + def process_custom_transpose(self, prim, call, tracers, *, + transpose, out_types, + lin_tree, res_tree, out_tree): + tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) + + in_avals_p = [t.aval for t in tracers] + in_avals_t = [*[t.aval for t in tracers_res], *out_types] + + with core.new_sublevel(): + call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic( + call, self.main, in_avals_p) + closed_call_jaxpr = core.ClosedJaxpr( + convert_constvars_jaxpr(call_jaxpr), ()) + + transpose_flat, in_tree2 = flatten_fun_nokwargs( + lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree))) + + main_ = ref(self.main) + # the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts + transpose_jaxpr_thunk = _memoize( + lambda: trace_to_subjaxpr_dynamic( + transpose_flat, main_(), in_avals_t)[::2]) + + out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + invars = map(self.getvar, tracers) + constvars = map(self.getvar, map(self.instantiate_const, call_consts)) + outvars = map(self.makevar, out_tracers) + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, + dict(call_jaxpr=closed_call_jaxpr, + transpose_jaxpr_thunk=transpose_jaxpr_thunk, + out_types=out_types, res_tree=res_tree, + lin_tree=lin_tree, out_tree=out_tree), + closed_call_jaxpr.effects, + source_info_util.current()) + self.frame.add_eqn(eqn) + return out_tracers + + +custom_staging_rules: Dict[Primitive, Callable] = {} + +def _memoize(thunk): + cell = [] + saved_state = [core.thread_local_state.trace_state.copy()] + def memoized(): + if not cell: + prev_state = core.thread_local_state.trace_state + core.thread_local_state.trace_state = saved_state.pop() + try: + cell.append(thunk()) + finally: + core.thread_local_state.trace_state = prev_state + return cell[0] + return memoized + +# TODO(mattjj): remove this DebugInfo and helper functions, replace with +# api_util.py versions + +class DebugInfo(NamedTuple): + func_src_info: Optional[str] # f'{fun.__name__} at {filename}:{lineno}' + signature: Optional[inspect.Signature] # inspect.signature(fun) + in_tree: Optional[PyTreeDef] # caller/constructor might not have this info + out_tree: Optional[Callable[[], PyTreeDef]] # lazy, not avail at trace time + has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args + traced_for: str # "jit", "scan", "make_jaxpr", etc + +def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], + out_tree_thunk: Optional[Callable[[], PyTreeDef]], + has_kwargs: bool, traced_for: str) -> DebugInfo: + try: sig = inspect.signature(fn) + except (ValueError, TypeError): sig = None + src_info = fun_sourceinfo(fn) + return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs, + traced_for) + +def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo: + "Make a DebugInfo from data available to final-style primitives like pmap." + in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False) + return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for) + +def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]: + ba = None if dbg.in_tree is None else sig_info(dbg) + if ba is None: return None + return [(name, key_path) for name, dummy_arg in ba.arguments.items() + for key_path, _ in generate_key_paths(dummy_arg)] + +def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]: + if dbg.in_tree is None or dbg.signature is None: return None + try: + dummy_args = tree_unflatten(dbg.in_tree, [False] * dbg.in_tree.num_leaves) + except: + return None + args, kwargs = dummy_args if dbg.has_kwargs else (dummy_args, {}) + try: + return dbg.signature.bind(*args, **kwargs) + except (TypeError, ValueError): + return None + +def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]: + if dbg.out_tree is None: return None + try: + num_leaves = dbg.out_tree().num_leaves + dummy_result = tree_unflatten(dbg.out_tree(), [False] * num_leaves) + except: + return None + else: + return [path for path, _ in generate_key_paths(dummy_result)] + +@profiler.annotate_function +def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, + in_avals: Sequence[AbstractValue], + debug_info: Optional[DebugInfo] = None, + *, + keep_inputs: Optional[List[bool]] = None): + with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore + main.jaxpr_stack = () # type: ignore + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( + fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) + del main, fun + return jaxpr, out_avals, consts + +def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, + in_avals: Sequence[AbstractValue], *, + keep_inputs: Optional[Sequence[bool]] = None, + debug_info: Optional[DebugInfo] = None): + keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs + + frame = JaxprStackFrame() + frame.debug_info = debug_info + with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): + trace = DynamicJaxprTrace(main, core.cur_sublevel()) + in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] + ans = fun.call_wrapped(*in_tracers_) + out_tracers = map(trace.full_raise, ans) + jaxpr, consts = frame.to_jaxpr(out_tracers) + del fun, main, trace, frame, in_tracers, out_tracers, ans + config.jax_enable_checks and core.check_jaxpr(jaxpr) + return jaxpr, [v.aval for v in jaxpr.outvars], consts + + +@profiler.annotate_function +def trace_to_jaxpr_dynamic2( + fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None + ) -> Tuple[Jaxpr, OutputType, List[Any]]: + with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore + main.jaxpr_stack = () # type: ignore + jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) + del main, fun + return jaxpr, out_type, consts + +def trace_to_subjaxpr_dynamic2( + fun: lu.WrappedFun, main: core.MainTrace, + debug_info: Optional[DebugInfo] = None +) -> Tuple[Jaxpr, OutputType, List[Any]]: + in_avals, keep_inputs = unzip2(fun.in_type) + frame = JaxprStackFrame() + frame.debug_info = debug_info + with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): + trace = DynamicJaxprTrace(main, core.cur_sublevel()) + in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] + ans = fun.call_wrapped(*in_tracers_) + out_tracers = map(trace.full_raise, ans) + jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers) + del fun, main, trace, frame, in_tracers, out_tracers, ans + return jaxpr, out_type, consts + + +@contextmanager +def extend_jaxpr_stack(main, frame): + main.jaxpr_stack = main.jaxpr_stack + (frame,) + try: + yield + finally: + assert frame is main.jaxpr_stack[-1] + main.jaxpr_stack = main.jaxpr_stack[:-1] + +@profiler.annotate_function +def trace_to_jaxpr_final(fun: lu.WrappedFun, + in_avals: Sequence[AbstractValue], + debug_info: Optional[DebugInfo] = None, + keep_inputs: Optional[Sequence[bool]] = None): + with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore + main.jaxpr_stack = () # type: ignore + with core.new_sublevel(): + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( + fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) + del fun, main + return jaxpr, out_avals, consts + +@profiler.annotate_function +def trace_to_jaxpr_final2( + fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None + ) -> Tuple[Jaxpr, OutputType, List[Any]]: + with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore + main.jaxpr_stack = () # type: ignore + with core.new_sublevel(): + jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) + del fun, main + return jaxpr, out_type, consts + + +AbstractedAxisName = Hashable +AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], + Tuple[AbstractedAxisName, ...]] +def infer_lambda_input_type( + axes_specs: Optional[Sequence[AbstractedAxesSpec]], + args: Sequence[Any] + ) -> InputType: + ndims = [getattr(get_aval(x), 'ndim', 0) for x in args] + partial_specs = _canonicalize_specs(ndims, axes_specs) + specs = _complete_specs(args, partial_specs) + idxs, implicit_types = _collect_implicit(args, specs) + implicit_sig = [(ty, False) for ty in implicit_types] + explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)] + input_type = (*implicit_sig, *explicit_sig) + lu._check_input_type(input_type) + return input_type + +def _spec_to_dict(spec: AbstractedAxesSpec) -> Dict[int, AbstractedAxisName]: + if isinstance(spec, tuple): + return {i: d for i, d in enumerate(spec) if d is not None} + else: + return spec + +def _canonicalize_specs( + ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]] + ) -> List[Dict[int, AbstractedAxisName]]: + if specs is None: + return [{}] * len(ndims) + else: + return [_spec_to_dict(s) for n, s in zip(ndims, specs)] + +def _complete_specs( + args: Sequence[Any], partial_specs: List[Dict[int, AbstractedAxisName]] + ) -> List[Dict[int, AbstractedAxisName]]: + # The abstracted axes specification in `partial_specs` is partial in the sense + # that there could be additional axis abstraction represented in `args` due to + # Tracers existing in the shapes of elements of `args`. The purpose of this + # function is to produce a full specification, for each argument mapping any + # abstracted axis positions to a name, introducing new names as needed for + # Tracers in axis sizes which don't already correspond to abstracted axis + # names (with one new name per unique Tracer object id). + + # Identify each user-supplied name in partial_specs with a size. + sizes: Dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {} + for x, spec in zip(args, partial_specs): + for i, name in spec.items(): + d = sizes.setdefault(name, x.shape[i]) + if d is not x.shape[i] and d != x.shape[i]: raise TypeError + + # Introduce new names as needed for Tracers in shapes. + named_tracers: Dict[TracerId, AbstractedAxisName] = { + id(d): name for name, d in sizes.items() if isinstance(d, Tracer)} + specs: List[Dict[int, AbstractedAxisName]] = [] + for x, spec in zip(args, partial_specs): + if isinstance(get_aval(x), DShapedArray): + spec = dict(spec) + for i, d in enumerate(x.shape): + if isinstance(d, Tracer): + spec[i] = named_tracers.get(id(d), TracerAsName(d)) + specs.append(spec) + + # Assert that `specs` is now complete in the sense that there are no Tracers + # which don't correspond to an AbstractedAxisName. + assert all(not spec or not any(isinstance(d, Tracer) and i not in spec + for i, d in enumerate(x.shape)) + for x, spec in zip(args, specs)) + return specs + + +def _collect_implicit( + args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]] + ) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractValue]]: + # Given an explicit argument list and a specification of abstracted axes, we + # want to produce an InputType by identifying AbstractedAxisNames with DBIdxs + # and figuring out which AbstractedAxisNames correspond to implicit arguments. + + idxs: Dict[AbstractedAxisName, DBIdx] = {} + implicit_types: List[AbstractValue] = [] + explicit_tracers: Dict[TracerId, int] = {} + counter = it.count() + + # Add implicit arguments to idxs. + for explicit_idx, (x, spec) in enumerate(zip(args, specs)): + for i, name in spec.items(): + if name not in idxs and id(x.shape[i]) not in explicit_tracers: + idxs[name] = DBIdx(next(counter)) + implicit_types.append(raise_to_shaped(get_aval(x.shape[i]))) + if isinstance(x, Tracer): + explicit_tracers.setdefault(id(x), explicit_idx) # use the first + + # Now that we know the implicit args, add explicit args to idxs. + offset = len(implicit_types) + for x, spec in zip(args, specs): + for i, name in spec.items(): + if id(x.shape[i]) in explicit_tracers: + idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])])) + + return idxs, implicit_types + +def _arg_type( + idxs: Dict[AbstractedAxisName, DBIdx], x: Any, + spec: Dict[int, AbstractedAxisName] + ) -> AbstractValue: + # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames. + aval = get_aval(x) # aval.shape could contain Tracers + if not spec: return core.raise_to_shaped(aval) + shape: List[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d + for i, d in enumerate(aval.shape)] + assert not any(isinstance(d, Tracer) for d in shape) + return DShapedArray(tuple(shape), aval.dtype, False) + +def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]: + invars = [*jaxpr.constvars, *jaxpr.invars] + expl_outvars = jaxpr.outvars + + # First do a pass to collect implicit outputs, meaning variables which occurr + # in explicit_outvars types but not in invars or to the left in outvars. + seen: Set[Var] = set(invars) + impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore + (seen.add(x) or type(x.aval) is DShapedArray) # type: ignore + for d in x.aval.shape if type(d) is Var and d not in seen] + outvars = [*impl_outvars, *expl_outvars] + + # Now assemble an OutputType by mapping vars in shapes to InDBIdx/OutDBIdx. + in_map : Dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)} + out_map: Dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars) + if type(x) is Var} + out_avals_ = (x.aval for x in outvars) + out_avals = [a.update(shape=tuple(in_map.get(d, out_map.get(d)) + if type(d) is Var else d for d in a.shape)) + if type(a) is DShapedArray else a for a in out_avals_] + kept_outs = [False] * len(impl_outvars) + [True] * len(expl_outvars) + out_type = tuple(zip(out_avals, kept_outs)) + + new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, + jaxpr.effects, jaxpr.debug_info) + config.jax_enable_checks and core.check_jaxpr(jaxpr) + return new_jaxpr, out_type + + +class TracerAsName: + ref: Any + def __init__(self, tracer): + self.ref = core.get_referent(tracer) + def __eq__(self, other): + return isinstance(other, TracerAsName) and self.ref is other.ref + def __hash__(self): + return id(self.ref) + +def _extract_implicit_args( + trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]], + explicit_tracers: Sequence[DynamicJaxprTracer] + ) -> Sequence[DynamicJaxprTracer]: + # First, construct a list to represent the full argument list, leaving the + # implicit arguments as Nones for now. + explicit_tracers_ = iter(explicit_tracers) + tracers = [next(explicit_tracers_) if expl else None for _, expl in in_type] + assert next(explicit_tracers_, None) is None + del explicit_tracers_ + + # Next, populate the implicit arguments using DBIdxs in in_type. + for i, (aval, explicit) in enumerate(in_type): + if not explicit or not isinstance(aval, DShapedArray): + continue # can't populate an implicit argument + tracer = tracers[i] + assert tracer is not None + for d1, d2 in zip(aval.shape, tracer.aval.shape): + if isinstance(d1, DBIdx): + if tracers[d1.val] is None: + tracers[d1.val] = trace.instantiate_const(d2) + assert tracers[d1.val] is trace.instantiate_const(d2) + assert all(t is not None for t in tracers) + return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore + +def _input_type_to_tracers( + new_arg: Callable[[AbstractValue], Tracer], + in_avals: Sequence[AbstractValue] + ) -> Sequence[Tracer]: + # Create input Tracers given input AbstractValues, each of which can contain + # DeBruijn indices which refer to positions in the input argument list. That + # is, each element `a` of `in_avals` can have DBIdx instances in its shape, + # which must refer to positions left of `a`'s. + in_tracers: List[Tracer] = [] + + def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue: + if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape): + shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] # type: ignore + return a.update(shape=tuple(shape)) + return a + + for a in in_avals: + in_tracers.append(new_arg(_substitute_tracers_in_aval(a))) + return in_tracers + +def _substitute_vars_in_type( + consts: Dict[Var, Literal], env: Dict[Var, Var], a: AbstractValue + ) -> AbstractValue: + if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape): + shape = [consts[d].val if d in consts else env[d] # type: ignore + if isinstance(d, Var) else d for d in a.shape] + return a.update(shape=tuple(shape)) + else: + return a + + +class DimensionHandlerTracer(core.DimensionHandler): + """See core.DimensionHandler. + + Most methods are inherited. + """ + def is_constant(self, d: core.DimSize) -> bool: + assert isinstance(d, Tracer) + return False + + def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool: + return d1 is d2 + + def greater_equal(self, d1: core.DimSize, d2: core.DimSize): + raise core.InconclusiveDimensionOperation("TODO") + + def divide_shape_sizes(self, s1: core.Shape, s2: core.Shape) -> core.DimSize: + """Computes integer "i" such that i * size(s2) == size(s1). + + Raise InconclusiveDimensionOperation if there is no such integer for all + contexts. + """ + s1_size = functools.reduce(op.mul, s1, 1) + s2_size = functools.reduce(op.mul, s2, 1) + q, r = divmod(s1_size, s2_size) + # TODO(necula): must check that r == 0! + return q + + def stride(self, d: core.DimSize, window_size: core.DimSize, window_stride: core.DimSize) -> core.DimSize: + """Implements `(d - window_size) // window_stride + 1`""" + raise core.InconclusiveDimensionOperation("TODO") + + def as_value(self, d: core.DimSize): + """Turns a dimension size into a Jax value that we can compute with.""" + raise core.InconclusiveDimensionOperation("TODO") + +core._SPECIAL_DIMENSION_HANDLERS[DynamicJaxprTracer] = DimensionHandlerTracer() + +Const = Any +Val = Any + +def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const] + ) -> Tuple[Jaxpr, List[Const]]: + bounds = {v: v.aval.dtype.bound for v in jaxpr.invars + if isinstance(v.aval, core.UnshapedArray) and + type(v.aval.dtype) is core.bint and not v.aval.shape} + idxs = {v: DBIdx(i) for i, v in enumerate(jaxpr.invars)} + + def substitute(aval: AbstractValue) -> AbstractValue: + if (isinstance(aval, core.UnshapedArray) and type(aval.dtype) is core.bint + and not aval.shape): + return ShapedArray((), dtypes._scalar_type_to_dtype(int)) + elif isinstance(aval, DShapedArray): + shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore + typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray + return typ(tuple(shape), aval.dtype, aval.weak_type) + else: + return aval + + in_avals = [substitute(v.aval) for v in jaxpr.invars] + eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts)) + padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals) + return padded_jaxpr, padded_consts + +class BoundedAxisSize(NamedTuple): + val: Union[int, DynamicJaxprTracer] + bound: int + +def _eval_jaxpr_padded( + jaxpr: Jaxpr, consts: List[Const], *args: DynamicJaxprTracer + ) -> List[Union[Const, DynamicJaxprTracer]]: + env: Dict[Var, Val] = {} + + def read(x): + return x.val if type(x) is Literal else env[x] + + def write(v, val) -> None: + env[v] = val + + map(write, jaxpr.constvars, consts) + map(write, jaxpr.invars, args) + for eqn in jaxpr.eqns: + in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars] + out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars] + rule = padding_rules[eqn.primitive] + outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params) + map(write, eqn.outvars, outs) + return map(read, jaxpr.outvars) + +def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue: + if isinstance(aval, DShapedArray): + shp = [] + for d in aval.shape: + if isinstance(d, core.DArray): + assert not d.shape and type(d.dtype) is core.bint + shp.append(BoundedAxisSize(int(d._data), int(d.dtype.bound))) + elif (type(d) is core.Var and isinstance(d.aval, core.DShapedArray) and + type(d.aval.dtype) is core.bint): + assert not d.aval.shape + shp.append(BoundedAxisSize(env[d], d.aval.dtype.bound)) + else: + shp.append(env.get(d, d)) + return DShapedArray(tuple(shp), aval.dtype, aval.weak_type) + else: + return aval + +def _is_bint_axis_size(d: Union[int, core.DArray, core.Var]) -> bool: + if isinstance(d, core.DArray): + assert not d.shape # pytype: disable=attribute-error + return type(d.dtype) is core.bint # pytype: disable=attribute-error + elif isinstance(d, core.Var): + return (isinstance(d.aval, core.DShapedArray) and # pytype: disable=attribute-error + type(d.aval.dtype) is core.bint) # pytype: disable=attribute-error + return False + + +padding_rules: Dict[Primitive, Callable] = {} + +def def_trivial_padding(prim: Primitive) -> None: + if prim.multiple_results: + padding_rules[prim] = partial(_trivial_padding_rule_multi, prim) + else: + padding_rules[prim] = partial(_trivial_padding_rule, prim) + +def _trivial_padding_rule(prim, _, __, *args, **params): + return [prim.bind(*args, **params)] + +def _trivial_padding_rule_multi(prim, _, __, *args, **params): + return prim.bind(*args, **params) + +def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): + if call_jaxpr.constvars: raise NotImplementedError + padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ()) + if padded_consts: raise NotImplementedError + new_params = dict(params, call_jaxpr=padded_jaxpr) + subfuns, bind_params = prim.get_bind_params(new_params) + return prim.bind(*subfuns, *args, **bind_params) + + +# TODO(mattjj): the following are deprecated; update callers to _nounits version +# See https://github.com/google/jax/pull/9498 +@lu.transformation +def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]], + pvals: Sequence[PartialVal]): + assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals + trace = main.with_cur_sublevel() + in_tracers = map(trace.new_arg, pvals) + ans = yield in_tracers, {} + assert isinstance(ans, (list, tuple)), ( + f"Got unexpected return type when tracing function to jaxpr: {ans}") + assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( + f"Got unexpected return type when tracing function to jaxpr: {ans}") + instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate + out_tracers = map(trace.full_raise, map(core.full_lower, ans)) + out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers) + jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers) + out_pvals = [t.pval for t in out_tracers] + del trace, in_tracers, out_tracers + yield jaxpr, (out_pvals, consts, env) + +partial_eval_jaxpr: Callable + +def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): + if instantiate: + return trace.instantiate_const(trace.full_raise(tracer)) + else: + return tracer diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 81e9ef2a4..01c835df4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -48,7 +48,6 @@ import numpy as np import jax from jax.errors import JAXTypeError -from jax.interpreters import partial_eval as pe from jax.tree_util import tree_flatten, tree_map from jax._src import api_util @@ -71,6 +70,7 @@ from jax._src.config import flags from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.lib import xla_client as xc diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 00d66550a..63348535e 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -78,6 +78,7 @@ from jax.interpreters import xla from jax._src import ad_util from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -362,7 +363,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, approx_top_k_p = core.Primitive('approx_top_k') approx_top_k_p.multiple_results = True -approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p)) +approx_top_k_p.def_impl(partial(dispatch.apply_primitive, approx_top_k_p)) approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval) xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation) xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation, diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index f5d0c4490..ba2c56e7a 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -25,7 +25,7 @@ from jax._src import ad_util from jax._src import util from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3 from jax.api_util import flatten_fun_nokwargs -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe from jax.tree_util import tree_map, tree_unflatten map, unsafe_map = safe_map, map diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index ce748fc47..16b98bf9f 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -22,14 +22,10 @@ import operator from typing import Callable, Sequence, List, Tuple from jax.config import config -from jax.interpreters import ad -from jax.interpreters import batching -from jax._src.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu @@ -37,6 +33,11 @@ from jax._src import source_info_util from jax._src import util from jax._src import state from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.traceback_util import api_boundary from jax._src.util import (safe_map, split_list, partition_list) @@ -803,7 +804,7 @@ def cond_bind(*args, branches, linear): cond_p = core.AxisPrimitive('cond') cond_p.multiple_results = True -cond_p.def_impl(partial(xla.apply_primitive, cond_p)) +cond_p.def_impl(partial(dispatch.apply_primitive, cond_p)) cond_p.def_effectful_abstract_eval(_cond_abstract_eval) cond_p.def_custom_bind(cond_bind) ad.primitive_jvps[cond_p] = _cond_jvp diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 15c5c1f4e..9a6b1b871 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -20,16 +20,16 @@ from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, import jax.numpy as jnp from jax import lax from jax.api_util import flatten_fun_nokwargs -from jax.interpreters import ad -from jax.interpreters import batching -from jax.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten, treedef_tuple, tree_map, tree_leaves, PyTreeDef) from jax._src import ad_util from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util @@ -299,7 +299,7 @@ def _for_impl_unrolled(body, nsteps, unroll, *args): return state mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) -for_p.def_impl(functools.partial(xla.apply_primitive, for_p)) +for_p.def_impl(functools.partial(dispatch.apply_primitive, for_p)) def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, jaxpr, nsteps, reverse, which_linear, unroll): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 8a9828c6e..7aff43cca 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -27,14 +27,15 @@ from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped from jax.interpreters import ad from jax.interpreters import batching from jax._src.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import xla from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, tree_map, tree_flatten_with_path, keystr) from jax._src.tree_util import equality_errors from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api +from jax._src import dispatch from jax._src import dtypes from jax._src import effects from jax._src import source_info_util @@ -1032,7 +1033,7 @@ def scan_bind(*args, **params): scan_p = core.AxisPrimitive("scan") scan_p.multiple_results = True scan_p.def_custom_bind(scan_bind) -scan_p.def_impl(partial(xla.apply_primitive, scan_p)) +scan_p.def_impl(partial(dispatch.apply_primitive, scan_p)) scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp ad.reducing_transposes[scan_p] = _scan_transpose @@ -1612,7 +1613,7 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, while_p = core.AxisPrimitive('while') while_p.multiple_results = True -while_p.def_impl(partial(xla.apply_primitive, while_p)) +while_p.def_impl(partial(dispatch.apply_primitive, while_p)) while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) ad.primitive_jvps[while_p] = _while_loop_jvp pe.custom_partial_eval_rules[while_p] = _while_partial_eval diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d02e86860..4099562e7 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -27,8 +27,6 @@ import numpy as np import jax from jax import tree_util -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla from jax.tree_util import tree_map from jax._src import ad_util @@ -51,7 +49,9 @@ from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla +from jax._src.interpreters import xla from jax._src.interpreters.batching import ConcatAxis from jax._src.lax import slicing from jax._src.lax.utils import ( @@ -2264,7 +2264,7 @@ def _convert_elt_type_pp_rule(eqn, context, settings): return core._pp_eqn(eqn.replace(params=params), context, settings) convert_element_type_p = Primitive('convert_element_type') -convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p)) +convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p)) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, @@ -3475,7 +3475,7 @@ def _reduce_named_shape_rule(*avals, computation, jaxpr, consts, dimensions): reduce_p = core.Primitive('reduce') reduce_p.multiple_results = True -reduce_p.def_impl(partial(xla.apply_primitive, reduce_p)) +reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, _reduce_dtype_rule, _reduce_weak_type_rule, @@ -3869,7 +3869,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys sort_p = Primitive('sort') sort_p.multiple_results = True -sort_p.def_impl(partial(xla.apply_primitive, sort_p)) +sort_p.def_impl(partial(dispatch.apply_primitive, sort_p)) sort_p.def_abstract_eval(_sort_abstract_eval) ad.primitive_jvps[sort_p] = _sort_jvp batching.primitive_batchers[sort_p] = _sort_batch_rule @@ -3960,7 +3960,7 @@ def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k): top_k_p = Primitive('top_k') top_k_p.multiple_results = True -top_k_p.def_impl(partial(xla.apply_primitive, top_k_p)) +top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p)) top_k_p.def_abstract_eval(_top_k_abstract_eval) def _top_k_lower(ctx, operand, k): return chlo.TopKOp(operand, mlir.i64_attr(k)).results @@ -3993,7 +3993,7 @@ def create_token(_=None): return create_token_p.bind() create_token_p = Primitive("create_token") -create_token_p.def_impl(partial(xla.apply_primitive, create_token_p)) +create_token_p.def_impl(partial(dispatch.apply_primitive, create_token_p)) create_token_p.def_abstract_eval(lambda *_: abstract_token) def _create_token_lowering(ctx, *operands): @@ -4015,7 +4015,7 @@ def _after_all_abstract_eval(*operands): after_all_p = Primitive("after_all") -after_all_p.def_impl(partial(xla.apply_primitive, after_all_p)) +after_all_p.def_impl(partial(dispatch.apply_primitive, after_all_p)) after_all_p.def_abstract_eval(_after_all_abstract_eval) def _after_all_lowering(ctx, *operands): @@ -4060,7 +4060,7 @@ def _infeed_abstract_eval(token, *, shapes, partitions): infeed_p = Primitive("infeed") infeed_p.multiple_results = True -infeed_p.def_impl(partial(xla.apply_primitive, infeed_p)) +infeed_p.def_impl(partial(dispatch.apply_primitive, infeed_p)) infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval) mlir.lowerable_effects.add_type(InOutFeedEffect) @@ -4111,7 +4111,7 @@ def _outfeed_abstract_eval(token, *xs, partitions): return abstract_token, {outfeed_effect} outfeed_p = Primitive("outfeed") -outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p)) +outfeed_p.def_impl(partial(dispatch.apply_primitive, outfeed_p)) outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval) mlir.lowerable_effects.add_type(InOutFeedEffect) @@ -4153,7 +4153,7 @@ def _rng_uniform_abstract_eval(a, b, *, shape): weak_type=(a.weak_type and b.weak_type)) rng_uniform_p = Primitive("rng_uniform") -rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p)) +rng_uniform_p.def_impl(partial(dispatch.apply_primitive, rng_uniform_p)) rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval) def _rng_uniform_lowering(ctx, a, b, *, shape): @@ -4247,7 +4247,7 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm): rng_bit_generator_p = Primitive("rng_bit_generator") rng_bit_generator_p.multiple_results = True rng_bit_generator_p.def_impl( - partial(xla.apply_primitive, rng_bit_generator_p)) + partial(dispatch.apply_primitive, rng_bit_generator_p)) rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, @@ -4298,7 +4298,7 @@ def _copy_impl(prim, *args, **kwargs): if isinstance(a, jax.Array) and isinstance(a.sharding, PmapSharding): sharded_dim = _which_dim_sharded(a.sharding) return _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs) - return xla.apply_primitive(prim, *args, **kwargs) + return dispatch.apply_primitive(prim, *args, **kwargs) # The copy_p primitive exists for expressing making copies of runtime arrays. # For that reason we don't simplify it out of jaxprs (e.g. for jit invariance). @@ -4354,7 +4354,7 @@ def _iota_abstract_eval(*, dtype, shape, dimension): return core.DShapedArray(shape, dtype, False) iota_p = Primitive('iota') -iota_p.def_impl(partial(xla.apply_primitive, iota_p)) +iota_p.def_impl(partial(dispatch.apply_primitive, iota_p)) iota_p.def_abstract_eval(_iota_abstract_eval) def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3b95c4136..a7cd4e1a9 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -23,10 +23,10 @@ import numpy as np import jax from jax import lax -from jax.interpreters import xla from jax._src import ad_util from jax._src import api +from jax._src import dispatch from jax._src import dtypes from jax._src.core import ( Primitive, ShapedArray, raise_to_shaped, is_constant_shape) @@ -448,10 +448,12 @@ mlir.register_lowering( # Asymmetric eigendecomposition def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): - return ( - xla.apply_primitive(eig_p, operand, - compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors)) + return dispatch.apply_primitive( + eig_p, + operand, + compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + ) def eig_lower(*args, **kw): raise NotImplementedError( @@ -577,8 +579,8 @@ def eigh_jacobi(x: ArrayLike, *, lower: bool = True, return w, v def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues): - w, v = xla.apply_primitive(eigh_jacobi_p, operand, lower=lower, - sort_eigenvalues=sort_eigenvalues) + w, v = dispatch.apply_primitive(eigh_jacobi_p, operand, lower=lower, + sort_eigenvalues=sort_eigenvalues) return w, v def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues): @@ -634,8 +636,8 @@ mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule) def _eigh_impl(operand, *, lower, sort_eigenvalues): - v, w = xla.apply_primitive(eigh_p, operand, lower=lower, - sort_eigenvalues=sort_eigenvalues) + v, w = dispatch.apply_primitive(eigh_p, operand, lower=lower, + sort_eigenvalues=sort_eigenvalues) return v, w def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues): @@ -1016,7 +1018,7 @@ def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *, lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation') lu_pivots_to_permutation_p.multiple_results = False lu_pivots_to_permutation_p.def_impl( - partial(xla.apply_primitive, lu_pivots_to_permutation_p)) + partial(dispatch.apply_primitive, lu_pivots_to_permutation_p)) lu_pivots_to_permutation_p.def_abstract_eval( _lu_pivots_to_permutation_abstract_eval) batching.primitive_batchers[lu_pivots_to_permutation_p] = ( @@ -1111,7 +1113,7 @@ def _lu_python(x): return fn(x) def _lu_impl(operand): - lu, pivot, perm = xla.apply_primitive(lu_p, operand) + lu, pivot, perm = dispatch.apply_primitive(lu_p, operand) return lu, pivot, perm def _lu_abstract_eval(operand): @@ -1385,7 +1387,7 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a): geqrf_p = Primitive('geqrf') geqrf_p.multiple_results = True -geqrf_p.def_impl(partial(xla.apply_primitive, geqrf_p)) +geqrf_p.def_impl(partial(dispatch.apply_primitive, geqrf_p)) geqrf_p.def_abstract_eval(_geqrf_abstract_eval) batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule mlir.register_lowering(geqrf_p, _geqrf_lowering_rule) @@ -1474,7 +1476,7 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus): householder_product_p = Primitive('householder_product') -householder_product_p.def_impl(partial(xla.apply_primitive, householder_product_p)) +householder_product_p.def_impl(partial(dispatch.apply_primitive, householder_product_p)) householder_product_p.def_abstract_eval(_householder_product_abstract_eval) batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule mlir.register_lowering(householder_product_p, _householder_product_lowering_rule) @@ -1494,7 +1496,7 @@ mlir.register_lowering( def _qr_impl(operand, *, full_matrices): - q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices) + q, r = dispatch.apply_primitive(qr_p, operand, full_matrices=full_matrices) return q, r def _qr_abstract_eval(operand, *, full_matrices): @@ -1572,7 +1574,7 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering)); # Singular value decomposition def _svd_impl(operand, *, full_matrices, compute_uv): - return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices, + return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices, compute_uv=compute_uv) def _svd_abstract_eval(operand, *, full_matrices, compute_uv): @@ -1762,7 +1764,7 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t tridiagonal_solve_p = Primitive('tridiagonal_solve') tridiagonal_solve_p.multiple_results = False tridiagonal_solve_p.def_impl( - functools.partial(xla.apply_primitive, tridiagonal_solve_p)) + functools.partial(dispatch.apply_primitive, tridiagonal_solve_p)) tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b) # TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve? @@ -1873,7 +1875,7 @@ def schur(x: ArrayLike, *, def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals, select_callable): - return xla.apply_primitive( + return dispatch.apply_primitive( schur_p, operand, compute_schur_vectors=compute_schur_vectors, @@ -2000,7 +2002,7 @@ def _hessenberg_abstract_eval(a): return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)] hessenberg_p = Primitive("hessenberg") -hessenberg_p.def_impl(partial(xla.apply_primitive, hessenberg_p)) +hessenberg_p.def_impl(partial(dispatch.apply_primitive, hessenberg_p)) hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval) hessenberg_p.multiple_results = True @@ -2098,7 +2100,7 @@ def _tridiagonal_abstract_eval(a, *, lower): ] tridiagonal_p = Primitive("tridiagonal") -tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p)) +tridiagonal_p.def_impl(partial(dispatch.apply_primitive, tridiagonal_p)) tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval) tridiagonal_p.multiple_results = True diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 1e8e32e49..75605fa29 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -20,7 +20,6 @@ import weakref import numpy as np import jax -from jax.interpreters import partial_eval as pe from jax._src import ad_util from jax._src import core @@ -30,6 +29,7 @@ from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax from jax._src.lax.utils import ( _argnum_weak_type, diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 5495da3d9..a193600fe 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -23,6 +23,7 @@ from typing import Callable from jax.interpreters import xla from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src.util import safe_zip from jax._src.lib import xla_client @@ -44,7 +45,7 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None, weak_type_rule = weak_type_rule or _standard_weak_type_rule named_shape_rule = named_shape_rule or standard_named_shape_rule prim = core.Primitive(name) - prim.def_impl(partial(xla.apply_primitive, prim)) + prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule)) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 7e11f950e..0334a652b 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -39,7 +39,7 @@ from jax.errors import JAXTypeError from jax._src.array import ArrayImpl from jax._src.sharding_impls import NamedSharding from jax._src.interpreters import mlir -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax.interpreters import batching @@ -965,7 +965,7 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap # This is DynamicJaxprTrace.process_map with some very minor modifications def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): - from jax.interpreters.partial_eval import ( + from jax._src.interpreters.partial_eval import ( trace_to_subjaxpr_dynamic, DynamicJaxprTracer, convert_constvars_jaxpr, new_jaxpr_eqn) assert primitive is xmap_p diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 09613769d..b06afc934 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -28,9 +28,9 @@ import jax from jax._src import core from jax import stages from jax.errors import JAXTypeError -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters.pxla import PartitionSpec +from jax._src.interpreters import xla from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, treedef_tuple, broadcast_prefix, all_leaves) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index bfc2c534c..ca20dbf33 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -21,10 +21,10 @@ from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union import numpy as np from jax import lax -from jax.interpreters import partial_eval as pe from jax._src import core from jax._src import linear_util as lu +from jax._src.interpreters import partial_eval as pe from jax._src.state.types import AbstractRef from jax._src.state.primitives import get_p, swap_p, addupdate_p from jax._src.util import safe_map, safe_zip, split_list diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 278bfd2e4..2a486c032 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -19,13 +19,13 @@ from typing import Any, List, Tuple, Union import numpy as np from jax import lax -from jax.interpreters import ad -from jax.interpreters import batching -from jax.interpreters import partial_eval as pe from jax._src import ad_util from jax._src import core from jax._src import pretty_printer as pp +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import partial_eval as pe from jax._src.typing import Array from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect, AccumEffect) diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index ef6256255..169e2ebd4 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -19,11 +19,11 @@ from jax import tree_util from jax._src import linear_util as lu from jax.experimental import pjit from jax.errors import UnexpectedTracerError +from jax._src import mesh as mesh_lib from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir import ir -import jax.interpreters.pxla as pxla from jax._src.interpreters import mlir -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe from jax._src import custom_api_util from jax._src.lib import xla_client as xc from jax._src.api_util import flatten_fun_nokwargs @@ -370,7 +370,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, propagate_user_sharding, partition, infer_sharding_from_operands, static_args): - mesh = pxla.thread_resources.env.physical_mesh + mesh = mesh_lib.thread_resources.env.physical_mesh axis_context = ctx.module_context.axis_context if isinstance(axis_context, mlir.ShardingContext): diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 6e3aaa9be..64a1a88cc 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -511,9 +511,9 @@ from jax import custom_derivatives from jax._src import dtypes from jax import lax from jax.experimental import pjit -from jax.interpreters import ad, batching, pxla -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import ad, batching, pxla from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src import ad_checkpoint from jax._src import dispatch diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index c35f4b09f..bc50f767e 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -61,7 +61,6 @@ import numpy as np from jax import lax import jax.numpy as jnp from jax.experimental import pjit -from jax.interpreters import partial_eval as pe from jax.tree_util import (register_pytree_node, tree_structure, treedef_is_leaf, tree_flatten, tree_unflatten,) @@ -70,6 +69,7 @@ from jax._src import core from jax._src import dispatch from jax._src import linear_util as lu from jax._src.api_util import shaped_abstractify +from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.util import unzip2, weakref_lru_cache diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index e2eb80854..e51ec5fdb 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -43,10 +43,10 @@ from jax._src.util import (HashableFunction, HashablePartial, unzip2, as_hashable_function, memoize, partition_list, merge_lists) from jax.api_util import flatten_fun_nokwargs, shaped_abstractify -from jax.interpreters import batching +from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import xla from jax._src.interpreters import pxla from jax.interpreters import ad from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 3abfb06ae..039ca7dc0 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -36,7 +36,6 @@ from jax.experimental.sparse.util import ( SparseEfficiencyError, SparseEfficiencyWarning, Shape, SparseInfo) from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p -from jax.interpreters import partial_eval as pe from jax._src.interpreters import mlir import jax.numpy as jnp from jax.util import safe_zip, unzip2, split_list @@ -45,6 +44,7 @@ from jax._src import core from jax._src import dispatch from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.interpreters import partial_eval as pe from jax._src.lax.lax import ( _const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers) from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index d79455e9a..92a6ce03f 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -62,9 +62,9 @@ from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_spar import jax.numpy as jnp from jax._src.api_util import flatten_fun_nokwargs from jax._src.lib import pytree -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla -from jax.interpreters import pxla +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import xla +from jax._src.interpreters import pxla from jax.tree_util import tree_flatten, tree_map, tree_unflatten from jax.util import safe_map, safe_zip, split_list from jax._src.config import config diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 7170e2d33..6cb3d9ae4 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -11,2496 +11,91 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from collections import namedtuple -from contextlib import contextmanager, AbstractContextManager -import functools -from functools import partial -import inspect -import itertools as it -import operator as op -from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple, - List, Union, Hashable, Set) -from weakref import ref - -import numpy as np - -from jax._src import linear_util as lu -from jax.config import config -from jax._src import api_util -from jax._src import core -from jax._src import effects -from jax._src import dtypes -from jax._src import profiler -from jax._src import source_info_util -from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, - fun_sourceinfo) -from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, - AbstractValue, ClosedJaxpr, new_jaxpr_eqn, - ConcreteArray, Var, DropVar, raise_to_shaped, Atom, - JaxprEqn, Primitive, ShapedArray, DShapedArray, - mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, - InputType, OutputType, get_referent) -from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten, - KeyPath, generate_key_paths, keystr) -from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, - merge_lists, partition_list, OrderedSet, - as_hashable_function, weakref_lru_cache) - - -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip -def identity(x): return x - -TracerId = int -AvalId = int -ConstId = int - -def _update_annotation_known( - f: lu.WrappedFun, - orig_type: Optional[InputType], - in_knowns: List[bool] - ) -> lu.WrappedFun: - if orig_type is None: return f - # orig_type might contain DBIdx, but we're tossing out some args so we have to - # re-index. moreover some of the implicit args may not be needed anymore. - # so we basically just re-infer the lambda input type - if (all(e for _, e in orig_type) and - not any(type(d) is DBIdx for a, _ in orig_type for d in a.shape - if type(a) is DShapedArray)): - new_type = [ty for ty, known in zip(orig_type, in_knowns) if known] - return lu.annotate(f, tuple(new_type)) - - # Replace DBIdx with names, prune down to explicit only. - class Name: - def __init__(self, a): self.a = a - names = [Name(a) for a, _ in orig_type] - avals = [a.update(shape=tuple(names[d.val] if type(d) is DBIdx else d # type: ignore - for d in a.shape)) - if type(a) is DShapedArray else a for a, e in orig_type if e] - avals = [a for a, known in zip(avals, in_knowns) if known] - # Figure out the implicit part: names which aren't explicit and known. - expl_names = [o for o, (_, e) in zip(names, orig_type) if e] - expl_names = [o for o, k in zip(expl_names, in_knowns) if k] - expl_names_ = set(expl_names) - impl_names = {d for a in avals if type(a) is DShapedArray for d in a.shape - if type(d) is Name and d not in expl_names_} - impl_part = [(n.a, False) for n in impl_names] # type: ignore - # Figure out the explicit part: known explicit avals, replacing names w/ dbidx - name_map = {n: DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))} - expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape)) - if type(a) is DShapedArray else a, True) for a in avals] - return lu.annotate(f, (*impl_part, *expl_part)) - -class PartialVal(tuple): - """Partial value: either a known value or an unknown (abstract) value. - - Represented as a pair `(aval_opt, const)` of one of two kinds: - * `(None, )` indicates a known value, where the constant is either a - Tracer or satisfies `core.valid_jaxtype(const)`; - * `(, None)` indicates an unknown value characterized by an - abstract value. - """ - def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]): - pv, const = xs - if config.jax_enable_checks: - # type checks - assert isinstance(pv, (AbstractValue, type(None))), xs - assert (const is None or isinstance(const, core.Tracer) or - core.valid_jaxtype(const)), const - # invariant checks - assert (pv is None) ^ (const is None) - return tuple.__new__(cls, xs) - - @classmethod - def known(cls, const: core.Value) -> PartialVal: - return PartialVal((None, const)) - - @classmethod - def unknown(cls, aval: AbstractValue) -> PartialVal: - return PartialVal((aval, None)) - - def is_known(self) -> bool: - return self[0] is None - - def get_known(self) -> Optional[core.Value]: - """Get the known value, if known, else None.""" - return self[1] if self[0] is None else None - - def get_aval(self) -> AbstractValue: - """Get AbstractValue directly (if unknown) or from the constant (known).""" - known = self.get_known() - if known is not None: - return get_aval(known) - else: - return self[0] - - -class JaxprTrace(Trace['JaxprTracer']): - - def __init__(self, *args, name_stack: source_info_util.NameStack): - super().__init__(*args) - self.name_stack = name_stack - - def pure(self, val: Any) -> JaxprTracer: - return self.new_const(val) - - def lift(self, val: Tracer) -> JaxprTracer: - return self.new_const(val) - - def sublift(self, val: JaxprTracer) -> JaxprTracer: - return JaxprTracer(self, val.pval, FreeVar(val)) - - def new_const(self, val) -> JaxprTracer: - if isinstance(val, Tracer) and val._trace.level == self.level: - raise Exception - return JaxprTracer(self, PartialVal.known(val), None) - - def new_instantiated_literal(self, val) -> JaxprTracer: - aval = get_aval(val) - return JaxprTracer(self, PartialVal.unknown(aval), - Literal(val, raise_to_shaped(aval))) - - def new_instantiated_const(self, val) -> JaxprTracer: - aval = get_aval(val) - if isinstance(aval, DShapedArray): - shape = [self.new_instantiated_const(d) - if isinstance(d, Tracer) and d._trace.level < self.level else d - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(val)) - - def new_arg(self, pval: PartialVal) -> JaxprTracer: - const = pval.get_known() - # XXX: Think twice before changing this constant argument pruning! - # This has really important consequences for partial_eval_jaxpr. - # Most importantly, this guarantees that the unknown jaxpr never uses - # known inputs (if it needs them, then they get passed through residuals). - if const is None: - aval = pval.get_aval() - if type(aval) is DShapedArray: - shape = [self.new_instantiated_const(d) - if isinstance(d, Tracer) and d._trace.level < self.level else d - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding()) - else: - return self.new_const(const) - - def instantiate_const(self, tracer: JaxprTracer) -> JaxprTracer: - const = tracer.pval.get_known() - if const is None: - return tracer - else: - if type(const) in core.literalable_types and np.shape(const) == (): - return self.new_instantiated_literal(const) - else: - return self.new_instantiated_const(const) - - def instantiate_const_abstracted(self, tracer) -> JaxprTracer: - const = tracer.pval.get_known() - if const is None: - return tracer - else: - aval = raise_to_shaped(get_aval(const), np.isscalar(const)) - return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) - - def process_primitive(self, primitive, tracers, params): - if primitive in custom_partial_eval_rules: - return custom_partial_eval_rules[primitive](self, *tracers, **params) - else: - return self.default_process_primitive(primitive, tracers, params) - - def default_process_primitive(self, primitive, tracers, params): - # By default, if all the input tracers are known, then bind the primitive - # and consider all outputs known. Otherwise, stage the application into the - # jaxpr and consider all outputs unknown. - consts = [t.pval.get_known() for t in tracers] - if all(c is not None for c in consts): - return primitive.bind(*consts, **params) - tracers = map(self.instantiate_const, tracers) - avals = [t.aval for t in tracers] - out_aval, effects = primitive.abstract_eval(*avals, **params) - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - if primitive.multiple_results: - out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) - for aval in out_aval] - eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, source) - for t in out_tracers: t.recipe = eqn - return out_tracers - else: - out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None) - out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, - params, effects, source) - return out_tracer - - def process_call(self, primitive, f, tracers, params): - rule = call_partial_eval_rules.get(primitive) - if rule: - return rule(self, primitive, f, tracers, params) - - update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) - in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers]) - # TODO(mattjj): check in_avals are consistent with f.in_type - - # We want to partially evaluate this call into two calls: one evaluated now - # taking known values (in_consts) as inputs and producing known values - # (out_consts) as outputs, and the other staged out as an eqn into the jaxpr - # being built. The latter takes as input residuals (res) produced as outputs - # of the first call, shared closed-over values (env), and explicit arguments - # which were unknown to the first call (corresponding to in_avals). - - # Wrap f to perform the partial evaluation and plumb out aux data. - if not config.jax_dynamic_shapes: - f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) - f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), - tuple(in_avals)) - else: - if f.in_type is None: - f = lu.annotate(f, tuple((a, True) for a in in_avals)) - f_, aux = trace_to_subjaxpr_nounits_dyn(f, self.main, tuple(in_knowns), - f.in_type, False) - # Adjust parameters (e.g. donated_invars) for the call to be evaluated now. - const_params = update_params(params, in_knowns, 0) - - # Run the call, getting known out vals and aux data used for staged-out call - out = primitive.bind(_update_annotation_known(f_, f.in_type, in_knowns), - *in_consts, **const_params) - fwds, out_knowns, out_type, jaxpr, env = aux() - # Split apart known outputs from the original call and non-fwded residuals. - out_consts, non_fwd_res_ = split_list(out, [sum(out_knowns)]) - - # Form the complete list of residuals by forwarding some inputs. - if config.jax_dynamic_shapes: - # With dynamic shapes, we may need to forward implicit arguments. - in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns) - in_consts_full = [None] * len(f.in_type) - for idx, (aval, explicit) in enumerate(f.in_type): - if explicit and next(in_knowns_): - c = in_consts_full[idx] = next(in_consts_) - if aval.shape: - for d1, d2 in zip(aval.shape, c.shape): - if type(d1) is DBIdx: - in_consts_full[d1.val] = d2 - else: - in_consts_full = in_consts - non_fwd_res = iter(non_fwd_res_) - res = [next(non_fwd_res) if i is None else in_consts_full[i] for i in fwds] - sentinel = object() - assert next(non_fwd_res, sentinel) is sentinel - - # Create the input tracers for the staged-out (unknown-value) call. - res_tracers = map(self.instantiate_const, map(self.new_const, res)) - env_tracers = map(self.full_raise, env) - unknown_arg_tracers = [t for t in tracers if not t.is_known()] - # Adjust parameters (e.g. donated_invars) for the staged-out call's args. - num_new_args = len(res_tracers) + len(env_tracers) - staged_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) - staged_params = update_params(staged_params, map(op.not_, in_knowns), - num_new_args) - # The outputs of the staged-out call are Tracers with the new eqn as recipe. - if config.jax_dynamic_shapes: - # With dynamic shapes, we may need to substitute Tracers into avals. - out_tracers = [] - for aval, _ in out_type: - assert not isinstance(aval, ConcreteArray) - if type(aval) is DShapedArray: - shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val] - if type(d) is InDBIdx else d for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - out_tracers.append(JaxprTracer(self, PartialVal.unknown(aval), None)) - else: - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_type] - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers), - out_tracers, primitive, staged_params, jaxpr.effects, - source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - def process_map(self, primitive, f: lu.WrappedFun, tracers, params): - update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) - in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers]) - - # This method is like process_call above, except: - # 1. we delete an axis from mapped-over input avals' shapes, and - # analogously add an axis to mapped-over output avals' shapes; - # 2. we update the in_axes and out_axes/out_axes_thunk parameters to - # reflect the inputs and outputs pruned from the unknown/known sides. - - # Map (delete an axis from) unknown inputs' avals as dictated by in_axes. - unk_in_axes, const_in_axes = partition_list(in_knowns, params['in_axes']) - in_avals_mapped = [mapped_aval(params['axis_size'], ax, aval) - for ax, aval in zip(unk_in_axes, in_avals)] - - # Wrap f to perform partial evaluation and plumb out aux data. - f = trace_to_subjaxpr_nounits(f, self.main, False) - f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), - tuple(in_avals_mapped)) - # Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk) - const_params = update_params(params, in_knowns, 0) # handles donated_invars - out_axes_thunk = params['out_axes_thunk'] - @as_hashable_function(closure=out_axes_thunk) - def const_out_axes_thunk(): - out_knowns, _, jaxpr, _ = aux() - _, out_axes = partition_list(out_knowns, out_axes_thunk()) - return tuple(out_axes) + (0,) * len(jaxpr.constvars) # res mapped axis 0 - const_params = dict(const_params, in_axes=tuple(const_in_axes), - out_axes_thunk=const_out_axes_thunk) - - # Run the map, getting known out vals and aux data used for staged-out map. - out = primitive.bind(f, *in_consts, **const_params) - out_knowns, out_avals_mapped, jaxpr, env = aux() - # Split apart known outputs from the original call and residuals. - out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) - - # We can only check_jaxpr with the dynamic axis environment extended: - with core.extend_axis_env(params['axis_name'], params['axis_size'], None): - call_jaxpr = convert_constvars_jaxpr(jaxpr) - - # Compute staged and const out_axes, taking into account residuals. - out_axes = params['out_axes_thunk']() - staged_out_axes, _ = partition_list(out_knowns, out_axes) - staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,) - - # Create the input tracers for the staged-out (unkonwn-value) call. - const_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.full_raise, env) - unknown_arg_tracers = [t for t in tracers if not t.is_known()] - # Adjust params for staged-out call on unknown values. - num_new_args = len(const_tracers) + len(env_tracers) - staged_params = update_params(params, map(op.not_, in_knowns), num_new_args) - staged_params = dict(staged_params, in_axes=staged_in_axes, - out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr) - del staged_params['out_axes_thunk'] - # The outputs of the staged-out call are Tracers with the new eqn as recipe. - out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a) - for ax, a in zip(staged_out_axes, out_avals_mapped)] - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_avals] - eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type] - out_tracers, primitive, staged_params, - jaxpr.effects, - source_info_util.current()) - for t in out_tracers: t.recipe = eqn - - return merge_lists(out_knowns, out_tracers, out_consts) - - def post_process_call(self, primitive, out_tracers, params): - unknown_out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers) - out_pvals = [t.pval for t in out_tracers] - out_knowns, out_avals, out_consts = partition_pvals(out_pvals) - out = [*out_consts, *res] - main = self.main - - def todo(out): - trace = main.with_cur_sublevel() - out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) - const_tracers = map(trace.new_instantiated_const, res) - in_tracers = (*const_tracers, *map(trace.full_raise, env)) - out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) - for a in out_avals] - update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) - new_params = update_params(params, [], len(in_tracers)) - new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, - jaxpr.effects, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - return out, todo - - def post_process_map(self, primitive, out_tracers, params): - unknown_out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers) - out_pvals = [t.pval for t in out_tracers] - out_knowns, out_avals_mapped, out_consts = partition_pvals(out_pvals) - out = [*out_consts, *res] - main = self.main - - with core.extend_axis_env(params['axis_name'], params['axis_size'], None): - call_jaxpr = convert_constvars_jaxpr(jaxpr) - - def todo(out): - trace = main.with_cur_sublevel() - out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) - const_tracers = map(trace.new_instantiated_const, res) - env_tracers = map(trace.full_raise, env) - - staged_out_axes = tuple(out_axes_unknown) # set by out_axes_transform - staged_in_axes = (0,) * len(res) + (None,) * len(env) - - update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) - staged_params = update_params(params, [], len(res) + len(env)) - staged_params = dict(staged_params, in_axes=staged_in_axes, - out_axes=tuple(staged_out_axes), - call_jaxpr=call_jaxpr) - - out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], d, a) - for d, a in zip(staged_out_axes, out_avals_mapped)] - out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) - for a in out_avals] - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, - primitive, staged_params, jaxpr.effects, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - def out_axes_transform(out_axes): - nonlocal out_axes_unknown - out_axes_unknown, out_axes_known = partition_list(out_knowns, out_axes) - return tuple(out_axes_known) + (0,) * len(jaxpr.constvars) - out_axes_unknown: Optional[list] = None - - return out, (todo, out_axes_transform) - - def _current_truncated_name_stack(self): - return source_info_util.current_name_stack()[len(self.name_stack):] - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - # We assume partial evaluation is only performed to build linear functions, - # and hence we don't need to keep the custom JVP rule around anymore. - del jvp, symbolic_zeros - assert not all(t.is_known() for t in tracers) - return fun.call_wrapped(*tracers) - - def post_process_custom_jvp_call(self, out_tracers, _): - # This path should only be reachable if we expose a partial eval API - # unrelated to autodiff, since we raise an error when differentiation with - # respect to values over which a custom_jvp function closes is detected. - raise NotImplementedError # TODO(mattjj) - - def process_custom_transpose(self, prim, call, tracers, **params): - res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves]) - assert all(t.is_known() for t in res_ts) - lin_all_known = all(t.is_known() for t in lin_ts) - if lin_all_known: - res_cvals = [t.pval[1] for t in res_ts] - lin_cvals = [t.pval[1] for t in lin_ts] - return prim.bind(call, *res_cvals, *lin_cvals, **params) - else: - out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) - for aval in params['out_types']] - in_tracers = map(self.instantiate_const, tracers) - new_params = dict(params, call=call) - eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params, - core.no_effects, source_info_util.current()) - for t in out_tracers: t.recipe = eqn - return out_tracers - - def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees): - # TODO(mattjj): after old remat is deleted, make this method trivial. - # Because we instantiate all tracers, in_knowns is all False. - tracers = map(self.instantiate_const_abstracted, tracers) - in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) - f = trace_to_subjaxpr_nounits(f, self.main, True) - f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees) - out_knowns, out_avals, jaxpr, env = aux() - out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - res_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.full_raise, env) - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_avals] - closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) - - @_memoize - def fwd_jaxpr_thunk(): - fwd_ = trace_to_subjaxpr_nounits(fwd, self.main, True) - fwd_, aux = partial_eval_wrapper_nounits( - fwd_, tuple(in_knowns), tuple(in_avals)) - with core.new_sublevel(): - out_flat = fwd_.call_wrapped() - out_knowns, out_avals, jaxpr, env = aux() - _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) - return converted_jaxpr, (*res, *env) - - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers), - out_tracers, prim.initial_style, - dict(fun_jaxpr=closed_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_thunk, - num_consts=len(res) + len(env), - bwd=bwd, out_trees=out_trees), - jaxpr.effects, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - def post_process_custom_vjp_call(self, out_tracers, _): - # This path should only be reachable if we expose a partial eval API - # unrelated to autodiff, since we raise an error when differentiation with - # respect to values over which a custom_vjp function closes is detected. - raise NotImplementedError # TODO(mattjj) - -def partition_pvals( - pvals: List[PartialVal] - ) -> Tuple[List[bool], List[AbstractValue], List[Any]]: - knowns = [pval.is_known() for pval in pvals ] - avals = [pval.get_aval() for pval in pvals if not pval.is_known()] - consts = [pval.get_known() for pval in pvals if pval.is_known()] - return knowns, avals, consts - -@lu.transformation_with_aux -def partial_eval_wrapper_nounits( - in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], - *in_consts: Any): - in_avals_, in_consts_ = iter(in_avals), iter(in_consts) - in_pvals = [PartialVal.known(next(in_consts_)) if known else - PartialVal.unknown(next(in_avals_)) for known in in_knowns] - sentinel = object() - assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel - jaxpr, (*maybe_fwds, out_pvals, res, env) = yield (in_pvals,), {} - out_knowns, out_avals, out_consts = partition_pvals(out_pvals) - yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env) - -@lu.transformation_with_aux -def trace_to_subjaxpr_nounits_dyn( - main: core.MainTrace, in_knowns: Sequence[bool], in_type: InputType, - instantiate: Union[bool, Sequence[bool]], - *in_consts: Any): - trace = main.with_cur_sublevel() - in_avals, which_explicit = unzip2(in_type) - - # To form input tracers from in_type, we need to first build ConstVar tracers - # for all axis sizes, so that we can then use those tracers in the shapes of - # avals for unknown inputs' tracers. We use ConstVar recipes for on-the-fly - # type agreement checking via get_referent. - in_consts_full: List[Optional[JaxprTracer]] = [None] * len(in_type) - in_consts_iter, in_knowns_iter = iter(in_consts), iter(in_knowns) - for idx, (aval, explicit) in enumerate(in_type): - if explicit and next(in_knowns_iter): - constval = next(in_consts_iter) - if isinstance(aval, DShapedArray): - for i, d in enumerate(aval.shape): - if isinstance(d, DBIdx): - if in_consts_full[d.val] is None: - in_consts_full[d.val] = \ - JaxprTracer(trace, PartialVal.unknown(in_avals[d.val]), - ConstVar(constval.shape[i])) - assert core.same_referent(constval.shape[i], in_consts_full[d.val]) - shape = [in_consts_full[d.val] if type(d) is DBIdx else d # type: ignore - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - in_consts_full[idx] = JaxprTracer(trace, PartialVal.unknown(aval), - ConstVar(constval)) - # Check that we covered all axis sizes with ConstVar tracers. - for idx, (aval, explicit) in enumerate(in_type): - if not explicit: assert in_consts_full[idx] is not None - if isinstance(aval, DShapedArray): - assert all(type(d) is not DBIdx or in_consts_full[d.val] is not None # type: ignore - for d in aval.shape) - - # Next, build tracers for all unknown inputs, using the in_consts_full list - # for axis size tracers when necessary. - in_tracers = [] - in_knowns_iter = iter(in_knowns) - for aval, explicit in in_type: - if explicit and not next(in_knowns_iter): - if isinstance(aval, DShapedArray): - shape = [in_consts_full[d.val] if type(d) is DBIdx else d # type: ignore - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - tracer = JaxprTracer(trace, PartialVal.unknown(aval), LambdaBinding()) - in_tracers.append(tracer) - - # Merge in_consts and in_tracers and call wrapped fn with explicit arguments. - in_args = merge_lists(in_knowns, in_tracers, in_consts) - ans = yield in_args, {} - - # Instantiate outputs and build jaxpr. - if isinstance(instantiate, bool): - instantiate = [instantiate] * len(ans) - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t - for inst, t in zip(instantiate, out_tracers)] - - # Collect known outputs. - out_knowns: List[bool] = [t.is_known() for t in out_tracers] - out_consts: List[Any] = [t.pval.get_known() for t in out_tracers - if t.is_known()] - - # Build the jaxpr. - out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr(in_tracers, out_tracers) - out_avals = [v.aval for v in jaxpr.outvars] - idx_map = {v: InDBIdx(i) - for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))} - out_type = [(a.update(shape=tuple(idx_map.get(d, d) for d in a.shape)) # type: ignore - if type(a) is DShapedArray else a, True) for a in out_avals] - - # Which residuals are just forwarded inputs? Check obj id, then prune. - id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore - if c is not None} - fwds: List[Optional[int]] = [id_map.get(id(c)) for c in res] - res = tuple([c for c, fwd in zip(res, fwds) if fwd is None]) - - del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \ - in_tracers, in_args, ans, out_tracers, out_avals - yield (*out_consts, *res), (fwds, out_knowns, tuple(out_type), jaxpr, env) - - -custom_partial_eval_rules: Dict[Primitive, Callable] = {} -call_partial_eval_rules: Dict[Primitive, Callable] = {} -call_param_updaters: Dict[Primitive, Callable] = {} - -def _closed_call_param_updater(params, _, __): - jaxpr = params.get('call_jaxpr') - if jaxpr is None: return params - assert type(jaxpr) is core.Jaxpr - return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ())) -call_param_updaters[core.closed_call_p] = _closed_call_param_updater - -def abstract_eval_fun(fun, *avals, debug_info=None, **params): - _, avals_out, _ = trace_to_jaxpr_dynamic( - lu.wrap_init(fun, params), avals, debug_info) - assert all(isinstance(aval, AbstractValue) for aval in avals_out) - return avals_out - - -JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar', - 'ConstVar', Literal] - -class JaxprTracer(Tracer): - __slots__ = ['pval', 'recipe'] - - def __init__(self, trace: JaxprTrace, pval: PartialVal, - recipe: Optional[JaxprTracerRecipe]): - assert isinstance(pval, PartialVal) - pv, const = pval - if isinstance(const, Tracer) and const._trace.level >= trace.level: - raise core.escaped_tracer_error( - const, f"Tracer from a higher level: {const} in trace {trace}") - if isinstance(pv, DShapedArray): - assert all(not isinstance(d, Tracer) or isinstance(d, JaxprTracer) and - d._trace.level == trace.level for d in pv.shape) - self._trace = trace - self.pval = pval - self.recipe = recipe - - def __repr__(self): - return f'Traced<{self.aval}:{self._trace}>' - - @property - def aval(self) -> AbstractValue: - return self.pval.get_aval() - - @property - def parents(self) -> Sequence[JaxprTracer]: - if isinstance(self.recipe, JaxprEqnRecipe): - # TODO broadcast_in_dim can create a new tracer... - return self.recipe.in_tracers - elif isinstance(self.aval, DShapedArray): - return [d for d in self.aval.shape if isinstance(d, JaxprTracer)] - else: - return [] - - def full_lower(self): - known = self.pval.get_known() - if known is not None: - return core.full_lower(known) - else: - return self - - def is_known(self): - return self.pval.is_known() - - def get_referent(self): - if self.pval.is_known(): - return get_referent(self.pval.get_known()) - elif isinstance(self.recipe, (FreeVar, ConstVar, Literal)): - return get_referent(self.recipe.val) # pytype: disable=attribute-error - else: - return self - - -@profiler.annotate_function -def trace_to_jaxpr( - fun: lu.WrappedFun, pvals: Sequence[PartialVal], - instantiate: Union[bool, Sequence[bool]] = False, - ) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]: - """ - Partially evaluate a function, building a jaxpr for un-evaluated computation. - - Args: - fun: lu.WrappedFun representing the function to be partially evaluated. The - function must be flattened, in the sense of accepting jaxpr type arguments - and returning a flat list of jaxpr type outputs. - pvals: sequence of PartialVals of length equal to the number of inputs to - `fun` indicating which inputs are known or unknown. - instantiate: optional bool or sequence of bools of length equal to the - number of outputs of `fun` indicating which outputs should be forced to be - treated as unknown and hence instantiated in the jaxpr. If a single bool, - the value is applied to all outputs. Default False. - - Returns: - A triple where the first element is a jaxpr representing the computation - which depends on unknown inputs; the second element is a list of PartialVals - of length equal to the length of the output of `fun` representing which - outputs are known and unknown (along with their values and abstract values, - respectively); the third element is a list of known residual values. The - returned jaxpr takes as inputs the known residual values followed by values - of the originally unknown inputs. - """ - current_name_stack = source_info_util.current_name_stack() - with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: - fun = trace_to_subjaxpr(fun, main, instantiate) - jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) - assert not env - del main, fun, env - - return jaxpr, out_pvals, consts - -@profiler.annotate_function -def trace_to_jaxpr_nounits( - fun: lu.WrappedFun, pvals: Sequence[PartialVal], - instantiate: Union[bool, Sequence[bool]] = False, - ) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]: - current_name_stack = source_info_util.current_name_stack() - with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: - fun = trace_to_subjaxpr_nounits(fun, main, instantiate) - jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) - assert not env - del main, fun, env - return jaxpr, out_pvals, consts - - -@lu.transformation -def trace_to_subjaxpr_nounits( - main: core.MainTrace, - instantiate: Union[bool, Sequence[bool]], - in_pvals: Sequence[PartialVal]): - assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - main, instantiate, in_pvals) - out_pvals = [t.pval for t in out_tracers] - del out_tracers - yield jaxpr, (out_pvals, out_consts, env) - -def _trace_to_subjaxpr_nounits(main, instantiate, in_pvals): - trace = main.with_cur_sublevel() - in_knowns = [pval.is_known() for pval in in_pvals] - in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] - in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] - in_args = merge_lists(in_knowns, in_tracers, in_consts) - ans = yield in_args, {} - assert isinstance(ans, (list, tuple)), ( - f"Got unexpected return type when tracing function to jaxpr: {ans}") - assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( - f"Got unexpected return type when tracing function to jaxpr: {ans}") - if isinstance(instantiate, bool): - instantiate = [instantiate] * len(ans) - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t - for inst, t in zip(instantiate, out_tracers)] - out_tracers_ = [t for t in out_tracers if not t.is_known()] - jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_) - return out_tracers, jaxpr, out_consts, env - -# The below variant implements an optimization where residuals which are also -# inputs are indicated in auxiliary data rather than passed as outputs. -# TODO(mattjj): update all callers to use this version, delete other version. -@lu.transformation -def trace_to_subjaxpr_nounits_fwd( - main: core.MainTrace, - instantiate: Union[bool, Sequence[bool]], - in_pvals: Sequence[PartialVal]): - assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - main, instantiate, in_pvals) - out_pvals = [t.pval for t in out_tracers] - - # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. - in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] - id_map = {id(c): i for i, c in enumerate(in_consts)} - fwds: List[Optional[int]] = [id_map.get(id(c)) for c in out_consts] - pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] - - del out_tracers - yield jaxpr, (fwds, out_pvals, pruned_consts, env) - - -FreeVar = namedtuple('FreeVar', ['val']) -ConstVar = namedtuple('ConstVar', ['val']) -LambdaBinding = namedtuple('LambdaBinding', []) -class JaxprEqnRecipe(NamedTuple): - eqn_id: Any - in_tracers: Sequence[JaxprTracer] - out_tracer_refs: Sequence[ref[JaxprTracer]] - out_avals: Sequence[core.AbstractValue] - primitive: Primitive - params: Dict[str, Any] - effects: core.Effects - source_info: source_info_util.SourceInfo - -def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], - out_tracers: Sequence[JaxprTracer], - primitive: Primitive, - params: Dict[str, Any], - effects: core.Effects, - source_info: source_info_util.SourceInfo - ) -> JaxprEqnRecipe: - # TODO(necula): move these checks to core.check_jaxpr, and call in more places - if primitive.call_primitive or primitive.map_primitive: - assert "call_jaxpr" in params - assert ("donated_invars" not in params or - len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) - if primitive.map_primitive: - assert ("in_axes" in params and - len(params["in_axes"]) == len(params["call_jaxpr"].invars)) - assert ("donated_invars" in params and - len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) - out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] - return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), - out_avals, primitive, params, effects, source_info) - - -def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], - recipe: JaxprEqnRecipe) -> core.JaxprEqn: - (_, in_tracers, out_tracer_refs, out_avals, prim, params, eff, src) = recipe - invars = [getvar(t) for t in in_tracers] - out_tracers = [t_ref() for t_ref in out_tracer_refs] - outvars = [DropVar(a) if t is None else getvar(t) # type: ignore - for a, t in zip(out_avals, out_tracers)] - return new_jaxpr_eqn(invars, outvars, prim, params, eff, src) - -def tracers_to_jaxpr( - in_tracers: Sequence[JaxprTracer], - out_tracers: Sequence[JaxprTracer] - ) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]: - """Constructs Jaxpr given tracers for inputs and outputs. - - Params: - in_tracers: the tracers that were created for the function inputs - out_tracers: the tracers that were output by the function. - - Returns: a triple of a `Jaxpr`, a list of constant values corresponding to - the `constvars` in the returned Jaxps, and a list of environment values. - The vars for the environment values have been prepended to the Jaxpr's - `invars`. - """ - gensym = core.gensym() - - t_to_var: Dict[TracerId, Var] = {} - consts: Dict[Var, Any] = {} - env: Dict[Var, JaxprTracer] = {} - constid_to_var: Dict[ConstId, Var] = {} # for deduplication - - def get_atom(t: JaxprTracer) -> Atom: - return t.recipe if type(t.recipe) is Literal else t_to_var[id(t)] - - def newvar(t: Optional[JaxprTracer]) -> Var: - assert t is not None - var = gensym(type_substitute(t.aval)) - var_ = t_to_var.setdefault(id(t), var) - assert var is var_ - return var - - def type_substitute(aval: AbstractValue) -> AbstractValue: - if isinstance(aval, DShapedArray): - # Replace any Tracers in aval.shape with Vars or Literal values - shape = [get_atom(d) if type(d) is JaxprTracer else d for d in aval.shape] - shape = [d.val if type(d) is Literal else d for d in shape] - aval = aval.update(shape=tuple(shape)) - return aval - - processed_eqn_ids = set() - eqns: List[core.JaxprEqn] = [] - for t in toposort([*in_tracers, *out_tracers]): - r = t.recipe - if isinstance(r, JaxprEqnRecipe): - # TODO broadcast_in_dim can create a new tracer, not present in parents - if r.eqn_id not in processed_eqn_ids: - in_atoms = map(get_atom, r.in_tracers) - outvars = [DropVar(type_substitute(a)) if rf() is None else newvar(rf()) - for a, rf in zip(r.out_avals, r.out_tracer_refs)] - eqns.append(new_jaxpr_eqn(in_atoms, outvars, r.primitive, r.params, - r.effects, r.source_info)) - processed_eqn_ids.add(r.eqn_id) - elif isinstance(r, LambdaBinding): - if not any(t is in_tracer for in_tracer in in_tracers): - raise core.escaped_tracer_error(t, f"Tracer not in input tracers: {t}") - newvar(t) - elif isinstance(r, ConstVar): - var = constid_to_var.get(id(r.val)) - if var is None: - var = constid_to_var[id(r.val)] = newvar(t) - consts[var] = r.val - t_to_var[id(t)] = var - elif isinstance(r, FreeVar): - env[newvar(t)] = r.val # type: ignore - elif isinstance(r, Literal): - pass - elif r is None: - assert False - else: - raise TypeError(r) - - env_vars, env_vals = unzip2(env.items()) - invars = [*env_vars, *map(get_atom, in_tracers)] - const_vars, const_vals = unzip2(consts.items()) - outvars = map(get_atom, out_tracers) # type: ignore[arg-type] - jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns) - jaxpr = Jaxpr(const_vars, invars, # type: ignore[list-item,arg-type] - outvars, eqns, jaxpr_effects) - config.jax_enable_checks and core.check_jaxpr(jaxpr) - # del getvar # needed to avoid cyclic-reference closure, apparently! - return jaxpr, const_vals, env_vals - -@weakref_lru_cache -def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: - """Moves the constvars to the start of invars.""" - config.jax_enable_checks and core.check_jaxpr(jaxpr) - dbg = jaxpr.debug_info and jaxpr.debug_info._replace( - arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) - lifted_jaxpr = Jaxpr(constvars=(), - invars=jaxpr.constvars + jaxpr.invars, - outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects, debug_info=dbg) - config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) - return lifted_jaxpr - -@weakref_lru_cache -def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr: - """Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr.""" - if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects): - raise NotImplementedError - config.jax_enable_checks and core.check_jaxpr(jaxpr) - constvars, invars = split_list(jaxpr.invars, [n]) - dbg = jaxpr.debug_info and jaxpr.debug_info._replace( - arg_names=jaxpr.debug_info.arg_names[n:]) - lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars, - debug_info=dbg) - config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) - return lifted_jaxpr - -def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr: - if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects): - raise NotImplementedError - config.jax_enable_checks and core.check_jaxpr(jaxpr) - env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) - converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, - invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects) - config.jax_enable_checks and core.check_jaxpr(converted_jaxpr) - return converted_jaxpr - - -def partial_eval_jaxpr_nounits( - jaxpr: ClosedJaxpr, unknowns: Sequence[bool], - instantiate: Union[bool, Sequence[bool]], - ) -> Tuple[ClosedJaxpr, ClosedJaxpr, List[bool], List[AbstractValue]]: - """Unzip a jaxpr in two by data dependence into 'known' and 'unknown' parts. - - That is, given a jaxpr and a sequence of booleans indicating which jaxpr - inputs (i.e. invars) are considered unknown, produce two jaxprs, a list of - booleans representing which of the original jaxpr's outputs are unknown (i.e. - have a data dependence on an unknown input), and a list of abstract values - representing residuals (part of the first jaxpr's output and the second - jaxpr's input). The two jaxprs result from partitioning the original jaxpr's - first-order primitive applications based on whether all the inputs to the - application are known (in which case the application is represented in the - 'known' jaxpr and its result is considered known) or whether any inputs to the - application are unknown (in which case the application is represented in the - 'unknown' jaxpr and its result is considered unknown). Higher-order primitives - are recursively unzipped in two. - - The `instantiate` argument can be used to ensure some outputs are lifted into - the 'unknown' jaxpr. - - For example, give an input jaxpr: - - { lambda ; a:f32[] b:f32[]. let - c:f32[] = cos a - d:f32[] = sin a - e:f32[] = neg d - f:f32[] = mul e b - in (c, f) } - - then applying this function with `unknowns=[False, True]` and - `instantiate=False` produces as an output triple: - - # jaxpr_known - { lambda ; a:f32[]. let - b:f32[] = cos a - c:f32[] = sin a - d:f32[] = neg c - in (b, d) } - - # jaxpr_unknown - { lambda ; a:f32[] b:f32[]. let c:f32[] = mul b a in (c,) } - - # out_unknowns - [False, True] - - Notice in particular that the first output (jaxpr_known) contains all the - primitive applications which do not have a data dependence on an unknown - input. Also notice the input and output types: the input type of the first - jaxpr produced represents the type of the known inputs of the original jaxpr, - and the output type of the second jaxpr produced represents the type of the - unknown outputs of the original jaxpr. - - In the above example, the output of jaxpr_known named `d` is a _residual_ - output, and corresponds to the input named `a` in jaxpr_unknown. In general, - jaxpr_known will produce extra outputs (at the end of its output list) - corresponding to intermeidate values of the original jaxpr which must be - passed to jaxpr_unknown (as leading inputs). - """ - instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate) - -@weakref_lru_cache -def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate): - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) - - cell = [] - def fun(*known_vals_in): - known_vals_in = iter(known_vals_in) - unknown_avals = (a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk) - in_pvals = [PartialVal.unknown(next(unknown_avals)) if uk - else PartialVal.known(next(known_vals_in)) for uk in in_unknowns] - assert next(known_vals_in, None) is next(unknown_avals, None) is None - jaxpr_unknown_, out_pvals, residuals = trace_to_jaxpr_nounits( - f, in_pvals, instantiate=instantiate) - jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) - out_unknowns = [not pval.is_known() for pval in out_pvals] - res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals] - cell.append((out_unknowns, jaxpr_unknown, res_avals)) - known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] - return [*known_vals_out, *residuals] - - known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk] - jaxpr_known, _, consts_known = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals) - (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking - - # check jaxpr_known and jaxpr_unknown in isolation - # TODO(mattjj): enable weak type checking here - if config.jax_enable_checks: - core.check_jaxpr(jaxpr_known) - core.check_jaxpr(jaxpr_unknown) - # check jaxpr_known has input type corresponding to known inputs of jaxpr - assert ([v.aval for v in jaxpr_known.invars] == - [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]) - # check jaxpr_known has out type corresponding to known outs of jaxpr plus res - assert ([v.aval.strip_weak_type() for v in jaxpr_known.outvars] == - [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) - if not uk] + [a.strip_weak_type() for a in res_avals]) - # check jaxpr_unknown has input type corresponding to res plus unknown inputs - assert ([v.aval.strip_weak_type() for v in jaxpr_unknown.invars] == - [a.strip_weak_type() for a in res_avals] + - [a.strip_weak_type() for a, uk in zip(jaxpr.in_avals, in_unknowns) - if uk]) - # check jaxpr_unknown has output type corresponding to unknown outputs - assert ([v.aval.strip_weak_type() for v in jaxpr_unknown.outvars] == - [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) - if uk]) - - closed_jaxpr_known = ClosedJaxpr(jaxpr_known, consts_known) - closed_jaxpr_unknown = ClosedJaxpr(jaxpr_unknown, ()) - return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals - - -def partial_eval_jaxpr_custom( - jaxpr: Jaxpr, - in_unknowns: Sequence[bool], - in_inst: Union[bool, Sequence[bool]], - ensure_out_unknowns: Union[bool, Sequence[bool]], - ensure_out_inst: Union[bool, Sequence[bool]], - saveable: Callable[..., bool], - ) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]: - if type(in_inst) is bool: - in_inst = (in_inst,) * len(jaxpr.invars) - if type(ensure_out_unknowns) is bool: - ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars) - if type(ensure_out_inst) is bool: - ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars) - return _partial_eval_jaxpr_custom_cached( - jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns), - tuple(ensure_out_inst), saveable) - -@weakref_lru_cache -def _partial_eval_jaxpr_custom_cached( - jaxpr: Jaxpr, - in_unknowns: Tuple[bool, ...], - in_inst: Tuple[bool, ...], - ensure_out_unknowns: Tuple[bool, ...], - ensure_out_inst: Tuple[bool, ...], - saveable: Callable[..., bool], - ) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]: - env: Dict[Var, Tuple[bool, bool]] = {} - residuals: OrderedSet[Var] = OrderedSet() - - def read(x: Atom) -> Tuple[bool, bool]: - if type(x) is Var: - return env[x] - return (False, True) - - def write(unk: bool, inst: bool, v: Var) -> None: - assert (unk, inst) != (True, False) - env[v] = (unk, inst) - - def ensure_instantiated(inst: bool, x: Atom) -> Atom: - if type(x) is Var and not inst: - residuals.add(x) - return x - - known_eqns, staged_eqns = [], [] - map(write, in_unknowns, in_inst, jaxpr.invars) - map(partial(write, False, True), jaxpr.constvars) - for eqn in jaxpr.eqns: - unks_in, inst_in = unzip2(map(read, eqn.invars)) - rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive) - if rule: - eqn1, eqn2, unks_out, inst_out, res = rule(saveable, unks_in, inst_in, eqn) - eqn1 and known_eqns.append(eqn1); eqn2 and staged_eqns.append(eqn2) # type: ignore - residuals.update(res) - map(write, unks_out, inst_out, eqn.outvars) - elif any(unks_in): - inputs = map(ensure_instantiated, inst_in, eqn.invars) - staged_eqns.append(eqn.replace(invars=inputs)) - map(partial(write, True, True), eqn.outvars) - else: - known_eqns.append(eqn) - # If it's an effectful primitive, we always to run and avoid staging it. - if eqn.effects or saveable( - eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params): - map(partial(write, False, False), eqn.outvars) - else: - inputs = map(ensure_instantiated, inst_in, eqn.invars) - staged_eqns.append(eqn.replace(invars=inputs)) - map(partial(write, False, True), eqn.outvars) - unzipped = unzip2(map(read, jaxpr.outvars)) - out_unknowns, out_inst = list(unzipped[0]), list(unzipped[1]) - assert all(type(v) is Var for v in residuals), residuals - - for x, inst, ensure_inst in zip(jaxpr.outvars, out_inst, ensure_out_inst): - if ensure_inst: ensure_instantiated(inst, x) - out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns) - out_inst = map(op.or_, out_inst, ensure_out_inst) - - ins_known, _ = partition_list(in_unknowns, jaxpr.invars) - outs_known, _ = partition_list(out_unknowns, jaxpr.outvars) - known_outvars = [*outs_known, *residuals] - known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known, known_outvars, - known_eqns) - jaxpr_known = Jaxpr(jaxpr.constvars, ins_known, known_outvars, - known_eqns, known_effects) - config.jax_enable_checks and core.check_jaxpr(jaxpr_known) - - _, ins_staged = partition_list(in_inst, jaxpr.invars) - _, outs_staged = partition_list(out_inst, jaxpr.outvars) - staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns)) - staged_invars = [*residuals, *ins_staged] - staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars, - outs_staged, staged_eqns) - jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars, - outs_staged, staged_eqns, staged_effects) - config.jax_enable_checks and core.check_jaxpr(jaxpr_staged) - - return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals) - -# A primitive rule for policy-driven partial evaluation returns a 5-tuple -# with the components representing, respectively: -# * the JaxprEqn for the 'known' side (or None if there is no known component), -# * the JaxprEqn for the 'unknown' side (or None), -# * a list of booleans indicating which of the original outputs are unknown, -# * a list of booleans indicating which of the original outputs are -# instantiated (i.e. available) in the 'unknown' side, -# * a list of Var instances representing residuals to be added (i.e. to be -# plumbed as outputs of the 'known' side jaxpr and added as input binders to -# the 'unknown' jaxpr). -PartialEvalCustomResult = Tuple[Optional[JaxprEqn], Optional[JaxprEqn], - Sequence[bool], Sequence[bool], List[Var]] -PartialEvalCustomRule = Callable[ - [Callable[..., bool], Sequence[bool], Sequence[bool], JaxprEqn], - PartialEvalCustomResult] -partial_eval_jaxpr_custom_rules: Dict[Primitive, PartialEvalCustomRule] = {} - -def partial_eval_jaxpr_custom_rule_not_implemented( - name: str, saveable: Callable[..., bool], unks_in: Sequence[bool], - inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult: - msg = (f'custom-policy remat rule not implemented for {name}, ' - 'open a feature request at https://github.com/google/jax/issues!') - raise NotImplementedError(msg) - - -ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool], - Sequence[bool], int, dict, dict], - Tuple[dict, dict]] -ResAvalUpdater = Callable[[Dict[str, Any], AbstractValue], AbstractValue] -def _default_res_aval_updater( - params: Dict[str, Any], aval: AbstractValue) -> AbstractValue: - return aval - -@contextmanager -def trivial_ctx(_): yield - -def call_partial_eval_custom_rule( - jaxpr_param_name: str, params_updater: ParamsUpdater, - saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool], - eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx: Callable[[core.ParamDict], AbstractContextManager[None]] = trivial_ctx, - ) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]: - jaxpr = eqn.params[jaxpr_param_name] - with ctx(eqn.params): - jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ - partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) - ins_known, _ = partition_list(unks_in, eqn.invars) - out_binders_known, _ = partition_list(unks_out, eqn.outvars) - _, ins_staged = partition_list(inst_in, eqn.invars) - _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym([jaxpr_known, jaxpr_staged]) - params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} - params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} - params_known, params_staged = params_updater( - unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, - params_staged) - residuals = [newvar(res_aval(params_known, var.aval)) - for var in jaxpr_staged.invars[:num_res]] - eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], - eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info) - eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged, - eqn.primitive, params_staged, - jaxpr_staged.effects, eqn.source_info) - assert len(eqn_staged.invars) == len(jaxpr_staged.invars) - new_inst = [x for x, inst in zip(eqn.invars, inst_in) - if type(x) is Var and not inst] - return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals - -def closed_call_partial_eval_custom_rule( - jaxpr_param_name: str, params_updater: ParamsUpdater, - saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool], - eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]: - # TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule. - closed_jaxpr = eqn.params[jaxpr_param_name] - jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \ - partial_eval_jaxpr_custom(closed_jaxpr.jaxpr, unks_in, inst_in, - False, False, saveable) - # Forming these fresh ClosedJaxprs defeats caching, but caller handles caching - jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts) - jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts) - ins_known, _ = partition_list(unks_in, eqn.invars) - out_binders_known, _ = partition_list(unks_out, eqn.outvars) - _, ins_staged = partition_list(inst_in, eqn.invars) - _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym([jaxpr_known.jaxpr, jaxpr_staged.jaxpr]) - params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} - params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} - params_known, params_staged = params_updater( - unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, - params_staged) - residuals = [newvar(res_aval(params_known, a)) - for a in jaxpr_staged.in_avals[:num_res]] - eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], - eqn.primitive, params_known, jaxpr_known.effects, - eqn.source_info) - eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged, - eqn.primitive, params_staged, jaxpr_staged.effects, - eqn.source_info) - assert len(eqn_staged.invars) == len(jaxpr_staged.in_avals) - new_inst = [x for x, inst in zip(eqn.invars, inst_in) - if type(x) is Var and not inst] - return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals - -partial_eval_jaxpr_custom_rules[core.call_p] = \ - partial(call_partial_eval_custom_rule, 'call_jaxpr', - lambda _, __, ___, ____, _____, x, y: (x, y)) -partial_eval_jaxpr_custom_rules[core.closed_call_p] = \ - partial(closed_call_partial_eval_custom_rule, 'call_jaxpr', - lambda _, __, ___, ____, _____, x, y: (x, y)) - - -def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]: - # Compute which inputs are just forwarded to outputs. - fwds: Dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars)) - for eqn in jaxpr.eqns: - if eqn.primitive in forwarding_rules: - eqn = eqn.replace(invars=[a if type(a) is Literal else fwds.get(a, a) # type: ignore - for a in eqn.invars]) - fwd_vars, _ = forwarding_rules[eqn.primitive](eqn) - for v_orig, v_new in zip(eqn.outvars, fwd_vars): - if v_new is not None: - fwds[v_orig] = v_new - idxs: Dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)} - return [None if type(v) is Literal else idxs.get(fwds.get(v)) # type: ignore - for v in jaxpr.outvars] - - -def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool], - instantiate: Union[bool, Sequence[bool]] = False, - ) -> Tuple[Jaxpr, List[bool]]: - if type(instantiate) is bool: - instantiate = (instantiate,) * len(jaxpr.invars) - return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate)) - - -def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], - instantiate: Union[bool, Sequence[bool]] = False, - ) -> Tuple[Jaxpr, List[bool], List[bool]]: - jaxpr_ = convert_constvars_jaxpr(jaxpr) - new_jaxpr_, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs) - used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)]) - new_jaxpr = convert_invars_to_constvars(new_jaxpr_, sum(used_consts)) - return new_jaxpr, used_consts, used_inputs - - -@weakref_lru_cache -def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...], - instantiate: Tuple[bool, ...] - ) -> Tuple[Jaxpr, List[bool]]: - env: Dict[Var, bool] = {} - - def read(v: Var) -> bool: - return env.get(v, False) - - def write(x: Atom, b: bool) -> None: - if type(x) is Var: - env[x] = read(x) or b - - def has_effects(e: JaxprEqn) -> bool: - return bool(e.effects) or core.primitive_uses_outfeed(e.primitive, e.params) - - new_eqns = [] - map(write, jaxpr.outvars, used_outputs) - for eqn in jaxpr.eqns[::-1]: - used_outs = map(read, eqn.outvars) - if not any(used_outs) and not has_effects(eqn): - used_ins = [False] * len(eqn.invars) - else: - rule = dce_rules.get(eqn.primitive, _default_dce_rule) - used_ins, new_eqn = rule(used_outs, eqn) - if new_eqn is not None: - new_eqns.append(new_eqn) - map(write, eqn.invars, used_ins) - used_inputs = map(read, jaxpr.invars) - used_inputs = map(op.or_, instantiate, used_inputs) - - invars = [v for v, b in zip(jaxpr.invars, used_inputs) if b] - outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b] - eqns = new_eqns[::-1] - jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns) - - dbg = jaxpr.debug_info and core.JaxprDebugInfo( - jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, - tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b), - tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b)) - new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) - config.jax_enable_checks and core.check_jaxpr(new_jaxpr) - - return new_jaxpr, used_inputs - -DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], Optional[JaxprEqn]]] - -def _default_dce_rule( - used_outs: List[bool], eqn: JaxprEqn - ) -> Tuple[List[bool], JaxprEqn]: - return [True] * len(eqn.invars), eqn - -dce_rules: Dict[Primitive, DCERule] = {} - - -def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn - ) -> Tuple[List[bool], Optional[JaxprEqn]]: - new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) - new_params = dict(eqn.params, call_jaxpr=new_jaxpr) - update_params = call_param_updaters.get(eqn.primitive) - if update_params: - new_params = update_params(new_params, used_inputs, 0) - if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects: - return used_inputs, None - else: - new_eqn = new_jaxpr_eqn( - [v for v, used in zip(eqn.invars, used_inputs) if used], - [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) - return used_inputs, new_eqn -dce_rules[core.call_p] = dce_jaxpr_call_rule - - -def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn - ) -> Tuple[List[bool], JaxprEqn]: - # TODO(mattjj): de-duplicate with above rule? - jaxpr_ = eqn.params['call_jaxpr'] - jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts - new_jaxpr, used_inputs = dce_jaxpr(jaxpr, used_outputs) - new_params = dict(eqn.params, call_jaxpr=core.ClosedJaxpr(new_jaxpr, consts)) - new_eqn = new_jaxpr_eqn( - [v for v, used in zip(eqn.invars, used_inputs) if used], - [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) - return used_inputs, new_eqn -dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule - -@weakref_lru_cache -def close_jaxpr(jaxpr: Jaxpr) -> ClosedJaxpr: - return ClosedJaxpr(jaxpr, ()) - -def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] - ) -> ClosedJaxpr: - """Reorder `invars` by moving those indicated in `to_move` to the front.""" - return _move_binders_to_front(closed_jaxpr, tuple(to_move)) - -@weakref_lru_cache -def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Tuple[bool, ...] - ) -> ClosedJaxpr: - assert len(closed_jaxpr.in_avals) == len(to_move) - new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) - new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, - closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, - closed_jaxpr.jaxpr.effects) - new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) - return new_closed_jaxpr - -def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: - return ([elt for elt, move in zip(lst, to_move) if move] + - [elt for elt, move in zip(lst, to_move) if not move]) - -def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] - ) -> ClosedJaxpr: - """Reorder `invars` by moving those indicated in `to_move` to the back.""" - return move_binders_to_front(closed_jaxpr, map(op.not_, to_move)) - -class DynamicJaxprTracer(core.Tracer): - __slots__ = ['aval', '_debug_info'] - - def __init__(self, trace, aval, line_info=None): - self._trace = trace - self._line_info = line_info - # Needed for UnexpectedTracerError. - self._debug_info = self._trace.frame.debug_info - self.aval = aval - - def full_lower(self): - return self - - def _contents(self): - return () - - def _origin_msg(self): - if not self._trace.main.jaxpr_stack: # type: ignore - # If this Tracer has been leaked the jaxpr stack may no longer be - # available. So we can't print as much origin information. - return ("\nThis DynamicJaxprTracer was created on line " - f"{source_info_util.summarize(self._line_info)}") - else: - invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self) - dbg = self._debug_info - if dbg is None: - return "" - - origin = ("The error occurred while tracing the function " - f"{dbg.func_src_info or ''} for {dbg.traced_for}. ") - arg_info = arg_info_all(dbg) - if invar_pos and arg_info: - arg_info = [arg_info[i] for i in invar_pos] - arg_names = [f'{name}{keystr(path)}' for name, path in arg_info] - if len(arg_names) == 1: - arg_info_str = f"the argument {arg_names[0]}" - elif len(arg_names) == 2: - arg_info_str = f"the arguments {arg_names[0]} and {arg_names[1]}" - else: - *rest, last = arg_names - arg_info_str = f"the arguments {', '.join(rest)}, and {last}" - origin += ("This concrete value was not available in Python because it " - f"depends on the value{'s' if len(invar_pos) > 1 else ''} " - f"of {arg_info_str}.") - elif progenitor_eqns: - msts = [" operation " - f"{core.pp_eqn(eqn, core.JaxprPpContext(), core.JaxprPpSettings(print_shapes=True))}\n" - f" from line {source_info_util.summarize(eqn.source_info)}" - for eqn in progenitor_eqns[:5]] # show at most 5 - origin += ("This value became a tracer due to JAX operations on these lines:" - "\n\n" + "\n\n".join(msts)) - if len(progenitor_eqns) > 5: - origin += "\n\n(Additional originating lines are not shown.)" - return "\n" + origin - - def _assert_live(self) -> None: - if not self._trace.main.jaxpr_stack: # type: ignore - raise core.escaped_tracer_error(self, None) - - def get_referent(self): - frame = self._trace.frame - val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) - return self if val is None else get_referent(val) -api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval") - -def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: - del outvars - jaxpr_effects = set() - all_vars = [*constvars, *invars] - for eqn in eqns: - for eff in eqn.effects: - if isinstance(eff, effects.JaxprInputEffect): - invar = eqn.invars[eff.input_index] - if invar not in all_vars: - raise ValueError( - "`JaxprInputEffect` does not have corresponding input.") - eff = eff.replace(input_index=all_vars.index(invar)) - jaxpr_effects.add(eff) - return jaxpr_effects - - -class JaxprStackFrame: - gensym: Callable[[AbstractValue], Var] - tracer_to_var: Dict[TracerId, Var] - constid_to_tracer: Dict[ConstId, Tracer] - constvar_to_val: Dict[Var, Any] - tracers: List[DynamicJaxprTracer] # hold onto strong refs for all tracers - eqns: List[JaxprEqn] - invars: List[Var] - effects: core.Effects - debug_info: Optional[DebugInfo] - - def __init__(self): - self.gensym = core.gensym() - self.tracer_to_var = {} - self.constid_to_tracer = {} - self.constvar_to_val = {} - self.tracers = [] # circ refs, frame->tracer->trace->main->frame, - self.eqns = [] # cleared when we pop frame from main - self.invars = [] - self.effects = set() - self.debug_info = None - - def add_eqn(self, eqn: core.JaxprEqn): - self.eqns.append(eqn) - - def to_jaxpr(self, out_tracers): - # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) - outvars = [self.tracer_to_var[id(t)] for t in out_tracers] - constvars, constvals = unzip2(self.constvar_to_val.items()) - jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars, - self.eqns) - jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects) - jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) - return jaxpr, constvals - - def to_jaxpr2(self, out_tracers): - # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) - constvars, constvals = unzip2(self.constvar_to_val.items()) - expl_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] - jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars, - self.eqns) - jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns, - jaxpr_effects) - # We can't run check_jaxpr until after we normalize. - jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) - jaxpr, out_type = _add_implicit_outputs(jaxpr) - config.jax_enable_checks and core.check_jaxpr(jaxpr) - return jaxpr, out_type, constvals - - def newvar(self, aval): - if isinstance(aval, DShapedArray): - # this aval may have tracers in it, so we replace those with variables - new_shape = [self.tracer_to_var[id(d)] if isinstance(d, Tracer) else d - for d in aval.shape] - aval = aval.update(shape=tuple(new_shape)) - return self.gensym(aval) - - def find_progenitors(self, tracer): - var = self.tracer_to_var.get(id(tracer)) - if not var: - return None, None - active_vars = {var} - for eqn in self.eqns[::-1]: - produced = set(eqn.outvars) & active_vars - if produced: - active_vars.difference_update(produced) - active_vars.update(eqn.invars) - invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] - constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars] - return invar_positions, const_eqns - -def _const_folding_and_forwarding(jaxpr, constvals): - consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) - var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined - new_eqns = [] - for eqn in jaxpr.eqns: - # always apply invar substitutions - eqn = eqn.replace(invars=[var_subs.get(v, v) for v in eqn.invars]) - # if any inputs are constants and we have a constant-folding rule, apply it - has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) - for eff in eqn.effects) - if (eqn.primitive in const_fold_rules and any(v in consts for v in eqn.invars) - and not has_input_effect): - consts_in = [consts.get(v) for v in eqn.invars] - consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) - assert (new_eqn is None) == all(c is not None for c in consts_out) - for v, c in zip(eqn.outvars, consts_out): - if c is not None: consts[v] = c - if new_eqn is None: continue - else: eqn = new_eqn - # if the application trivially maps some inputs to outputs, simplify - if eqn.primitive in forwarding_rules and not has_input_effect: - fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn) - assert (new_eqn is None) == all(v is not None for v in fwd_vars) - for v_orig, v_new in zip(eqn.outvars, fwd_vars): - if v_new is not None: var_subs[v_orig] = v_new - if new_eqn is None: continue - else: eqn = new_eqn - new_eqns.append(eqn) - new_constvars, new_constvals = unzip2(consts.items()) - new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) - return new_jaxpr, new_constvals - -ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn], - Tuple[List[Optional[Any]], Optional[JaxprEqn]]] -const_fold_rules: Dict[Primitive, ConstFoldRule] = {} - -ForwardingRule = Callable[[JaxprEqn], - Tuple[List[Optional[Var]], Optional[JaxprEqn]]] -forwarding_rules: Dict[Primitive, ForwardingRule] = {} - -def _inline_literals(jaxpr, constvals): - # This function also prunes unused constants and inserts `dropvar` symbols. - input_effects = {eff for eff in jaxpr.effects - if isinstance(eff, effects.JaxprInputEffect)} - # Don't inline any literal with an input effect - has_input_effect = [any(eff.input_index == i for eff in input_effects) - for i in range(len(constvals))] - lits = {v: Literal(c, v.aval) for v, c, e in zip(jaxpr.constvars, constvals, - has_input_effect) - if type(c) in core.literalable_types and not np.shape(c) and not e} - lit: Callable[[Var], Optional[Literal]] = lits.get - newname: Callable[[AbstractValue], Var] = core.gensym() - newvars: Dict[Var, Var] = {} - newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) - var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) - dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) - - def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: - if isinstance(aval, DShapedArray): - return [d for d in aval.shape if isinstance(d, Var)] - return [] - - used = {v for eqn in jaxpr.eqns for invar in eqn.invars - for v in it.chain([invar], vars_in_shape(invar.aval))} - used |= {v for outvar in jaxpr.outvars - for v in it.chain([outvar], vars_in_shape(outvar.aval))} - new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] - new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) - if v in used and not lit(v)] - new_invars = [var(v) for v in jaxpr.invars] - new_eqns = [] - for eqn in jaxpr.eqns: - invars = [lit(v) or var(v) for v in eqn.invars] - outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] - new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) - new_outvars = [lit(v) or var(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) - return new_jaxpr, new_constvals - -class DynamicJaxprTrace(core.Trace): - __slots__ = [] # type: ignore - - @property - def frame(self): - return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error - - def new_arg(self, aval): - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) - self.frame.tracers.append(tracer) - self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) - self.frame.invars.append(var) - return tracer - - def new_const(self, c): - # TODO(mattjj): for ints, or hashable consts, don't rely on id - tracer = self.frame.constid_to_tracer.get(id(c)) - if tracer is None: - aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) - aval = self._lift_tracers_in_aval(aval) - tracer = self._new_const(aval, c) - return tracer - - pure = lift = new_const - - def _new_const(self, aval, c): - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) - self.frame.tracers.append(tracer) - self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) - self.frame.constid_to_tracer[id(c)] = tracer - self.frame.constvar_to_val[var] = c - return tracer - - def sublift(self, t): - # When lifting closed-over tracers corresponding to this same trace, the - # variable to lift could have tracers (representing axis size variables) in - # its shape. We must lift those too! - tracer = self.frame.constid_to_tracer.get(id(t)) - if tracer is None: - aval = raise_to_shaped(get_aval(t), weak_type=dtypes.is_weakly_typed(t)) - aval = self._lift_tracers_in_aval(aval) - tracer = self._new_const(aval, t) - return tracer - - def _lift_tracers_in_aval(self, aval): - if (not isinstance(aval, DShapedArray) or - not any(isinstance(d, Tracer) for d in aval.shape)): - return aval - shape = [self.full_raise(d) if isinstance(d, Tracer) else d - for d in aval.shape] - return aval.update(shape=tuple(shape)) - - def getvar(self, tracer): - var = self.frame.tracer_to_var.get(id(tracer)) - if var is None: - raise core.escaped_tracer_error(tracer) - return var - - def makevar(self, tracer): - var = self.frame.tracer_to_var.get(id(tracer)) - assert var is None, "a jaxpr variable must be created only once per tracer" - self.frame.tracers.append(tracer) - var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) - return var - - def instantiate_const(self, val): - if (isinstance(val, Tracer) and val._trace.main is self.main - and val._trace.sublevel == self.sublevel): - return val - else: - return self.new_const(val) - - def process_primitive(self, primitive, tracers, params): - if primitive in custom_staging_rules: - return custom_staging_rules[primitive](self, *tracers, **params) - return self.default_process_primitive(primitive, tracers, params) - - def default_process_primitive(self, primitive, tracers, params): - avals = [t.aval for t in tracers] - out_avals, effects = primitive.abstract_eval(*avals, **params) - out_avals = [out_avals] if not primitive.multiple_results else out_avals - source_info = source_info_util.current() - out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] - invars = map(self.getvar, tracers) - outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info) - self.frame.add_eqn(eqn) - return out_tracers if primitive.multiple_results else out_tracers.pop() - - def process_call(self, call_primitive, f, explicit_tracers, params): - if f.in_type is None: - f = lu.annotate(f, tuple((raise_to_shaped(t.aval), True) - for t in explicit_tracers)) - implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) - in_tracers = [*implicit_tracers, *explicit_tracers] - # TODO(mattjj): check in_tracers are consistent with f.in_type annotation - with core.new_sublevel(): - # TODO(lenamartens): Make call_primitive name -> API function name mapping. - # (currently this will display eg. 'xla_call' instead of `jit`) - dbg = debug_info_final(f, call_primitive.name) - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg) - if params.get('inline', False): - return core.eval_jaxpr(jaxpr, consts, *in_tracers) - source_info = source_info_util.current() - out_tracers = [] - for aval, _ in out_type: - if type(aval) is DShapedArray: - shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else - out_tracers[d.val] if type(d) is OutDBIdx else - d for d in aval.shape] - aval = aval.update(shape=tuple(get_referent(d) for d in shape)) - out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) - invars = map(self.getvar, in_tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) - outvars = map(self.makevar, out_tracers) - new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) - update_params = call_param_updaters.get(call_primitive) - if update_params: - new_params = update_params(new_params, [True] * len(explicit_tracers), - len(consts) + len(implicit_tracers)) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, - new_params, new_params['call_jaxpr'].effects, - source_info) - self.frame.add_eqn(eqn) - return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] - - def post_process_call(self, call_primitive, out_tracers, params): - assert False # unreachable - - def process_map(self, map_primitive, f, tracers, params): - in_avals = [t.aval for t in tracers] - axis_name, axis_size = params['axis_name'], params['axis_size'] - reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) - if in_axis is not None else a - for a, in_axis in zip(in_avals, params['in_axes'])] - with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore - with core.new_sublevel(): - jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic( - f, self.main, reduced_in_avals, - debug_info=debug_info_final(f, map_primitive.name)) - ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) - if ordered_effects: - raise ValueError("Ordered effects not supported for " - f"map primitives: {ordered_effects}") - out_axes = params['out_axes_thunk']() - out_avals = [core.unmapped_aval(axis_size, axis_name, out_axis, a) - if out_axis is not None else a - for a, out_axis in zip(reduced_out_avals, out_axes)] - source_info = source_info_util.current() - out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] - invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) - outvars = map(self.makevar, out_tracers) - new_in_axes = (None,) * len(consts) + params['in_axes'] - new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, - call_jaxpr=convert_constvars_jaxpr(jaxpr)) - del new_params['out_axes_thunk'] - update_params = call_param_updaters.get(map_primitive) - if update_params: - new_params = update_params(new_params, [True] * len(tracers), len(consts)) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, - new_params, jaxpr.effects, source_info) - self.frame.add_eqn(eqn) - return out_tracers - - def post_process_map(self, map_primitive, out_tracers, params): - assert False # unreachable - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - in_avals = [t.aval for t in tracers] - with core.new_sublevel(): - fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) - closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) - main_ = ref(self.main) - jvp_jaxpr_thunk = _memoize( - lambda: trace_to_subjaxpr_dynamic(jvp, main_(), 2 * in_avals)[::2]) - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] - invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) - outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, - dict(call_jaxpr=closed_fun_jaxpr, - jvp_jaxpr_thunk=jvp_jaxpr_thunk, - num_consts=len(consts), - symbolic_zeros=symbolic_zeros), - fun_jaxpr.effects, - source_info_util.current()) - self.frame.add_eqn(eqn) - return out_tracers - - def post_process_custom_jvp_call(self, out_tracers, _): - assert False # unreachable - - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): - in_avals = [t.aval for t in tracers] - with core.new_sublevel(): - fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) - closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) - main_ = ref(self.main) - fwd_jaxpr_thunk = _memoize( - lambda: trace_to_subjaxpr_dynamic(fwd, main_(), in_avals)[::2]) - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] - invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) - outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, - dict(fun_jaxpr=closed_fun_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_thunk, - num_consts=len(consts), - bwd=bwd, out_trees=out_trees), - fun_jaxpr.effects, - source_info_util.current()) - self.frame.add_eqn(eqn) - return out_tracers - - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable - - def process_custom_transpose(self, prim, call, tracers, - transpose, out_types, - lin_tree, res_tree, out_tree): - tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) - - in_avals_p = [t.aval for t in tracers] - in_avals_t = [*[t.aval for t in tracers_res], *out_types] - - with core.new_sublevel(): - call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic( - call, self.main, in_avals_p) - closed_call_jaxpr = core.ClosedJaxpr( - convert_constvars_jaxpr(call_jaxpr), ()) - - transpose_flat, in_tree2 = flatten_fun_nokwargs( - lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree))) - - main_ = ref(self.main) - # the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts - transpose_jaxpr_thunk = _memoize( - lambda: trace_to_subjaxpr_dynamic( - transpose_flat, main_(), in_avals_t)[::2]) - - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] - invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, call_consts)) - outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, - dict(call_jaxpr=closed_call_jaxpr, - transpose_jaxpr_thunk=transpose_jaxpr_thunk, - out_types=out_types, res_tree=res_tree, - lin_tree=lin_tree, out_tree=out_tree), - closed_call_jaxpr.effects, - source_info_util.current()) - self.frame.add_eqn(eqn) - return out_tracers - - -custom_staging_rules: Dict[Primitive, Callable] = {} - -def _memoize(thunk): - cell = [] - saved_state = [core.thread_local_state.trace_state.copy()] - def memoized(): - if not cell: - prev_state = core.thread_local_state.trace_state - core.thread_local_state.trace_state = saved_state.pop() - try: - cell.append(thunk()) - finally: - core.thread_local_state.trace_state = prev_state - return cell[0] - return memoized - -# TODO(mattjj): remove this DebugInfo and helper functions, replace with -# api_util.py versions - -class DebugInfo(NamedTuple): - func_src_info: Optional[str] # f'{fun.__name__} at {filename}:{lineno}' - signature: Optional[inspect.Signature] # inspect.signature(fun) - in_tree: Optional[PyTreeDef] # caller/constructor might not have this info - out_tree: Optional[Callable[[], PyTreeDef]] # lazy, not avail at trace time - has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args - traced_for: str # "jit", "scan", "make_jaxpr", etc - -def debug_info(fn: Callable, in_tree: Optional[PyTreeDef], - out_tree_thunk: Optional[Callable[[], PyTreeDef]], - has_kwargs: bool, traced_for: str) -> DebugInfo: - try: sig = inspect.signature(fn) - except (ValueError, TypeError): sig = None - src_info = fun_sourceinfo(fn) - return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs, - traced_for) - -def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo: - "Make a DebugInfo from data available to final-style primitives like pmap." - in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False) - return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for) - -def arg_info_all(dbg: DebugInfo) -> Optional[List[Tuple[str, KeyPath]]]: - ba = None if dbg.in_tree is None else sig_info(dbg) - if ba is None: return None - return [(name, key_path) for name, dummy_arg in ba.arguments.items() - for key_path, _ in generate_key_paths(dummy_arg)] - -def sig_info(dbg: DebugInfo) -> Optional[inspect.BoundArguments]: - if dbg.in_tree is None or dbg.signature is None: return None - try: - dummy_args = tree_unflatten(dbg.in_tree, [False] * dbg.in_tree.num_leaves) - except: - return None - args, kwargs = dummy_args if dbg.has_kwargs else (dummy_args, {}) - try: - return dbg.signature.bind(*args, **kwargs) - except (TypeError, ValueError): - return None - -def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]: - if dbg.out_tree is None: return None - try: - num_leaves = dbg.out_tree().num_leaves - dummy_result = tree_unflatten(dbg.out_tree(), [False] * num_leaves) - except: - return None - else: - return [path for path, _ in generate_key_paths(dummy_result)] - -@profiler.annotate_function -def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue], - debug_info: Optional[DebugInfo] = None, - *, - keep_inputs: Optional[List[bool]] = None): - with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore - main.jaxpr_stack = () # type: ignore - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( - fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) - del main, fun - return jaxpr, out_avals, consts - -def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, - in_avals: Sequence[AbstractValue], *, - keep_inputs: Optional[Sequence[bool]] = None, - debug_info: Optional[DebugInfo] = None): - keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs - - frame = JaxprStackFrame() - frame.debug_info = debug_info - with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): - trace = DynamicJaxprTrace(main, core.cur_sublevel()) - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) - in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - ans = fun.call_wrapped(*in_tracers_) - out_tracers = map(trace.full_raise, ans) - jaxpr, consts = frame.to_jaxpr(out_tracers) - del fun, main, trace, frame, in_tracers, out_tracers, ans - config.jax_enable_checks and core.check_jaxpr(jaxpr) - return jaxpr, [v.aval for v in jaxpr.outvars], consts - - -@profiler.annotate_function -def trace_to_jaxpr_dynamic2( - fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None - ) -> Tuple[Jaxpr, OutputType, List[Any]]: - with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore - main.jaxpr_stack = () # type: ignore - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) - del main, fun - return jaxpr, out_type, consts - -def trace_to_subjaxpr_dynamic2( - fun: lu.WrappedFun, main: core.MainTrace, - debug_info: Optional[DebugInfo] = None -) -> Tuple[Jaxpr, OutputType, List[Any]]: - in_avals, keep_inputs = unzip2(fun.in_type) - frame = JaxprStackFrame() - frame.debug_info = debug_info - with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): - trace = DynamicJaxprTrace(main, core.cur_sublevel()) - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) - in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - ans = fun.call_wrapped(*in_tracers_) - out_tracers = map(trace.full_raise, ans) - jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers) - del fun, main, trace, frame, in_tracers, out_tracers, ans - return jaxpr, out_type, consts - - -@contextmanager -def extend_jaxpr_stack(main, frame): - main.jaxpr_stack = main.jaxpr_stack + (frame,) - try: - yield - finally: - assert frame is main.jaxpr_stack[-1] - main.jaxpr_stack = main.jaxpr_stack[:-1] - -@profiler.annotate_function -def trace_to_jaxpr_final(fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue], - debug_info: Optional[DebugInfo] = None, - keep_inputs: Optional[Sequence[bool]] = None): - with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore - main.jaxpr_stack = () # type: ignore - with core.new_sublevel(): - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( - fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) - del fun, main - return jaxpr, out_avals, consts - -@profiler.annotate_function -def trace_to_jaxpr_final2( - fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None - ) -> Tuple[Jaxpr, OutputType, List[Any]]: - with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore - main.jaxpr_stack = () # type: ignore - with core.new_sublevel(): - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) - del fun, main - return jaxpr, out_type, consts - - -AbstractedAxisName = Hashable -AbstractedAxesSpec = Union[Dict[int, AbstractedAxisName], - Tuple[AbstractedAxisName, ...]] -def infer_lambda_input_type( - axes_specs: Optional[Sequence[AbstractedAxesSpec]], - args: Sequence[Any] - ) -> InputType: - ndims = [getattr(get_aval(x), 'ndim', 0) for x in args] - partial_specs = _canonicalize_specs(ndims, axes_specs) - specs = _complete_specs(args, partial_specs) - idxs, implicit_types = _collect_implicit(args, specs) - implicit_sig = [(ty, False) for ty in implicit_types] - explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)] - input_type = (*implicit_sig, *explicit_sig) - lu._check_input_type(input_type) - return input_type - -def _spec_to_dict(spec: AbstractedAxesSpec) -> Dict[int, AbstractedAxisName]: - if isinstance(spec, tuple): - return {i: d for i, d in enumerate(spec) if d is not None} - else: - return spec - -def _canonicalize_specs( - ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]] - ) -> List[Dict[int, AbstractedAxisName]]: - if specs is None: - return [{}] * len(ndims) - else: - return [_spec_to_dict(s) for n, s in zip(ndims, specs)] - -def _complete_specs( - args: Sequence[Any], partial_specs: List[Dict[int, AbstractedAxisName]] - ) -> List[Dict[int, AbstractedAxisName]]: - # The abstracted axes specification in `partial_specs` is partial in the sense - # that there could be additional axis abstraction represented in `args` due to - # Tracers existing in the shapes of elements of `args`. The purpose of this - # function is to produce a full specification, for each argument mapping any - # abstracted axis positions to a name, introducing new names as needed for - # Tracers in axis sizes which don't already correspond to abstracted axis - # names (with one new name per unique Tracer object id). - - # Identify each user-supplied name in partial_specs with a size. - sizes: Dict[AbstractedAxisName, Union[int, DynamicJaxprTracer]] = {} - for x, spec in zip(args, partial_specs): - for i, name in spec.items(): - d = sizes.setdefault(name, x.shape[i]) - if d is not x.shape[i] and d != x.shape[i]: raise TypeError - - # Introduce new names as needed for Tracers in shapes. - named_tracers: Dict[TracerId, AbstractedAxisName] = { - id(d): name for name, d in sizes.items() if isinstance(d, Tracer)} - specs: List[Dict[int, AbstractedAxisName]] = [] - for x, spec in zip(args, partial_specs): - if isinstance(get_aval(x), DShapedArray): - spec = dict(spec) - for i, d in enumerate(x.shape): - if isinstance(d, Tracer): - spec[i] = named_tracers.get(id(d), TracerAsName(d)) - specs.append(spec) - - # Assert that `specs` is now complete in the sense that there are no Tracers - # which don't correspond to an AbstractedAxisName. - assert all(not spec or not any(isinstance(d, Tracer) and i not in spec - for i, d in enumerate(x.shape)) - for x, spec in zip(args, specs)) - return specs - - -def _collect_implicit( - args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]] - ) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractValue]]: - # Given an explicit argument list and a specification of abstracted axes, we - # want to produce an InputType by identifying AbstractedAxisNames with DBIdxs - # and figuring out which AbstractedAxisNames correspond to implicit arguments. - - idxs: Dict[AbstractedAxisName, DBIdx] = {} - implicit_types: List[AbstractValue] = [] - explicit_tracers: Dict[TracerId, int] = {} - counter = it.count() - - # Add implicit arguments to idxs. - for explicit_idx, (x, spec) in enumerate(zip(args, specs)): - for i, name in spec.items(): - if name not in idxs and id(x.shape[i]) not in explicit_tracers: - idxs[name] = DBIdx(next(counter)) - implicit_types.append(raise_to_shaped(get_aval(x.shape[i]))) - if isinstance(x, Tracer): - explicit_tracers.setdefault(id(x), explicit_idx) # use the first - - # Now that we know the implicit args, add explicit args to idxs. - offset = len(implicit_types) - for x, spec in zip(args, specs): - for i, name in spec.items(): - if id(x.shape[i]) in explicit_tracers: - idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])])) - - return idxs, implicit_types - -def _arg_type( - idxs: Dict[AbstractedAxisName, DBIdx], x: Any, - spec: Dict[int, AbstractedAxisName] - ) -> AbstractValue: - # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames. - aval = get_aval(x) # aval.shape could contain Tracers - if not spec: return core.raise_to_shaped(aval) - shape: List[Union[int, DBIdx]] = [idxs[spec[i]] if i in spec else d - for i, d in enumerate(aval.shape)] - assert not any(isinstance(d, Tracer) for d in shape) - return DShapedArray(tuple(shape), aval.dtype, False) - -def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]: - invars = [*jaxpr.constvars, *jaxpr.invars] - expl_outvars = jaxpr.outvars - - # First do a pass to collect implicit outputs, meaning variables which occurr - # in explicit_outvars types but not in invars or to the left in outvars. - seen: Set[Var] = set(invars) - impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore - (seen.add(x) or type(x.aval) is DShapedArray) # type: ignore - for d in x.aval.shape if type(d) is Var and d not in seen] - outvars = [*impl_outvars, *expl_outvars] - - # Now assemble an OutputType by mapping vars in shapes to InDBIdx/OutDBIdx. - in_map : Dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)} - out_map: Dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars) - if type(x) is Var} - out_avals_ = (x.aval for x in outvars) - out_avals = [a.update(shape=tuple(in_map.get(d, out_map.get(d)) - if type(d) is Var else d for d in a.shape)) - if type(a) is DShapedArray else a for a in out_avals_] - kept_outs = [False] * len(impl_outvars) + [True] * len(expl_outvars) - out_type = tuple(zip(out_avals, kept_outs)) - - new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, - jaxpr.effects, jaxpr.debug_info) - config.jax_enable_checks and core.check_jaxpr(jaxpr) - return new_jaxpr, out_type - - -class TracerAsName: - ref: Any - def __init__(self, tracer): - self.ref = core.get_referent(tracer) - def __eq__(self, other): - return isinstance(other, TracerAsName) and self.ref is other.ref - def __hash__(self): - return id(self.ref) - -def _extract_implicit_args( - trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]], - explicit_tracers: Sequence[DynamicJaxprTracer] - ) -> Sequence[DynamicJaxprTracer]: - # First, construct a list to represent the full argument list, leaving the - # implicit arguments as Nones for now. - explicit_tracers_ = iter(explicit_tracers) - tracers = [next(explicit_tracers_) if expl else None for _, expl in in_type] - assert next(explicit_tracers_, None) is None - del explicit_tracers_ - - # Next, populate the implicit arguments using DBIdxs in in_type. - for i, (aval, explicit) in enumerate(in_type): - if not explicit or not isinstance(aval, DShapedArray): - continue # can't populate an implicit argument - tracer = tracers[i] - assert tracer is not None - for d1, d2 in zip(aval.shape, tracer.aval.shape): - if isinstance(d1, DBIdx): - if tracers[d1.val] is None: - tracers[d1.val] = trace.instantiate_const(d2) - assert tracers[d1.val] is trace.instantiate_const(d2) - assert all(t is not None for t in tracers) - return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore - -def _input_type_to_tracers( - new_arg: Callable[[AbstractValue], Tracer], - in_avals: Sequence[AbstractValue] - ) -> Sequence[Tracer]: - # Create input Tracers given input AbstractValues, each of which can contain - # DeBruijn indices which refer to positions in the input argument list. That - # is, each element `a` of `in_avals` can have DBIdx instances in its shape, - # which must refer to positions left of `a`'s. - in_tracers: List[Tracer] = [] - - def _substitute_tracers_in_aval(a: AbstractValue) -> AbstractValue: - if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape): - shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] # type: ignore - return a.update(shape=tuple(shape)) - return a - - for a in in_avals: - in_tracers.append(new_arg(_substitute_tracers_in_aval(a))) - return in_tracers - -def _substitute_vars_in_type( - consts: Dict[Var, Literal], env: Dict[Var, Var], a: AbstractValue - ) -> AbstractValue: - if isinstance(a, DShapedArray) and any(isinstance(d, Var) for d in a.shape): - shape = [consts[d].val if d in consts else env[d] # type: ignore - if isinstance(d, Var) else d for d in a.shape] - return a.update(shape=tuple(shape)) - else: - return a - - -class DimensionHandlerTracer(core.DimensionHandler): - """See core.DimensionHandler. - - Most methods are inherited. - """ - def is_constant(self, d: core.DimSize) -> bool: - assert isinstance(d, Tracer) - return False - - def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool: - return d1 is d2 - - def greater_equal(self, d1: core.DimSize, d2: core.DimSize): - raise core.InconclusiveDimensionOperation("TODO") - - def divide_shape_sizes(self, s1: core.Shape, s2: core.Shape) -> core.DimSize: - """Computes integer "i" such that i * size(s2) == size(s1). - - Raise InconclusiveDimensionOperation if there is no such integer for all - contexts. - """ - s1_size = functools.reduce(op.mul, s1, 1) - s2_size = functools.reduce(op.mul, s2, 1) - q, r = divmod(s1_size, s2_size) - # TODO(necula): must check that r == 0! - return q - - def stride(self, d: core.DimSize, window_size: core.DimSize, window_stride: core.DimSize) -> core.DimSize: - """Implements `(d - window_size) // window_stride + 1`""" - raise core.InconclusiveDimensionOperation("TODO") - - def as_value(self, d: core.DimSize): - """Turns a dimension size into a Jax value that we can compute with.""" - raise core.InconclusiveDimensionOperation("TODO") - -core._SPECIAL_DIMENSION_HANDLERS[DynamicJaxprTracer] = DimensionHandlerTracer() - -Const = Any -Val = Any - -def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const] - ) -> Tuple[Jaxpr, List[Const]]: - bounds = {v: v.aval.dtype.bound for v in jaxpr.invars - if isinstance(v.aval, core.UnshapedArray) and - type(v.aval.dtype) is core.bint and not v.aval.shape} - idxs = {v: DBIdx(i) for i, v in enumerate(jaxpr.invars)} - - def substitute(aval: AbstractValue) -> AbstractValue: - if (isinstance(aval, core.UnshapedArray) and type(aval.dtype) is core.bint - and not aval.shape): - return ShapedArray((), dtypes._scalar_type_to_dtype(int)) - elif isinstance(aval, DShapedArray): - shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore - typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray - return typ(tuple(shape), aval.dtype, aval.weak_type) - else: - return aval - - in_avals = [substitute(v.aval) for v in jaxpr.invars] - eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts)) - padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals) - return padded_jaxpr, padded_consts - -class BoundedAxisSize(NamedTuple): - val: Union[int, DynamicJaxprTracer] - bound: int - -def _eval_jaxpr_padded( - jaxpr: Jaxpr, consts: List[Const], *args: DynamicJaxprTracer - ) -> List[Union[Const, DynamicJaxprTracer]]: - env: Dict[Var, Val] = {} - - def read(x): - return x.val if type(x) is Literal else env[x] - - def write(v, val) -> None: - env[v] = val - - map(write, jaxpr.constvars, consts) - map(write, jaxpr.invars, args) - for eqn in jaxpr.eqns: - in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars] - out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars] - rule = padding_rules[eqn.primitive] - outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params) - map(write, eqn.outvars, outs) - return map(read, jaxpr.outvars) - -def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue: - if isinstance(aval, DShapedArray): - shp = [] - for d in aval.shape: - if isinstance(d, core.DArray): - assert not d.shape and type(d.dtype) is core.bint - shp.append(BoundedAxisSize(int(d._data), int(d.dtype.bound))) - elif (type(d) is core.Var and isinstance(d.aval, core.DShapedArray) and - type(d.aval.dtype) is core.bint): - assert not d.aval.shape - shp.append(BoundedAxisSize(env[d], d.aval.dtype.bound)) - else: - shp.append(env.get(d, d)) - return DShapedArray(tuple(shp), aval.dtype, aval.weak_type) - else: - return aval - -def _is_bint_axis_size(d: Union[int, core.DArray, core.Var]) -> bool: - if isinstance(d, core.DArray): - assert not d.shape # pytype: disable=attribute-error - return type(d.dtype) is core.bint # pytype: disable=attribute-error - elif isinstance(d, core.Var): - return (isinstance(d.aval, core.DShapedArray) and # pytype: disable=attribute-error - type(d.aval.dtype) is core.bint) # pytype: disable=attribute-error - return False - - -padding_rules: Dict[Primitive, Callable] = {} - -def def_trivial_padding(prim: Primitive) -> None: - if prim.multiple_results: - padding_rules[prim] = partial(_trivial_padding_rule_multi, prim) - else: - padding_rules[prim] = partial(_trivial_padding_rule, prim) - -def _trivial_padding_rule(prim, _, __, *args, **params): - return [prim.bind(*args, **params)] - -def _trivial_padding_rule_multi(prim, _, __, *args, **params): - return prim.bind(*args, **params) - -def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): - if call_jaxpr.constvars: raise NotImplementedError - padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ()) - if padded_consts: raise NotImplementedError - new_params = dict(params, call_jaxpr=padded_jaxpr) - subfuns, bind_params = prim.get_bind_params(new_params) - return prim.bind(*subfuns, *args, **bind_params) - - -# TODO(mattjj): the following are deprecated; update callers to _nounits version -# See https://github.com/google/jax/pull/9498 -@lu.transformation -def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]], - pvals: Sequence[PartialVal]): - assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals - trace = main.with_cur_sublevel() - in_tracers = map(trace.new_arg, pvals) - ans = yield in_tracers, {} - assert isinstance(ans, (list, tuple)), ( - f"Got unexpected return type when tracing function to jaxpr: {ans}") - assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( - f"Got unexpected return type when tracing function to jaxpr: {ans}") - instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers) - jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers) - out_pvals = [t.pval for t in out_tracers] - del trace, in_tracers, out_tracers - yield jaxpr, (out_pvals, consts, env) - -partial_eval_jaxpr: Callable - -def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): - if instantiate: - return trace.instantiate_const(trace.full_raise(tracer)) - else: - return tracer +from jax._src.interpreters.partial_eval import ( + AbstractedAxesSpec as AbstractedAxesSpec, + AbstractedAxisName as AbstractedAxisName, + BoundedAxisSize as BoundedAxisSize, + Const as Const, + ConstFoldRule as ConstFoldRule, + ConstVar as ConstVar, + DCERule as DCERule, + DebugInfo as DebugInfo, + DimensionHandlerTracer as DimensionHandlerTracer, + DynamicJaxprTrace as DynamicJaxprTrace, + DynamicJaxprTracer as DynamicJaxprTracer, + ForwardingRule as ForwardingRule, + FreeVar as FreeVar, + JaxprEqnRecipe as JaxprEqnRecipe, + JaxprStackFrame as JaxprStackFrame, + JaxprTrace as JaxprTrace, + JaxprTracer as JaxprTracer, + JaxprTracerRecipe as JaxprTracerRecipe, + LambdaBinding as LambdaBinding, + ParamsUpdater as ParamsUpdater, + PartialEvalCustomResult as PartialEvalCustomResult, + PartialEvalCustomRule as PartialEvalCustomRule, + PartialVal as PartialVal, + ResAvalUpdater as ResAvalUpdater, + TracerAsName as TracerAsName, + TracerId as TracerId, + Val as Val, + abstract_eval_fun as abstract_eval_fun, + arg_info_all as arg_info_all, + call_padding_rule as call_padding_rule, + call_param_updaters as call_param_updaters, + call_partial_eval_custom_rule as call_partial_eval_custom_rule, + call_partial_eval_rules as call_partial_eval_rules, + close_jaxpr as close_jaxpr, + closed_call_partial_eval_custom_rule as closed_call_partial_eval_custom_rule, + config as config, + const_fold_rules as const_fold_rules, + convert_constvars_jaxpr as convert_constvars_jaxpr, + convert_envvars_to_constvars as convert_envvars_to_constvars, + convert_invars_to_constvars as convert_invars_to_constvars, + custom_partial_eval_rules as custom_partial_eval_rules, + custom_staging_rules as custom_staging_rules, + dce_jaxpr as dce_jaxpr, + dce_jaxpr_call_rule as dce_jaxpr_call_rule, + dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule, + dce_jaxpr_consts as dce_jaxpr_consts, + dce_rules as dce_rules, + debug_info as debug_info, + debug_info_final as debug_info_final, + def_trivial_padding as def_trivial_padding, + extend_jaxpr_stack as extend_jaxpr_stack, + forwarding_rules as forwarding_rules, + infer_lambda_input_type as infer_lambda_input_type, + instantiate_const_at as instantiate_const_at, + make_jaxpr_effects as make_jaxpr_effects, + move_binders_to_back as move_binders_to_back, + move_binders_to_front as move_binders_to_front, + new_eqn_recipe as new_eqn_recipe, + pad_jaxpr as pad_jaxpr, + padding_rules as padding_rules, + partial_eval_jaxpr_custom as partial_eval_jaxpr_custom, + partial_eval_jaxpr_custom_rule_not_implemented as partial_eval_jaxpr_custom_rule_not_implemented, + partial_eval_jaxpr_custom_rules as partial_eval_jaxpr_custom_rules, + partial_eval_jaxpr_nounits as partial_eval_jaxpr_nounits, + partial_eval_wrapper_nounits as partial_eval_wrapper_nounits, + partition_pvals as partition_pvals, + recipe_to_eqn as recipe_to_eqn, + result_info as result_info, + sig_info as sig_info, + trace_to_jaxpr as trace_to_jaxpr, + trace_to_jaxpr_dynamic as trace_to_jaxpr_dynamic, + trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2, + trace_to_jaxpr_final as trace_to_jaxpr_final, + trace_to_jaxpr_final2 as trace_to_jaxpr_final2, + trace_to_jaxpr_nounits as trace_to_jaxpr_nounits, + trace_to_subjaxpr as trace_to_subjaxpr, + trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic, + trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2, + trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, + trace_to_subjaxpr_nounits_dyn as trace_to_subjaxpr_nounits_dyn, + trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, + tracers_to_jaxpr as tracers_to_jaxpr, + trivial_ctx as trivial_ctx, +) + +from jax._src.core import Jaxpr as Jaxpr diff --git a/setup.cfg b/setup.cfg index 2903790fc..34cdc9641 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ per-file-ignores = jax/interpreters/ad.py:F401 jax/interpreters/batching.py:F401 jax/interpreters/mlir.py:F401 + jax/interpreters/partial_eval.py:F401 jax/interpreters/pxla.py:F401 jax/interpreters/xla.py:F401 jax/linear_util.py:F401 diff --git a/tests/api_test.py b/tests/api_test.py index 3bf6572bb..009095d6f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -55,7 +55,7 @@ from jax.interpreters import ad from jax._src.interpreters import mlir from jax.interpreters import xla from jax.interpreters import batching -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe from jax.sharding import PartitionSpec as P from jax._src import array from jax.experimental import pjit diff --git a/tests/core_test.py b/tests/core_test.py index b6fa34920..9419411d8 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -29,7 +29,6 @@ from jax import numpy as jnp from jax import jvp, linearize, vjp, jit, make_jaxpr from jax.api_util import flatten_fun_nokwargs from jax.config import config -from jax.interpreters import partial_eval as pe from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce, tree_leaves) @@ -38,6 +37,7 @@ from jax._src import linear_util as lu from jax._src import util from jax._src import test_util as jtu from jax._src.core import UnshapedArray, ShapedArray, DBIdx +from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 1bdb5edcd..f3831c301 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -26,8 +26,8 @@ from jax._src import linear_util as lu from jax.config import config from jax.experimental import maps from jax.experimental import pjit -from jax.interpreters import ad -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import ad +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src import ad_checkpoint from jax._src import dispatch diff --git a/tests/state_test.py b/tests/state_test.py index ef41053c8..f0f9f4c22 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -24,7 +24,7 @@ from jax._src import core from jax import lax from jax._src import linear_util as lu from jax.config import config -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu from jax._src.util import tuple_insert import jax.numpy as jnp