diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py new file mode 100644 index 000000000..e0d21e322 --- /dev/null +++ b/jax/_src/lax/control_flow/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for the control flow primitives.""" +from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p, + cummin, cummin_p, cumprod, + cumprod_p, cumsum, cumsum_p, + cumred_tpu_impl, fori_loop, map, + scan, scan_bind, scan_p, + _scan_impl, while_loop, while_p) +from jax._src.lax.control_flow.conditionals import cond, cond_p, switch +from jax._src.lax.control_flow.remat_impl import (remat_impl, + optimization_barrier_p) +from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root, + _custom_linear_solve_impl, + linear_solve_p) + +from jax._src.lax.control_flow.common import allowed_effects +# Private utilities used elsewhere in JAX +# TODO(sharadmv): lift them into a more common place +from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr, + _initial_style_jaxpr, + _initial_style_jaxprs_with_common_consts, + _check_tree_and_avals) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py new file mode 100644 index 000000000..1855436ab --- /dev/null +++ b/jax/_src/lax/control_flow/common.py @@ -0,0 +1,140 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for the common control flow utilities.""" +import os +from functools import partial + +from typing import Callable, Optional, Sequence, Set + +from jax import core +from jax import linear_util as lu +from jax.api_util import flatten_fun_nokwargs +from jax.interpreters import partial_eval as pe +from jax._src import ad_util +from jax._src import util +from jax._src.util import cache, safe_map, unzip3 +from jax.tree_util import tree_map, tree_unflatten, tree_structure + +map, unsafe_map = safe_map, map + +allowed_effects: Set[core.Effect] = set() + + +def _abstractify(x): + return core.raise_to_shaped(core.get_aval(x)) + +def _typecheck_param(prim, param, name, msg_required, pred): + if not pred: + msg = (f'invalid {prim} param {name} of type {type(param).__name__}, ' + f'{msg_required} required:') + param_str = str(param) + sep = os.linesep if os.linesep in param_str else ' ' + msg = sep.join([msg, param_str]) + raise core.JaxprTypeError(msg) + +@cache() +def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, + primitive_name: Optional[str] = None): + wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + debug = pe.debug_info(fun, in_tree, False, primitive_name or "") + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) + return jaxpr, consts, out_tree() + +@cache() +def _initial_style_jaxpr(fun: Callable, in_tree, in_avals, + primitive_name: Optional[str] = None): + jaxpr, consts, out_tree = _initial_style_open_jaxpr( + fun, in_tree, in_avals, primitive_name) + closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) + return closed_jaxpr, consts, out_tree + +@cache() +def _initial_style_jaxprs_with_common_consts( + funs: Sequence[Callable], in_tree, in_avals, primitive_name: str): + # When staging the branches of a conditional into jaxprs, constants are + # extracted from each branch and converted to jaxpr arguments. To use the + # staged jaxprs as the branches to a conditional *primitive*, we need for + # their (input) signatures to match. This function "joins" the staged jaxprs: + # for each one, it makes another that accepts *all* constants, but only uses + # those that it needs (dropping the rest). + + jaxprs, all_consts, all_out_trees = \ + unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name) + for fun in funs) + + newvar = core.gensym(jaxprs, suffix='_') + all_const_avals = [map(_abstractify, consts) for consts in all_consts] + unused_const_vars = [map(newvar, const_avals) + for const_avals in all_const_avals] + def pad_jaxpr_constvars(i, jaxpr): + prefix = util.concatenate(unused_const_vars[:i]) + suffix = util.concatenate(unused_const_vars[i + 1:]) + constvars = [*prefix, *jaxpr.constvars, *suffix] + return jaxpr.replace(constvars=constvars) + + consts = util.concatenate(all_consts) + jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] + closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) + for jaxpr in jaxprs] + return closed_jaxprs, consts, all_out_trees + +def _check_tree_and_avals(what, tree1, avals1, tree2, avals2): + """Raises TypeError if (tree1, avals1) does not match (tree2, avals2). + + Corresponding `tree` and `avals` must match in the sense that the number of + leaves in `tree` must be equal to the length of `avals`. `what` will be + prepended to details of the mismatch in TypeError. + """ + if tree1 != tree2: + raise TypeError( + f"{what} must have same type structure, got {tree1} and {tree2}.") + if not all(map(core.typematch, avals1, avals2)): + diff = tree_map(_show_diff, tree_unflatten(tree1, avals1), + tree_unflatten(tree2, avals2)) + raise TypeError(f"{what} must have identical types, got\n{diff}.") + + +def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False): + if has_aux: + actual_tree_children = actual_tree.children() + + if len(actual_tree_children) == 2: + # select first child as result tree + actual_tree = tree_structure(actual_tree_children[0]) + else: + raise ValueError( + f"{func_name}() produced a pytree with structure " + f"{actual_tree}, but a pytree tuple with auxiliary " + f"output was expected because has_aux was set to True.") + + if actual_tree != expected_tree: + raise TypeError( + f"{func_name}() output pytree structure must match {expected_name}, " + f"got {actual_tree} and {expected_tree}.") + +def _prune_zeros(ts): + return [t for t in ts if type(t) is not ad_util.Zero] + +def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]): + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals) + return core.ClosedJaxpr(jaxpr, consts) + +def _show_diff(array1, array2): + if core.typematch(array1, array2): + return f"{array1}" + return f"DIFFERENT {array1} vs. {array2}" + +def _avals_short(avals): + to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))() + return ' '.join(map(to_str, avals)) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py new file mode 100644 index 000000000..ae5598c3d --- /dev/null +++ b/jax/_src/lax/control_flow/conditionals.py @@ -0,0 +1,675 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for conditional control flow primitives.""" +import collections +import functools +from functools import partial +import inspect +import itertools + +from typing import Callable, Sequence + +from jax import core +from jax import linear_util as lu +from jax.config import config +from jax.core import ConcreteArray, raise_to_shaped +from jax.interpreters import ad +from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import xla +from jax.tree_util import tree_flatten, tree_unflatten +from jax._src import ad_util +from jax._src import dtypes +from jax._src import source_info_util +from jax._src import util +from jax._src.lax import lax +from jax._src.traceback_util import api_boundary +from jax._src.util import safe_map, extend_name_stack, split_list +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import mhlo +import numpy as np + +from jax._src.lax.control_flow.common import ( + _abstractify, + _avals_short, + _check_tree_and_avals, + _initial_style_jaxprs_with_common_consts, + _make_closed_jaxpr, + _prune_zeros, + _typecheck_param, + allowed_effects, + ) + +_map, unsafe_map = safe_map, map + + +# For backward compatibility with a previous switch/cond calling convention, +# we allow a single (pytree) `operand` argument to be passed by keyword. We use +# a sentinel object as its default value to indicate when it is _not_ passed. +_no_operand_sentinel = object() + +@api_boundary +def switch(index, branches: Sequence[Callable], *operands, + operand=_no_operand_sentinel): + """Apply exactly one of ``branches`` given by ``index``. + + If ``index`` is out of bounds, it is clamped to within bounds. + + Has the semantics of the following Python:: + + def switch(index, branches, *operands): + index = clamp(0, index, len(branches) - 1) + return branches[index](*operands) + + Args: + index: Integer scalar type, indicating which branch function to apply. + branches: Sequence of functions (A -> B) to be applied based on ``index``. + operands: Operands (A) input to whichever branch is applied. + + Returns: + Value (B) of ``branch(*operands)`` for the branch that was selected based + on ``index``. + """ + if not all(callable(branch) for branch in branches): + raise TypeError("lax.switch: branches argument should be a sequence of callables.") + if operand is not _no_operand_sentinel: + if operands: + raise TypeError("if 'operand' keyword is passed then no positional " + f"operands can be passed, got operand={operand} " + f"and positional operands {operands}") + operands = (operand,) + del operand + + if len(np.shape(index)) != 0: + raise TypeError( + f"Branch index must be scalar, " + f"got {index} of shape {np.shape(index)}.") + + try: + index_dtype = dtypes.result_type(index) + except TypeError as err: + msg = f"Index type must be an integer, got {index}." + raise TypeError(msg) from err + + if index_dtype.kind not in 'iu': + raise TypeError( + f"Index type must be an integer, got {index} as {index_dtype}") + + branches = tuple(branches) + + if len(branches) == 0: + raise ValueError("Empty branch sequence") + elif len(branches) == 1: + return branches[0](*operands) + + index = lax.convert_element_type(index, np.int32) + lo = np.array(0, np.int32) + hi = np.array(len(branches) - 1, np.int32) + index = lax.clamp(lo, index, hi) + + if (config.jax_disable_jit and + isinstance(core.get_aval(index), ConcreteArray)): + return branches[int(index)](*operands) + + ops, ops_tree = tree_flatten(operands) + ops_avals = tuple(_map(_abstractify, ops)) + + jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( + branches, ops_tree, ops_avals, primitive_name='switch') + for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): + _check_tree_and_avals(f"branch 0 and {i + 1} outputs", + out_trees[0], jaxprs[0].out_avals, + out_tree, jaxpr.out_avals) + joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) + disallowed_effects = joined_effects - allowed_effects + if disallowed_effects: + raise NotImplementedError( + f'Effects not supported in `switch`: {disallowed_effects}') + + linear = (False,) * (len(consts) + len(ops)) + out = cond_p.bind( + index, *consts, *ops, branches=tuple(jaxprs), linear=linear) + return tree_unflatten(out_trees[0], out) + + +def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, + operand=_no_operand_sentinel, linear=None): + """Conditionally apply ``true_fun`` or ``false_fun``. + + ``cond()`` has equivalent semantics to this Python implementation:: + + def cond(pred, true_fun, false_fun, *operands): + if pred: + return true_fun(*operands) + else: + return false_fun(*operands) + + ``pred`` must be a scalar type. + + Args: + pred: Boolean scalar type, indicating which branch function to apply. + true_fun: Function (A -> B), to be applied if ``pred`` is True. + false_fun: Function (A -> B), to be applied if ``pred`` is False. + operands: Operands (A) input to either branch depending on ``pred``. The + type can be a scalar, array, or any pytree (nested Python tuple/list/dict) + thereof. + + Returns: + Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``, + depending on the value of ``pred``. The type can be a scalar, array, or any + pytree (nested Python tuple/list/dict) thereof. + """ + if not (callable(true_fun) and callable(false_fun)): + raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") + if operand is not _no_operand_sentinel: + if operands: + raise TypeError("if 'operand' keyword is passed then no positional " + f"operands can be passed, got operand={operand} " + f"and positional operands {operands}") + operands = (operand,) + del operand + + if isinstance(pred, Sequence) or np.ndim(pred) != 0: + raise TypeError( + f"Pred must be a scalar, got {pred} of " + + (f"type {type(pred)}" if isinstance(pred, Sequence) + else f"shape {np.shape(pred)}.")) + + try: + pred_dtype = dtypes.result_type(pred) + except TypeError as err: + msg = ("Pred type must be either boolean or number, got {}.") + raise TypeError(msg.format(pred)) from err + + if pred_dtype.kind != 'b': + if pred_dtype.kind in 'iuf': + pred = pred != 0 + else: + msg = ("Pred type must be either boolean or number, got {}.") + raise TypeError(msg.format(pred_dtype)) + + if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray): + if pred: + return true_fun(*operands) + else: + return false_fun(*operands) + + ops, ops_tree = tree_flatten(operands) + if linear is None: + linear_ops = [False] * len(ops) + else: + linear_ops, ops_tree2 = tree_flatten(linear) + if ops_tree != ops_tree2: + raise TypeError('linear tree and operand tree mismatch') + ops_avals = tuple(_map(_abstractify, ops)) + + jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( + (true_fun, false_fun), ops_tree, ops_avals, 'cond') + true_jaxpr, false_jaxpr = jaxprs + out_tree, false_out_tree = out_trees + + _check_tree_and_avals("true_fun and false_fun output", + out_tree, true_jaxpr.out_avals, + false_out_tree, false_jaxpr.out_avals) + joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects) + disallowed_effects = joined_effects - allowed_effects + if disallowed_effects: + raise NotImplementedError( + f'Effects not supported in `cond`: {disallowed_effects}') + + index = lax.convert_element_type(pred, np.int32) + + linear = [False] * len(consts) + linear_ops + out = cond_p.bind( + index, *consts, *ops, + branches=(false_jaxpr, true_jaxpr), linear=tuple(linear)) + return tree_unflatten(out_tree, out) + +@api_boundary +@functools.wraps(_cond) +def cond(*args, **kwargs): + # detect an attempt to call the former, deprecated cond + try: + ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs) + except TypeError: + pass + else: + assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch + _, _, maybe_true_fun, _, maybe_false_fun = ba.args + if callable(maybe_true_fun) and callable(maybe_false_fun): + return _cond_with_per_branch_args(*ba.args) + + return _cond(*args, **kwargs) + +def _cond_with_per_branch_args(pred, + true_operand, true_fun: Callable, + false_operand, false_fun: Callable): + """Conditionally apply ``true_fun`` or ``false_fun``. + + Has equivalent semantics to this Python implementation:: + + def cond(pred, true_operand, true_fun, false_operand, false_fun): + if pred: + return true_fun(true_operand) + else: + return false_fun(false_operand) + + Pred has to be a scalar type, collection types (list, tuple) are not supported + """ + if not (callable(true_fun) and callable(false_fun)): + raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") + return _cond(pred, + lambda op: true_fun(op[0]), + lambda op: false_fun(op[1]), + (true_operand, false_operand)) + +def _cond_abstract_eval(*args, branches, **kwargs): + joined_effects = core.join_effects(*(b.effects for b in branches)) + disallowed_effects = joined_effects - allowed_effects + if disallowed_effects: + raise NotImplementedError( + f'Effects not supported in `cond`: {disallowed_effects}') + joined_effects = core.join_effects(*(b.effects for b in branches)) + return _map(raise_to_shaped, branches[0].out_avals), joined_effects + +def _bcast_select(pred, on_true, on_false): + if np.ndim(pred) != np.ndim(on_true): + idx = list(range(np.ndim(pred))) + pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx) + return lax.select(pred, on_true, on_false) + +def _bcast_select_n(pred, *cases): + if np.ndim(pred) != np.ndim(cases[0]): + idx = list(range(np.ndim(pred))) + pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) + return lax.select_n(pred, *cases) + +def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear): + index, *ops = args + index_dim, *op_dims = dims + + if index_dim is not batching.not_mapped: + # Convert to a lax.select. While we could get away with not broadcasting + # some operands yet, because all outputs must be broadcast together anyway + # for the select we broadcast the input operands for simplicity and leave + # optimizations to XLA. + # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! + index, *ops = ( + batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) + + in_batched = [True] * len(branches[0].in_avals) + out_batched = [True] * len(branches[0].out_avals) + + branches_batched = [ + batching.batch_jaxpr( + jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0] + for jaxpr in branches] + + branch_outs = [] + for i, jaxpr in enumerate(branches_batched): + # Perform a select on the inputs for safety of reverse-mode autodiff; see + # https://github.com/google/jax/issues/1052 + predicate = lax.eq(index, lax._const(index, i)) + ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops] + branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) + out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)] + return out, [0 if b else None for b in out_batched] + else: + ops_bat = [d is not batching.not_mapped for d in op_dims] + ops = [batching.moveaxis(x, d, 0) if b else x + for b, x, d in zip(ops_bat, ops, op_dims)] + + branches_out_bat = [ + batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1] + for jaxpr in branches] + out_bat = [any(bat) for bat in zip(*branches_out_bat)] + branches_batched = tuple( + batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0] + for jaxpr in branches) + + out_dims = [0 if b else batching.not_mapped for b in out_bat] + out = cond_p.bind( + index, *ops, branches=branches_batched, linear=linear) + return out, out_dims + +def _cond_jvp(primals, tangents, branches, linear): + nonzeros = [type(t) is not ad_util.Zero for t in tangents] + + index_nz, *ops_nz = nonzeros + assert index_nz is False + + branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1] + for jaxpr in branches] + out_nz = [any(nz) for nz in zip(*branches_out_nz)] + + branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0] + for jaxpr in branches) + + index, *ops = primals + _, *ops_dot = tangents + ops_dot = _prune_zeros(ops_dot) + + ops_lin = tuple(linear) + linear_jvp = ops_lin + (True,) * len(ops_dot) + out = cond_p.bind( + index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp) + out_primals, out_tangents = split_list(out, [len(out_nz)]) + out_tangents_iter = iter(out_tangents) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + for p, nz in zip(out_primals, out_nz)] + return out_primals, out_tangents + +def _cond_partial_eval(trace, *tracers, branches, linear): + in_unknowns = [t.pval[0] is not None for t in tracers] + index_uk, *ops_uk = in_unknowns + + if index_uk: + # When the branch index is unknown, we stage out the whole cond. + # TODO(mattjj): remove this path when old remat is removed + params = dict(branches=branches, linear=linear) + return trace.default_process_primitive(cond_p, tracers, params) + + branches_out_uks = [] + for branch_jaxpr in branches: + _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits( + branch_jaxpr, ops_uk, instantiate=False) + branches_out_uks.append(out_uks) + out_uks = [any(uks) for uks in zip(*branches_out_uks)] + + branches_known, branches_unknown, branch_res_avals = [], [], [] + for branch_jaxpr in branches: + branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \ + pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks) + branches_known.append(branch_jaxpr_known) + branches_unknown.append(branch_jaxpr_unknown) + branch_res_avals.append(res_avals) + + all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals) + num_res = len(all_res_avals) + + num_known_outs = len(out_uks) - sum(out_uks) + branches_known = _join_cond_outputs( + branches_known, all_res_avals, res_avals_per_branch, num_known_outs) + branches_unknown = _join_cond_pe_staged_jaxpr_inputs( + branches_unknown, all_res_avals, res_avals_per_branch) + assert all(all(_map(core.typematch, j.out_avals, branches_known[0].out_avals)) + for j in branches_known[1:]) + + in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] + linear_known = [l for l, uk in zip(linear, ops_uk) if not uk] + out_consts_res = cond_p.bind(*in_consts, branches=branches_known, + linear=tuple(linear_known)) + out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) + + index_tracer = trace.instantiate_const(tracers[0]) + ops_tracers = [trace.instantiate_const(t) + for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk] + res_tracers = _map(trace.new_instantiated_const, res) + out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) + for aval in branches_unknown[0].out_avals] + linear_unknown = ([False] * num_res + + [l for l, uk in zip(linear, in_unknowns[1:]) if uk]) + params = dict(branches=branches_unknown, linear=tuple(linear_unknown)) + name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] + source = source_info_util.current().replace(name_stack=name_stack) + eqn = pe.new_eqn_recipe( + [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, + core.no_effects, source) + for t in out_tracers: t.recipe = eqn + return util.merge_lists(out_uks, out_consts, out_tracers) + +# When partially evaluating conditionals, each branch produces residuals +# depending on the computation carried out by the branch, and a corresponding +# staged jaxpr that accepts those residuals as its first few inputs. The +# residual-producing branches are staged as jaxprs and bound right away in a +# conditional. The residual-consuming jaxprs are assembled together in a jaxpr +# conditional. The following helper functions ensure that both collections of +# jaxprs (those evaluated and those staged) are valid for joint use under their +# respective conditionals. +# +# In particular, the residuals derived from each original branch may have +# distinct types. Because the branches of conditionals must have identical type +# signatures, we join residuals together across branches into a common format. + +# In order to set up a type signature that all branches can conform to, it would +# suffice to concatenate all branches' residuals. But concatenation can result +# in redundant inputs and outputs, and might lead to memory allocation that +# scales unnecessarily with the branch count. This function finds common +# residual types across branches for reuse, so as to avoid redundant +# allocation. It returns a list L of types (avals) representing the collection +# of residuals merged according to type, and, for each branch, a lookup table to +# match its residuals to their positions/types in L. Example input/output: +# +# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]] +# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]] +# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]] +def _merge_branch_residuals(branch_res_avals): + def enumerate_equal(xs): + counts = {v: itertools.count() for v in set(xs)} + return [(x, next(counts[x])) for x in xs] + branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals) + all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals)) + indices = {v: i for i, v in enumerate(all_tagged_avals)} + branch_indices = [ + [indices[aval] for aval in avals] for avals in branch_res_tagged_avals] + all_avals = [x for x, _ in all_tagged_avals] + return all_avals, branch_indices + +# This function augments branch outputs to agree with the merged residual +# format: each branch is made to return zero-filled values in the places of +# residual outputs that it does not populate. +def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr, + num_non_res_outputs): + def augment_jaxpr(jaxpr, res_indices): + @lu.wrap_init + def f_aug(*args): + outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args) + outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs]) + aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals) + aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals)) + return outs + list(aug_residuals) + + return _make_closed_jaxpr(f_aug, jaxpr.in_avals) + + return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr)) + +# This function augments branch inputs to agree with the merged residual format: +# each branch is made to accept all residuals, even though it will ignore those +# that it does not read. +def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals, + res_aval_indices_per_jaxpr): + newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_') + all_res_vars = _map(newvar, all_res_avals) + + def augment_jaxpr(jaxpr, res_indices): + num_res = len(res_indices) + res_vars = jaxpr.jaxpr.invars[:num_res] + non_res_vars = jaxpr.jaxpr.invars[num_res:] + + aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars))) + aug_invars = aug_res_vars + non_res_vars + jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, + jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns, + jaxpr.jaxpr.effects) + jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) + return jaxpr_aug + + return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr)) + +def _ordered_unique(xs): + d = collections.OrderedDict((x, None) for x in xs) + return list(d.keys()) + +def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes): + res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) + primal_avals = _map(raise_to_shaped, primal_avals) + + @lu.wrap_init + def transposed(*args): + res, cts_out = split_list(args, [num_res]) + primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] + cts_in = ad.backward_pass( + jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out) + _, cts_in = split_list(cts_in, [num_res]) + return _map(ad.instantiate_zeros_aval, primal_avals, cts_in) + + return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals) + +def _cond_transpose(reduce_axes, cts, *args, branches, linear): + index, *ops = args + in_avals = _map(raise_to_shaped, branches[0].in_avals) + num_res = len(ops) - sum(linear) + + branches_trans = tuple( + _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches) + lin_in_avals = [raise_to_shaped(a, weak_type=False) + for a, l in zip(in_avals, linear) if l] + assert all(core.typematch(out_aval, lin_in_aval) + for jaxpr in branches_trans + for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) + + res = ops[:num_res] + cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts) + linear_trans = (False,) * num_res + (True,) * len(cts) + + out = cond_p.bind( + index, *res, *cts, branches=branches_trans, linear=linear_trans) + assert all(_map(core.typecheck, lin_in_avals, out)) + + out_iter = iter(out) + out = [next(out_iter) if l else None for l in linear] + assert next(out_iter, None) is None + return [None] + out + +def _cond_typecheck(*avals, branches, linear): + tc = partial(_typecheck_param, 'cond') + tc(branches, 'branches', 'tuple of ClosedJaxpr', + type(branches) is tuple and + all(type(x) is core.ClosedJaxpr for x in branches)) + tc(linear, 'linear', 'tuple of bool', + type(linear) is tuple and all(type(x) is bool for x in linear)) + + if len(branches) == 0: + raise core.JaxprTypeError('cond requires at least one branch function') + if len(linear) + 1 != len(avals): + raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for ' + f'{len(avals) - 1} non-predicate operands') + + jaxpr0 = branches[0] + jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals) + jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals) + joined_effects = core.join_effects(*(b.effects for b in branches)) + disallowed_effects = joined_effects - allowed_effects + if disallowed_effects: + raise NotImplementedError( + f'Effects not supported in `cond`: {disallowed_effects}') + + for i, jaxpr in enumerate(branches[1:]): + if len(jaxpr0.in_avals) != len(jaxpr.in_avals): + raise core.JaxprTypeError( + f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, ' + f'branch {i+1} takes {len(jaxpr.in_avals)}') + if len(jaxpr0.out_avals) != len(jaxpr.out_avals): + raise core.JaxprTypeError( + f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, ' + f'branch {i+1} outputs {len(jaxpr.out_avals)}') + if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)): + raise core.JaxprTypeError( + f'cond branches 0 and {i+1} have mismatching input types: ' + f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}') + if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)): + raise core.JaxprTypeError( + f'cond branches 0 and {i+1} have mismatching output types: ' + f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}') + + if len(avals) != 1 + len(jaxpr0.in_avals): + raise core.JaxprTypeError( + f'cond called with {len(avals) - 1} non-predicate operands, ' + f'but branches take {len(jaxpr0.in_avals)} inputs') + + index_aval, *op_avals = avals + if index_aval.dtype != np.int32: + raise core.JaxprTypeError( + f'cond called with index of type {index_aval.dtype} instead of int32') + if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)): + raise core.JaxprTypeError( + f'cond branches take input types {jaxpr0_in_avals_str}, ' + f'called with operands of type {_avals_short(op_avals)}') + if any(b.effects != branches[0].effects for b in branches[1:]): + raise core.JaxprTypeError( + f'cond branches must have matching effect types: ' + f'{[b.effects for b in branches]}') + joined_effects = core.join_effects(*(b.effects for b in branches)) + return None, joined_effects + +def cond_bind(*args, branches, linear): + if config.jax_enable_checks: + avals = _map(core.get_aval, args) + _cond_typecheck(*avals, branches=branches, linear=linear) + for jaxpr in branches: + core.check_jaxpr(jaxpr.jaxpr) + return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear) + +cond_p = core.AxisPrimitive('cond') +cond_p.multiple_results = True +cond_p.def_impl(partial(xla.apply_primitive, cond_p)) +cond_p.def_effectful_abstract_eval(_cond_abstract_eval) +cond_p.def_custom_bind(cond_bind) +ad.primitive_jvps[cond_p] = _cond_jvp +ad.reducing_transposes[cond_p] = _cond_transpose +pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval +batching.axis_primitive_batchers[cond_p] = _cond_batching_rule +xla.register_initial_style_primitive(cond_p) +core.custom_typechecks[cond_p] = _cond_typecheck +pe.partial_eval_jaxpr_custom_rules[cond_p] = \ + partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond') + +def _cond_lowering(ctx, index, *args, branches, linear): + del linear # Unused. + joined_effects = core.join_effects(*(branch.effects for branch in branches)) + ordered_effects = [eff for eff in joined_effects + if eff in core.ordered_effects] + num_tokens = len(ordered_effects) + tokens_in = ctx.tokens_in.subset(ordered_effects) + output_token_types = [mlir.token_type() for _ in ordered_effects] + output_types = [ + *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)] + flat_output_types = util.flatten(output_types) + + # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks + # have no arguments; the computation within the block uses implicit + # captures. + case_op = mhlo.CaseOp(flat_output_types, index=index, + num_branches=len(branches)) + name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond') + for i, jaxpr in enumerate(branches): + branch = case_op.regions[i].blocks.append() + with ir.InsertionPoint(branch): + sub_ctx = ctx.module_context.replace( + name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun')) + out_vals, tokens_out = mlir.jaxpr_subcomp( + sub_ctx, jaxpr.jaxpr, tokens_in, + _map(mlir.ir_constants, jaxpr.consts), + *_map(mlir.wrap_singleton_ir_values, args)) + out_tokens = [tokens_out.get(eff) for eff in ordered_effects] + out_vals = [*out_tokens, *out_vals] + mhlo.ReturnOp(util.flatten(out_vals)) + + tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types)) + tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens]) + ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens))) + return outputs + +mlir.register_lowering(cond_p, _cond_lowering) diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow/loops.py similarity index 59% rename from jax/_src/lax/control_flow.py rename to jax/_src/lax/control_flow/loops.py index 0ac94ccc7..4635f3e91 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1,4 +1,4 @@ -# Copyright 2019 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,228 +11,896 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Control flow primitives. -""" - - -import collections -import functools +"""Module for the loop primitives.""" from functools import partial -import inspect import itertools import operator -import os -from typing import Any, Callable, Optional, Sequence, Set, Tuple, TypeVar, List -import numpy as np +from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar import jax -from jax._src import api from jax import core +from jax import linear_util as lu +from jax.config import config +from jax.core import ConcreteArray, ShapedArray, raise_to_shaped +from jax.interpreters import ad +from jax.interpreters import batching +from jax.interpreters import masking +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import xla +from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, + tree_map) from jax._src import ad_checkpoint +from jax._src import ad_util +from jax._src import api from jax._src import dtypes from jax._src import source_info_util from jax._src import util from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lax import windowed_reductions -from jax import linear_util as lu -from jax.core import ConcreteArray, ShapedArray, raise_to_shaped -from jax._src.api_util import flatten_fun_nokwargs -from jax.interpreters import ad -from jax.interpreters import partial_eval as pe -from jax.interpreters import mlir -from jax.interpreters import xla -from jax.interpreters import batching -from jax.interpreters import masking from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import mhlo from jax._src.traceback_util import api_boundary -from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, - split_list, cache, extend_name_stack, wrap_name) -from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, - treedef_children, treedef_tuple, tree_map, - tree_leaves, tree_structure) -from jax._src import ad_util -from jax.config import config +from jax._src.util import ( + cache, + extend_name_stack, + safe_map, + safe_zip, + split_list, + unzip2, + ) +import numpy as np + +from jax._src.lax.control_flow.common import ( + _abstractify, + _avals_short, + _check_tree_and_avals, + _initial_style_jaxpr, + _make_closed_jaxpr, + _prune_zeros, + _typecheck_param, + allowed_effects, + ) _map = safe_map zip = safe_zip -_reduce = functools.reduce T = TypeVar('T') Array = Any BooleanNumeric = Any # A bool, or a Boolean array. -allowed_effects: Set[core.Effect] = set() +### Helper functions -@cache() -def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: Optional[str] = None): - wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - debug = pe.debug_info(fun, in_tree, False, primitive_name or "") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) - return jaxpr, consts, out_tree() +def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): + """Promote weakly-typed in_vals to be compatible with out_avals. -@cache() -def _initial_style_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: Optional[str] = None): - jaxpr, consts, out_tree = _initial_style_open_jaxpr( - fun, in_tree, in_avals, primitive_name) - closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) - return closed_jaxpr, consts, out_tree - -@cache() -def _initial_style_jaxprs_with_common_consts( - funs: Sequence[Callable], in_tree, in_avals, primitive_name: str): - # When staging the branches of a conditional into jaxprs, constants are - # extracted from each branch and converted to jaxpr arguments. To use the - # staged jaxprs as the branches to a conditional *primitive*, we need for - # their (input) signatures to match. This function "joins" the staged jaxprs: - # for each one, it makes another that accepts *all* constants, but only uses - # those that it needs (dropping the rest). - - jaxprs, all_consts, all_out_trees = \ - unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name) - for fun in funs) - - newvar = core.gensym(jaxprs, suffix='_') - all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts] - for consts in all_consts] - unused_const_vars = [[newvar(aval) for aval in const_avals] - for const_avals in all_const_avals] - - def pad_jaxpr_constvars(i, jaxpr): - prefix = util.concatenate(unused_const_vars[:i]) - suffix = util.concatenate(unused_const_vars[i + 1:]) - constvars = [*prefix, *jaxpr.constvars, *suffix] - return jaxpr.replace(constvars=constvars) - - consts = util.concatenate(all_consts) - jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)] - closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) - for jaxpr in jaxprs] - return closed_jaxprs, consts, all_out_trees - -def _abstractify(x): - return raise_to_shaped(core.get_aval(x)) - -def _typecheck_param(prim, param, name, msg_required, pred): - if not pred: - msg = (f'invalid {prim} param {name} of type {type(param).__name__}, ' - f'{msg_required} required:') - param_str = str(param) - sep = os.linesep if os.linesep in param_str else ' ' - msg = sep.join([msg, param_str]) - raise core.JaxprTypeError(msg) + Args: + in_vals : flattened list of input values. + in_avals : corresponding list of avals. + out_avals : list of target output avals. + Returns: + in_vals_new : flattened list of modified in_vals with no weak types. + changed : bool; true if in_vals required modification. + """ + if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals): + # Calling function is responsible for catching this. + return in_vals, False + weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals)) + if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)] + if not weak_mismatches: + return in_vals, False + for i in weak_mismatches: + new_dtype = dtypes.result_type(in_vals[i], out_avals[i]) + in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype) + return in_vals, True -### fori_loop and while_loop +### scan -def _fori_cond_fun(loop_carry): - i, upper, _ = loop_carry - return lax.lt(i, upper) - -@cache() -def _fori_body_fun(body_fun): - def while_body_fun(loop_carry): - i, upper, x = loop_carry - return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x) - return while_body_fun - -@cache() -def _fori_scan_body_fun(body_fun): - def scanned_fun(loop_carry, _): - i, x = loop_carry - return (i + 1, body_fun(i, x)), None - return scanned_fun +Carry = TypeVar('Carry') +X = TypeVar('X') +Y = TypeVar('Y') @api_boundary -def fori_loop(lower, upper, body_fun, init_val): - """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. +def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1) -> Tuple[Carry, Y]: + """Scan a function over leading array axes while carrying along state. The `Haskell-like type signature`_ in brief is .. code-block:: haskell - fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a + scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) - The semantics of ``fori_loop`` are given by this Python implementation:: + where we use [t] here to denote the type t with an additional leading axis. + That is, if t is an array type then [t] represents the type with an additional + leading axis, and if t is a pytree (container) type with array leaves then [t] + represents the type with the same pytree structure and corresponding leaves + each with an additional leading axis. - def fori_loop(lower, upper, body_fun, init_val): - val = init_val - for i in range(lower, upper): - val = body_fun(i, val) - return val + When ``a`` is an array type or None, and ``b`` is an array type, the semantics + of ``scan`` are given roughly by this Python implementation:: - Unlike that Python version, ``fori_loop`` is implemented in terms of either a - call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the - trip count is static (meaning known at tracing time, perhaps because ``lower`` - and ``upper`` are Python integer literals) then the ``fori_loop`` is - implemented in terms of ``scan`` and reverse-mode autodiff is supported; - otherwise, a ``while_loop`` is used and reverse-mode autodiff is not - supported. See those functions' docstrings for more information. + def scan(f, init, xs, length=None): + if xs is None: + xs = [None] * length + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, np.stack(ys) - Also unlike the Python analogue, the loop-carried value ``val`` must hold a - fixed shape and dtype across all iterations (and not just be consistent up to - NumPy rank/shape broadcasting and dtype promotion rules, for example). In - other words, the type ``a`` in the type signature above represents an array - with a fixed shape and dtype (or a nested tuple/list/dict container data - structure with a fixed structure and arrays with fixed shape and dtype at the - leaves). + Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree + types, and so multiple arrays can be scanned over at once and produce multiple + output arrays. (None is actually an empty pytree.) + + Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to + a single XLA While HLO. That makes it useful for reducing compilation times + for jit-compiled functions, since native Python loop constructs in an ``@jit`` + function are unrolled, leading to large XLA computations. + + Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype + across all iterations (and not just be consistent up to NumPy rank/shape + broadcasting and dtype promotion rules, for example). In other words, the type + ``c`` in the type signature above represents an array with a fixed shape and + dtype (or a nested tuple/list/dict container data structure with a fixed + structure and arrays with fixed shape and dtype at the leaves). .. note:: - :py:func:`fori_loop` compiles ``body_fun``, so while it can be combined with + :py:func:`scan` compiles ``f``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. Args: - lower: an integer representing the loop index lower bound (inclusive) - upper: an integer representing the loop index upper bound (exclusive) - body_fun: function of type ``(int, a) -> a``. - init_val: initial loop carry value of type ``a``. + f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning + that ``f`` accepts two arguments where the first is a value of the loop + carry and the second is a slice of ``xs`` along its leading axis, and that + ``f`` returns a pair where the first element represents a new value for + the loop carry and the second represents a slice of the output. + init: an initial loop carry value of type ``c``, which can be a scalar, + array, or any pytree (nested Python tuple/list/dict) thereof, representing + the initial loop carry value. This value must have the same structure as + the first element of the pair returned by ``f``. + xs: the value of type ``[a]`` over which to scan along the leading axis, + where ``[a]`` can be an array or any pytree (nested Python + tuple/list/dict) thereof with consistent leading axis sizes. + length: optional integer specifying the number of loop iterations, which + must agree with the sizes of leading axes of the arrays in ``xs`` (but can + be used to perform scans where no input ``xs`` are needed). + reverse: optional boolean specifying whether to run the scan iteration + forward (the default) or in reverse, equivalent to reversing the leading + axes of the arrays in both ``xs`` and in ``ys``. + unroll: optional positive int specifying, in the underlying operation of the + scan primitive, how many scan iterations to unroll within a single + iteration of a loop. Returns: - Loop value from the final iteration, of type ``a``. + A pair of type ``(c, [b])`` where the first element represents the final + loop carry value and the second element represents the stacked outputs of + the second output of ``f`` when scanned over the leading axis of the inputs. .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature """ - if not callable(body_fun): - raise TypeError("lax.fori_loop: body_fun argument should be callable.") - # TODO(phawkins): perhaps do more type checking here, better error messages. - lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower)) - upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper)) - if lower_dtype != upper_dtype: - msg = ("lower and upper arguments to fori_loop must have equal types, " - "got {} and {}") - raise TypeError(msg.format(lower_dtype.name, upper_dtype.name)) + if not callable(f): + raise TypeError("lax.scan: f argument should be a callable.") + xs_flat, xs_tree = tree_flatten(xs) - # If we can specialize on the trip count, call scan instead of a while_loop - # to enable efficient reverse-mode differentiation. - if (isinstance(core.get_aval(lower), ConcreteArray) and - isinstance(core.get_aval(upper), ConcreteArray)): - try: - lower_ = int(lower) - upper_ = int(upper) - except TypeError: - use_scan = False + try: + lengths = [x.shape[0] for x in xs_flat] + except AttributeError as err: + msg = "scan got value with no leading axis to scan over: {}." + raise ValueError( + msg.format(', '.join(str(x) for x in xs_flat + if not hasattr(x, 'shape')))) from err + + if length is not None: + length = int(length) + if not all(length == l for l in lengths): + msg = ("scan got `length` argument of {} which disagrees with " + "leading axis sizes {}.") + raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) + else: + unique_lengths = set(lengths) + if len(unique_lengths) > 1: + msg = "scan got values with different leading axis sizes: {}." + raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) + elif len(unique_lengths) == 0: + msg = "scan got no values to scan over and `length` not provided." + raise ValueError(msg) else: - use_scan = True + length, = unique_lengths + + if config.jax_disable_jit: + if length == 0: + raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.") + carry = init + ys = [] + maybe_reversed = reversed if reverse else lambda x: x + for i in maybe_reversed(range(length)): + xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat] + carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) + ys.append(y) + stack = lambda *ys: jax.numpy.stack(ys) + stacked_y = tree_map(stack, *maybe_reversed(ys)) + return carry, stacked_y + + x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat] + x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat] + x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes)) + + def _create_jaxpr(init): + init_flat, init_tree = tree_flatten(init) + in_flat, in_tree = tree_flatten((init, xs)) + + carry_avals = tuple(_map(_abstractify, init_flat)) + jaxpr, consts, out_tree = _initial_style_jaxpr( + f, in_tree, carry_avals + x_avals, "scan") + out_tree_children = out_tree.children() + if len(out_tree_children) != 2: + msg = "scan body output must be a pair, got {}." + raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) + carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves] + return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children + + # The carry input and output avals must match exactly. However, we want to account for + # the case when init contains weakly-typed values (e.g. Python scalars), with avals that + # may not match the output despite being compatible by virtue of their weak type. + # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if + # necessary, a second time with modified init values. + init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) + new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out) + if changed: + new_init = tree_unflatten(init_tree, new_init_flat) + init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init) + in_flat, jaxpr, consts, out_tree, out_tree_children = rest + + _check_tree_and_avals("scan carry output and input", + # Extract the subtree and avals for the first element of the return tuple + out_tree_children[0], carry_avals_out, + init_tree, carry_avals) + disallowed_effects = jaxpr.effects - allowed_effects + if disallowed_effects: + raise NotImplementedError( + f'Effects not supported in `scan`: {disallowed_effects}') + + out = scan_p.bind(*consts, *in_flat, + reverse=reverse, length=length, jaxpr=jaxpr, + num_consts=len(consts), num_carry=len(init_flat), + linear=(False,) * (len(consts) + len(in_flat)), + unroll=unroll) + return tree_unflatten(out_tree, out) + +def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear, + f_impl, x_avals, y_avals): + consts, init, xs = split_list(args, [num_consts, num_carry]) + + carry = init + ys = [] + + for i in range(length): + i_ = length - i - 1 if reverse else i + x = _map(partial(_index_array, i_), x_avals, xs) + out = f_impl(*consts, *carry, *x) + carry, y = split_list(out, [num_carry]) + ys.append(y) + + ys = list(reversed(ys)) if reverse else ys + ys = list(zip(*ys)) + ys = _map(_stack, y_avals, ys) + return (*carry, *ys) + +def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear, + f_impl, x_avals, y_avals): + consts, init, xs = split_list(args, [num_consts, num_carry]) + + def cond_fun(vals): + i, *_ = vals + return i < length + + def body_fun(vals): + [i], carry, ys = split_list(vals, [1, num_carry]) + i_ = length - i - 1 if reverse else i + x = _map(partial(_dynamic_index_array, i_), x_avals, xs) + out_flat = f_impl(*consts, *carry, *x) + carry_out, y_updates = split_list(out_flat, [num_carry]) + ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates) + return [i + 1] + carry_out + ys_out + + ys_init = _map(partial(_empty_array, length), y_avals) + if length == 0: + return init + ys_init else: - use_scan = False + init_val = [lax._const(length, 0)] + init + ys_init + _, *outs = while_loop(cond_fun, body_fun, init_val) + return outs - if use_scan: - if config.jax_disable_jit and upper_ == lower_: - # non-jit implementation of scan does not support length=0 - return init_val +def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry, + linear, block_length, f_impl, x_avals, y_avals): + consts, init, xs = split_list(args, [num_consts, num_carry]) - (_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val), - None, length=upper_ - lower_) + num_blocks, rem = divmod(length, block_length) + assert rem == 0 + + partition = partial(_partition_leading, num_blocks, block_length) + xs_block = _map(partition, x_avals, xs) + + prepend_aval = partial(_prepend_dim_to_aval, block_length) + x_block_avals = _map(prepend_aval, x_avals) + y_block_avals = _map(prepend_aval, y_avals) + + f_impl_block = partial( + _scan_impl_unrolled, reverse=reverse, length=block_length, + num_consts=num_consts, num_carry=num_carry, linear=linear, + f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) + + outs = _scan_impl_loop( + *consts, *init, *xs_block, reverse=reverse, length=num_blocks, + num_consts=num_consts, num_carry=num_carry, linear=linear, + f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals) + + carry, ys_blocks = split_list(outs, [num_carry]) + combine = partial(_combine_leading, num_blocks, block_length) + ys = _map(combine, y_avals, ys_blocks) + return (*carry, *ys) + +def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, + unroll): + _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) + _, y_avals = split_list(jaxpr.out_avals, [num_carry]) + f_impl = core.jaxpr_as_fun(jaxpr) + + if unroll == 1: + return _scan_impl_loop( + *args, reverse=reverse, length=length, num_consts=num_consts, + num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, + y_avals=y_avals) + + consts, init, xs = split_list(args, [num_consts, num_carry]) + num_blocks, rem = divmod(length, unroll) + length_div = num_blocks * unroll + + if rem > 0: + if reverse: + split = partial(_split_leading_dim, rem) + xs_rem, xs = unzip2(_map(split, x_avals, xs)) + else: + split = partial(_split_leading_dim, length_div) + xs, xs_rem = unzip2(_map(split, x_avals, xs)) + + outs = _scan_impl_block_unrolled( + *consts, *init, *xs, reverse=reverse, length=length_div, + num_consts=num_consts, num_carry=num_carry, linear=linear, + block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) + + carry, ys = split_list(outs, [num_carry]) + + if rem > 0: + outs = _scan_impl_unrolled( + *consts, *carry, *xs_rem, reverse=reverse, length=rem, + num_consts=num_consts, num_carry=num_carry, linear=linear, + f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) + carry, ys_rem = split_list(outs, [num_carry]) + if reverse: + ys = _map(_concatenate, y_avals, ys_rem, ys) + else: + ys = _map(_concatenate, y_avals, ys, ys_rem) + + return (*carry, *ys) + +def _stack(aval, vals): + vals = [lax.expand_dims(x, (0,)) for x in vals] + return lax.concatenate(vals, 0) + +def _concatenate(aval, x1, x2): + return lax.concatenate([x1, x2], 0) + +def _split_leading_dim(i, aval, x): + assert x.ndim >= 1 + return (slicing.slice_in_dim(x, 0, i), + slicing.slice_in_dim(x, i, x.shape[0])) + +def _dynamic_index_array(i, aval, x): + return slicing.dynamic_index_in_dim(x, i, keepdims=False) + +def _index_array(i, aval, x): + return slicing.index_in_dim(x, i, keepdims=False) + +def _empty_array(sz, aval): + return lax.full((sz,) + aval.shape, 0, aval.dtype) + +def _update_array(i, aval, xs, x): + return slicing.dynamic_update_index_in_dim(xs, x, i, 0) + +def _partition_leading(sz0, sz1, aval, x): + assert x.ndim >= 1 + assert x.shape[0] == sz0 * sz1 + return lax.reshape(x, (sz0, sz1, *x.shape[1:])) + +def _combine_leading(sz0, sz1, aval, x): + assert x.ndim >= 2 + assert x.shape[0] == sz0 + assert x.shape[1] == sz1 + return lax.collapse(x, 0, 2) + +def _prepend_dim_to_aval(sz, aval): + return core.unmapped_aval(sz, core.no_axis_name, 0, aval) + +def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, + linear, unroll): + carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) + return carry_avals + ys_avals, jaxpr.effects + +def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, + linear, unroll): + num_xs = len(jaxpr.in_avals) - num_carry - num_consts + num_ys = len(jaxpr.out_avals) - num_carry + nonzeros = [type(t) is not ad_util.Zero for t in tangents] + const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry]) + + # Fixpoint computation of which carry are not ad.zero: either + # non-zero from init, or the carry out is non-zero. Each iteration promotes + # at least one carry to non-zero. We need at most len(carry) iterations, + # but we need one last iteration to prepare the jaxpr based on the final + # carry_nz. + carry_nz = init_nz + for _ in range(1 + len(carry_nz)): + nonzeros = const_nz + carry_nz + xs_nz + jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr( + jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys) + carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:] + if carry_nz_out == carry_nz: + break + else: + carry_nz = _map(operator.or_, carry_nz, carry_nz_out) else: - _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun), - (lower, upper, init_val)) - return result + assert False, "Fixpoint not reached" + tangents = [ad.instantiate_zeros(t) if nz else t + for t, nz in zip(tangents, nonzeros)] + + consts, init, xs = split_list(primals, [num_consts, num_carry]) + all_tangents = split_list(tangents, [num_consts, num_carry]) + consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents) + + jaxpr_jvp_rearranged = ad.rearrange_binders( + jaxpr_jvp, + [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)], + [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)]) + + consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) + jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot) + + init_linear + [True] * len(init_dot) + + xs_linear + [True] * len(xs_dot)) + + out_flat = scan_p.bind( + *(consts + consts_dot + init + init_dot + xs + xs_dot), + reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged, + num_consts=num_consts + len(consts_dot), + num_carry=num_carry + len(init_dot), + linear=jaxpr_jvp_linear, unroll=unroll) + + carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) + primals_out = carry + ys + tangents_out_iter = iter(carry_dot + ys_dot) + tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p) + for p, nz in zip(primals_out, nonzeros_out)] + return primals_out, tangents_out + +def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, + jaxpr, linear, unroll): + num_ys = len(jaxpr.out_avals) - num_carry + unknowns = [not t.pval.is_known() for t in tracers] + const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) + + # Fixpoint computation of which carry elements are unknown. Each iteration + # promotes at least one carry to unknown. We need at most len(carry) + # iterations, but we need one last iteration to prepare the jaxpr based on the + # final carry_uk. + carry_uk = init_uk + for _ in range(1 + len(carry_uk)): + unknowns = const_uk + carry_uk + xs_uk + jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits( + jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) + carry_uk_out, ys_uk = split_list(out_uk, [num_carry]) + if carry_uk_out == carry_uk: + break + else: + carry_uk = _map(operator.or_, carry_uk, carry_uk_out) + else: + assert False, "Fixpoint not reached" + num_res = len(res_avals) + del res_avals, carry_uk_out + + # Instantiate those inputs which must be treated as unknown from the fixpoint. + tracers = [trace.instantiate_const(t) if uk else t + for t, uk in zip(tracers, unknowns)] + + # The residual inputs and outputs of the jaxprs produced haven't yet been + # adapted to the scan calling convention; in particular, jaxpr_known has its + # residual outputs all at the end, meaning they're extensive outputs (which is + # fully general but may be wasteful for residuals which are loop-invariant) + # while jaxpr_unknown has its corresponding residual inputs at the front (just + # as a convention with partial_eval_jaxpr_nounits), making them constant + # inputs. To make them consistent, we move the residual inputs on + # jaxpr_unknown to the end, even though we may move some back in the sequel. + jaxpr_unknown = pe.move_binders_to_back( + jaxpr_unknown, [True] * num_res + [False] * sum(unknowns)) + + # At this point, all residuals are treated as extensive outputs of jaxpr_known + # (and extensive inputs to jaxpr_unknown). But residuals that are loop- + # invariant can be hoisted out of the scan, rather than letting them get + # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't + # want to broadcast the matrix!). So, outside the loop we perform a partial + # evaluation with known 'const' inputs (but all other inputs unknown). + const_pvals = [pe.PartialVal.known(t.pval.get_known()) + for t in tracers[:num_consts] if t.pval.is_known()] + other_pvals = [pe.PartialVal.unknown(aval) + for aval in jaxpr_known.in_avals[len(const_pvals):]] + with source_info_util.reset_name_stack(): + jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( + lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals, + instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) + jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) + # The above trace_to_jaxpr_nounits call computed loop-invariant residuals + # (known values in invar_pvals_out) and also computed loop-invariant values + # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the + # previous consts). We need to collect the computed inteisive residuals, and + # move corresponding intensive residual binders in jaxpr_unknown to the front. + res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] + intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] + jaxpr_unknown = pe.move_binders_to_front( + jaxpr_unknown, + [False] * sum(unknowns) + [pval.is_known() for pval in res_pvals]) + del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals + # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and + # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown. + + # As another optimization, for any extensive inputs that are just forwarded to + # extensive outputs, to avoid a copy (which would be looping over + # dynamic-update-slice) we'd rather forward the input tracer/value. That means + # pruning some outputs from jaxpr_known here, and updating `out_flat` below. + fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr) + # Prune fwds_known to include only extensive input to extensive output. + fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and + in_idx is not None and + in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk) + else None for out_idx, in_idx in enumerate(fwds_known)] + # Drop any extensive output we can instead get by forwarding an input. + # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint + jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts + jaxpr_known_.outvars = [x for x, i in zip(jaxpr_known_.outvars, fwds_known) + if i is None] + jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ()) + del jaxpr_known_ + # We use `fwds_known` below when forming the output of scanning jaxpr_known. + + # Run the known part of the scan (if it has any outputs or effects). + known_inputs = (list(jaxpr_known_consts) + + [t.pval.get_known() for t in tracers[num_consts:] + if t.pval.is_known()]) + if not jaxpr_known.out_avals and not jaxpr_known.effects: + out_known = [] + else: + linear_known = [False] * len(known_inputs) # conservative! + out_known = scan_p.bind( + *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, + num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), + linear=tuple(linear_known), unroll=unroll) + del linear_known + # Complete the known output by filling in forwarded values using fwds_known. + out_known_iter = iter(out_known) + out_known = [next(out_known_iter) if f is None + else _maybe_put(known_inputs[f]) for f in fwds_known] + assert next(out_known_iter, None) is None + del known_inputs, out_known_iter + + # Split known outputs from residuals. + out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)]) + assert len(intensive_res) + len(extensive_res) == num_res + + # Create input tracers for jaxpr_unknown bind. + unknown_inputs = [t for t in tracers if not t.pval.is_known()] + intensive_res = _map(trace.new_instantiated_const, intensive_res) + extensive_res = _map(trace.new_instantiated_const, extensive_res) + # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. + carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) + ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval) + for y_aval in y_avals] + out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) + for a in itertools.chain(carry_avals, ys_avals)] + del carry_avals, y_avals + # Create equation. + linear_unknown = tuple([False] * len(intensive_res) + + [l for l, uk in zip(linear, unknowns) if uk] + + [False] * len(extensive_res)) + name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] + source = source_info_util.current().replace(name_stack=name_stack) + assert len(out_tracers) == len(jaxpr_unknown.out_avals) + eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res], + out_tracers, scan_p, + dict(reverse=reverse, length=length, unroll=unroll, + jaxpr=jaxpr_unknown, linear=linear_unknown, + num_consts=len(intensive_res) + sum(const_uk), + num_carry=sum(carry_uk)), + jaxpr_unknown.effects, source) + for t in out_tracers: t.recipe = eqn + + # Merge known and unknown outputs into final result. + return util.merge_lists(out_uk, out_known, out_tracers) + +def _maybe_put(x): + if isinstance(x, np.ndarray): + return jax.device_put(x, jax.devices('cpu')[0]) + else: + return x + +def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts, + num_carry, jaxpr, linear, unroll): + # we've only implemented transposing scans with specific lin/nonlin patterns + consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) + num_ires = len(consts_lin) - sum(consts_lin) + num_eres = len(xs_lin) - sum(xs_lin) + if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires): + raise NotImplementedError + if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: + raise NotImplementedError + if not all(init_lin): + pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963 + + consts, _, xs = split_list(args, [num_consts, num_carry]) + ires, _ = split_list(consts, [num_ires]) + _, eres = split_list(xs, [sum(xs_lin)]) + assert not any(ad.is_undefined_primal(r) for r in ires) + assert not any(ad.is_undefined_primal(r) for r in eres) + + carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) + ct_carry, ct_ys = split_list(cts, [num_carry]) + ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry) + ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys) + ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) + + # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) + # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) + jaxpr_trans = _transpose_scan_jaxpr( + num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes) + linear_trans = ([False] * num_ires + + [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + + [False] * num_eres) + + outs = scan_p.bind( + *(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse, + length=length, jaxpr=jaxpr_trans, num_consts=num_ires, + num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans), + unroll=unroll) + ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) + return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres + +# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b) +# -> ([res1, CT c, CT b, res2] -> [CT c, CT a]) +def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes): + num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2 + # TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals + # if an axis isn't reduced + res1_avals, c_avals, a_avals, res2_avals = split_list( + jaxpr.in_avals, [num_res1, num_c, num_a]) + num_b = len(jaxpr.out_avals) + b_avals = list(jaxpr.out_avals) + + @lu.wrap_init + def transposed(*res1_cbar_bbar_res2): + res1, c_bar, b_bar, res2 = split_list( + res1_cbar_bbar_res2, [num_res1, num_c, num_b]) + primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] + + [ad.UndefinedPrimal(aval) for aval in a_avals] + res2) + cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, + primals, b_bar) + _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) + a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) + c_bar = _map(ad.instantiate_zeros_aval, c_avals, + _map(ad.add_tangents, c_bar, new_c_bar)) + return c_bar + a_bar + return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals) + + +def _scan_batching_rule(axis_size, axis_name, main_type, args, dims, reverse, length, + jaxpr, num_consts, num_carry, linear, unroll): + num_ys = len(jaxpr.out_avals) - num_carry + orig_batched = [d is not batching.not_mapped for d in dims] + const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry]) + + # Fixpoint computation of which carry are batched: either + # batched from init, or the carry out is batched. Each iteration promotes + # at least one carry to batched. We need at most len(carry) iterations, + # but we need one last iteration to prepare the jaxpr based on the final + # carry_batched. + carry_batched = init_batched + for _ in range(1 + len(carry_batched)): + batched = const_batched + carry_batched + xs_batched + jaxpr_batched, batched_out = batching.batch_jaxpr( + jaxpr, axis_size, batched, + instantiate=carry_batched + [False] * num_ys, + axis_name=axis_name, + main_type=main_type) + carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] + if carry_batched_out == carry_batched: + break + else: + carry_batched = _map(operator.or_, carry_batched, carry_batched_out) + else: + assert False, "Fixpoint not reached" + + consts, init, xs = split_list(args, [num_consts, num_carry]) + consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) + new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 + else x for x, d in zip(consts, consts_bdims)] + new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched + else batching.moveaxis(x, d, 0) if now_batched else x + for x, d, was_batched, now_batched in + zip(init, init_bdims, init_batched, carry_batched)] + new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 + else x for x, d in zip(xs, xs_bdims)] + new_args = new_consts + new_init + new_xs + + outs = scan_p.bind( + *new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched, + num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll) + carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] + ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] + return outs, carry_bdims + ys_bdims + +def _scan_masking_rule(padded_vals, logical_shapes, reverse, length, + jaxpr, num_consts, num_carry, linear, unroll): + dynamic_length, = masking.shape_as_value((length,)) + masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry) + consts, init, xs = split_list(padded_vals, [num_consts, num_carry]) + max_length, = {x.shape[0] for x in xs} + const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) + dynamic_length = lax.convert_element_type(dynamic_length, dtypes.int_) + out_vals = scan_p.bind(dynamic_length, *consts, dtypes.int_(0), *init, *xs, + reverse=reverse, length=max_length, jaxpr=masked_jaxpr, + num_consts=1 + num_consts, num_carry=1 + num_carry, + linear=tuple([False] + const_linear + [False] + init_linear + xs_linear), + unroll=unroll) + return out_vals[1:] + +def _masked_scan_jaxpr(jaxpr, num_consts, num_carry): + fun = core.jaxpr_as_fun(jaxpr) + + @lu.wrap_init + def masked(*args): + [dynamic_length], consts, [i], carry, xs = split_list( + args, [1, num_consts, 1, num_carry]) + out = fun(*(consts + carry + xs)) + new_carry, ys = split_list(out, [num_carry]) + new_carry = [lax.select(i < dynamic_length, new_c, c) + for new_c, c in zip(new_carry, carry)] + return [i + 1] + new_carry + ys + + aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_)) + const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) + return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals) + +def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): + padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts)) + return scan_p.bind(*args, jaxpr=padded_jaxpr, **params) + +def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn + ) -> Tuple[List[bool], core.JaxprEqn]: + if not config.after_neurips: + return [True] * len(eqn.params['jaxpr'].in_avals), eqn + jaxpr = eqn.params['jaxpr'] + num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] + num_xs = len(jaxpr.in_avals) - num_consts - num_carry + used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry]) + for i in range(1 + num_carry): + used_outputs = used_carry_out + used_extensive_out + jaxpr_dce, used_inputs = pe.dce_jaxpr( + jaxpr.jaxpr, used_outputs, + instantiate=[False] * num_consts + used_carry_out + [False] * num_xs) + used_consts, used_carry_in, used_extensive_in = \ + split_list(used_inputs, [num_consts, num_carry]) + if list(used_carry_in) == list(used_carry_out): + break + else: + used_carry_out = _map(operator.or_, used_carry_out, used_carry_in) + else: + assert False, "Fixpoint not reached" + core.check_jaxpr(jaxpr.jaxpr) + + new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u] + new_params = dict(eqn.params, num_consts=sum(used_consts), + num_carry=sum(used_carry_in), linear=tuple(new_linear), + jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts)) + new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) + if used], + [v for v, used in zip(eqn.outvars, used_outputs) + if used], + eqn.primitive, new_params, eqn.effects, + eqn.source_info) + assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals ) + assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals) + return used_inputs, new_eqn + +def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry, + jaxpr, linear, unroll): + tc = partial(_typecheck_param, 'scan') + tc(reverse, 'reverse', 'bool', type(reverse) is bool) + tc(num_consts, 'num_consts', 'non-negative int', + type(num_consts) is int and num_consts >= 0) + tc(num_carry, 'num_carry', 'non-negative int', + type(num_carry) is int and num_carry >= 0) + tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr) + tc(linear, 'linear', 'tuple of bool', + type(linear) is tuple and all(type(x) is bool for x in linear)) + tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0) + + length_types = (int, masking.Poly) if bind_time else (int,) + tc(length, 'length', 'non-negative int', + type(length) in length_types and length >= 0) + + if len(linear) != len(avals): + raise core.JaxprTypeError( + f'scan param linear has length {len(linear)} for {len(avals)} operands') + + const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry]) + const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list( + jaxpr.in_avals, [num_consts, num_carry]) + carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry]) + x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) + + if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): + raise core.JaxprTypeError( + f'scan input carry input and output types mismatch: ' + f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}') + if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)): + raise core.JaxprTypeError( + f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n' + f'called with consts of type\n{_avals_short(const_avals)}') + if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)): + raise core.JaxprTypeError( + f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n' + f'called with initial carry of type\n{_avals_short(init_avals)}') + if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)): + raise core.JaxprTypeError( + f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n' + f'called with sequence of type\n{_avals_short(x_avals)}') + return None, jaxpr.effects + +def scan_bind(*args, **params): + if config.jax_enable_checks: + avals = _map(core.get_aval, args) + _scan_typecheck(True, *avals, **params) + core.check_jaxpr(params['jaxpr'].jaxpr) + return core.AxisPrimitive.bind(scan_p, *args, **params) + +scan_p = core.AxisPrimitive("scan") +scan_p.multiple_results = True +scan_p.def_custom_bind(scan_bind) +scan_p.def_impl(partial(xla.apply_primitive, scan_p)) +scan_p.def_effectful_abstract_eval(_scan_abstract_eval) +ad.primitive_jvps[scan_p] = _scan_jvp +ad.reducing_transposes[scan_p] = _scan_transpose +pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval +xla.register_initial_style_primitive(scan_p) +mlir.register_lowering(scan_p, + mlir.lower_fun(_scan_impl, multiple_results=True)) +batching.axis_primitive_batchers[scan_p] = _scan_batching_rule +masking.masking_rules[scan_p] = _scan_masking_rule +core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) +pe.partial_eval_jaxpr_custom_rules[scan_p] = \ + partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan') +pe.padding_rules[scan_p] = _scan_padding_rule +pe.dce_rules[scan_p] = _scan_dce_rule + +### while_loop @api_boundary def while_loop(cond_fun: Callable[[T], BooleanNumeric], @@ -718,1446 +1386,112 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, mlir.register_lowering(while_p, _while_lowering) +### fori_loop -### cond and switch +def _fori_cond_fun(loop_carry): + i, upper, _ = loop_carry + return lax.lt(i, upper) +@cache() +def _fori_body_fun(body_fun): + def while_body_fun(loop_carry): + i, upper, x = loop_carry + return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x) + return while_body_fun -# For backward compatibility with a previous switch/cond calling convention, -# we allow a single (pytree) `operand` argument to be passed by keyword. We use -# a sentinel object as its default value to indicate when it is _not_ passed. -_no_operand_sentinel = object() - +@cache() +def _fori_scan_body_fun(body_fun): + def scanned_fun(loop_carry, _): + i, x = loop_carry + return (i + 1, body_fun(i, x)), None + return scanned_fun @api_boundary -def switch(index, branches: Sequence[Callable], *operands, - operand=_no_operand_sentinel): - """Apply exactly one of ``branches`` given by ``index``. - - If ``index`` is out of bounds, it is clamped to within bounds. - - Has the semantics of the following Python:: - - def switch(index, branches, *operands): - index = clamp(0, index, len(branches) - 1) - return branches[index](*operands) - - Args: - index: Integer scalar type, indicating which branch function to apply. - branches: Sequence of functions (A -> B) to be applied based on ``index``. - operands: Operands (A) input to whichever branch is applied. - - Returns: - Value (B) of ``branch(*operands)`` for the branch that was selected based - on ``index``. - """ - if not all(callable(branch) for branch in branches): - raise TypeError("lax.switch: branches argument should be a sequence of callables.") - if operand is not _no_operand_sentinel: - if operands: - raise TypeError("if 'operand' keyword is passed then no positional " - f"operands can be passed, got operand={operand} " - f"and positional operands {operands}") - operands = (operand,) - del operand - - if len(np.shape(index)) != 0: - raise TypeError( - f"Branch index must be scalar, " - f"got {index} of shape {np.shape(index)}.") - - try: - index_dtype = dtypes.result_type(index) - except TypeError as err: - msg = f"Index type must be an integer, got {index}." - raise TypeError(msg) from err - - if index_dtype.kind not in 'iu': - raise TypeError( - f"Index type must be an integer, got {index} as {index_dtype}") - - branches = tuple(branches) - - if len(branches) == 0: - raise ValueError("Empty branch sequence") - elif len(branches) == 1: - return branches[0](*operands) - - index = lax.convert_element_type(index, np.int32) - lo = np.array(0, np.int32) - hi = np.array(len(branches) - 1, np.int32) - index = lax.clamp(lo, index, hi) - - if (config.jax_disable_jit and - isinstance(core.get_aval(index), ConcreteArray)): - return branches[int(index)](*operands) - - ops, ops_tree = tree_flatten(operands) - ops_avals = tuple(_map(_abstractify, ops)) - - jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - branches, ops_tree, ops_avals, primitive_name='switch') - for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): - _check_tree_and_avals(f"branch 0 and {i + 1} outputs", - out_trees[0], jaxprs[0].out_avals, - out_tree, jaxpr.out_avals) - joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) - disallowed_effects = joined_effects - allowed_effects - if disallowed_effects: - raise NotImplementedError( - f'Effects not supported in `switch`: {disallowed_effects}') - - linear = (False,) * (len(consts) + len(ops)) - out = cond_p.bind( - index, *consts, *ops, branches=tuple(jaxprs), linear=linear) - return tree_unflatten(out_trees[0], out) - - -def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, - operand=_no_operand_sentinel, linear=None): - """Conditionally apply ``true_fun`` or ``false_fun``. - - ``cond()`` has equivalent semantics to this Python implementation:: - - def cond(pred, true_fun, false_fun, *operands): - if pred: - return true_fun(*operands) - else: - return false_fun(*operands) - - ``pred`` must be a scalar type. - - Args: - pred: Boolean scalar type, indicating which branch function to apply. - true_fun: Function (A -> B), to be applied if ``pred`` is True. - false_fun: Function (A -> B), to be applied if ``pred`` is False. - operands: Operands (A) input to either branch depending on ``pred``. The - type can be a scalar, array, or any pytree (nested Python tuple/list/dict) - thereof. - - Returns: - Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``, - depending on the value of ``pred``. The type can be a scalar, array, or any - pytree (nested Python tuple/list/dict) thereof. - """ - if not (callable(true_fun) and callable(false_fun)): - raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") - if operand is not _no_operand_sentinel: - if operands: - raise TypeError("if 'operand' keyword is passed then no positional " - f"operands can be passed, got operand={operand} " - f"and positional operands {operands}") - operands = (operand,) - del operand - - if isinstance(pred, Sequence) or np.ndim(pred) != 0: - raise TypeError( - f"Pred must be a scalar, got {pred} of " + - (f"type {type(pred)}" if isinstance(pred, Sequence) - else f"shape {np.shape(pred)}.")) - - try: - pred_dtype = dtypes.result_type(pred) - except TypeError as err: - msg = ("Pred type must be either boolean or number, got {}.") - raise TypeError(msg.format(pred)) from err - - if pred_dtype.kind != 'b': - if pred_dtype.kind in 'iuf': - pred = pred != 0 - else: - msg = ("Pred type must be either boolean or number, got {}.") - raise TypeError(msg.format(pred_dtype)) - - if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray): - if pred: - return true_fun(*operands) - else: - return false_fun(*operands) - - ops, ops_tree = tree_flatten(operands) - if linear is None: - linear_ops = [False] * len(ops) - else: - linear_ops, ops_tree2 = tree_flatten(linear) - if ops_tree != ops_tree2: - raise TypeError('linear tree and operand tree mismatch') - ops_avals = tuple(_map(_abstractify, ops)) - - jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - (true_fun, false_fun), ops_tree, ops_avals, 'cond') - true_jaxpr, false_jaxpr = jaxprs - out_tree, false_out_tree = out_trees - - _check_tree_and_avals("true_fun and false_fun output", - out_tree, true_jaxpr.out_avals, - false_out_tree, false_jaxpr.out_avals) - joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects) - disallowed_effects = joined_effects - allowed_effects - if disallowed_effects: - raise NotImplementedError( - f'Effects not supported in `cond`: {disallowed_effects}') - - index = lax.convert_element_type(pred, np.int32) - - linear = [False] * len(consts) + linear_ops - out = cond_p.bind( - index, *consts, *ops, - branches=(false_jaxpr, true_jaxpr), linear=tuple(linear)) - return tree_unflatten(out_tree, out) - -@api_boundary -@functools.wraps(_cond) -def cond(*args, **kwargs): - # detect an attempt to call the former, deprecated cond - try: - ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs) - except TypeError: - pass - else: - assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch - _, _, maybe_true_fun, _, maybe_false_fun = ba.args - if callable(maybe_true_fun) and callable(maybe_false_fun): - return _cond_with_per_branch_args(*ba.args) - - return _cond(*args, **kwargs) - -def _cond_with_per_branch_args(pred, - true_operand, true_fun: Callable, - false_operand, false_fun: Callable): - """Conditionally apply ``true_fun`` or ``false_fun``. - - Has equivalent semantics to this Python implementation:: - - def cond(pred, true_operand, true_fun, false_operand, false_fun): - if pred: - return true_fun(true_operand) - else: - return false_fun(false_operand) - - Pred has to be a scalar type, collection types (list, tuple) are not supported - """ - if not (callable(true_fun) and callable(false_fun)): - raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") - return _cond(pred, - lambda op: true_fun(op[0]), - lambda op: false_fun(op[1]), - (true_operand, false_operand)) - -def _cond_abstract_eval(*args, branches, **kwargs): - joined_effects = core.join_effects(*(b.effects for b in branches)) - disallowed_effects = joined_effects - allowed_effects - if disallowed_effects: - raise NotImplementedError( - f'Effects not supported in `cond`: {disallowed_effects}') - joined_effects = core.join_effects(*(b.effects for b in branches)) - return _map(raise_to_shaped, branches[0].out_avals), joined_effects - -def _bcast_select(pred, on_true, on_false): - if np.ndim(pred) != np.ndim(on_true): - idx = list(range(np.ndim(pred))) - pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx) - return lax.select(pred, on_true, on_false) - -def _bcast_select_n(pred, *cases): - if np.ndim(pred) != np.ndim(cases[0]): - idx = list(range(np.ndim(pred))) - pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) - return lax.select_n(pred, *cases) - -def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear): - index, *ops = args - index_dim, *op_dims = dims - - if index_dim is not batching.not_mapped: - # Convert to a lax.select. While we could get away with not broadcasting - # some operands yet, because all outputs must be broadcast together anyway - # for the select we broadcast the input operands for simplicity and leave - # optimizations to XLA. - # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! - index, *ops = ( - batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) - - in_batched = [True] * len(branches[0].in_avals) - out_batched = [True] * len(branches[0].out_avals) - - branches_batched = [ - batching.batch_jaxpr( - jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0] - for jaxpr in branches] - - branch_outs = [] - for i, jaxpr in enumerate(branches_batched): - # Perform a select on the inputs for safety of reverse-mode autodiff; see - # https://github.com/google/jax/issues/1052 - predicate = lax.eq(index, lax._const(index, i)) - ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops] - branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) - out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)] - return out, [0 if b else None for b in out_batched] - else: - ops_bat = [d is not batching.not_mapped for d in op_dims] - ops = [batching.moveaxis(x, d, 0) if b else x - for b, x, d in zip(ops_bat, ops, op_dims)] - - branches_out_bat = [ - batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1] - for jaxpr in branches] - out_bat = [any(bat) for bat in zip(*branches_out_bat)] - branches_batched = tuple( - batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0] - for jaxpr in branches) - - out_dims = [0 if b else batching.not_mapped for b in out_bat] - out = cond_p.bind( - index, *ops, branches=branches_batched, linear=linear) - return out, out_dims - -def _cond_jvp(primals, tangents, branches, linear): - nonzeros = [type(t) is not ad_util.Zero for t in tangents] - - index_nz, *ops_nz = nonzeros - assert index_nz is False - - branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1] - for jaxpr in branches] - out_nz = [any(nz) for nz in zip(*branches_out_nz)] - - branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0] - for jaxpr in branches) - - index, *ops = primals - _, *ops_dot = tangents - ops_dot = _prune_zeros(ops_dot) - - ops_lin = tuple(linear) - linear_jvp = ops_lin + (True,) * len(ops_dot) - out = cond_p.bind( - index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp) - out_primals, out_tangents = split_list(out, [len(out_nz)]) - out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) - for p, nz in zip(out_primals, out_nz)] - return out_primals, out_tangents - -def _cond_partial_eval(trace, *tracers, branches, linear): - in_unknowns = [t.pval[0] is not None for t in tracers] - index_uk, *ops_uk = in_unknowns - - if index_uk: - # When the branch index is unknown, we stage out the whole cond. - # TODO(mattjj): remove this path when old remat is removed - params = dict(branches=branches, linear=linear) - return trace.default_process_primitive(cond_p, tracers, params) - - branches_out_uks = [] - for branch_jaxpr in branches: - _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits( - branch_jaxpr, ops_uk, instantiate=False) - branches_out_uks.append(out_uks) - out_uks = [any(uks) for uks in zip(*branches_out_uks)] - - branches_known, branches_unknown, branch_res_avals = [], [], [] - for branch_jaxpr in branches: - branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \ - pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks) - branches_known.append(branch_jaxpr_known) - branches_unknown.append(branch_jaxpr_unknown) - branch_res_avals.append(res_avals) - - all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals) - num_res = len(all_res_avals) - - num_known_outs = len(out_uks) - sum(out_uks) - branches_known = _join_cond_outputs( - branches_known, all_res_avals, res_avals_per_branch, num_known_outs) - branches_unknown = _join_cond_pe_staged_jaxpr_inputs( - branches_unknown, all_res_avals, res_avals_per_branch) - assert all(all(_map(core.typematch, j.out_avals, branches_known[0].out_avals)) - for j in branches_known[1:]) - - in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] - linear_known = [l for l, uk in zip(linear, ops_uk) if not uk] - out_consts_res = cond_p.bind(*in_consts, branches=branches_known, - linear=tuple(linear_known)) - out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) - - index_tracer = trace.instantiate_const(tracers[0]) - ops_tracers = [trace.instantiate_const(t) - for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk] - res_tracers = _map(trace.new_instantiated_const, res) - out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) - for aval in branches_unknown[0].out_avals] - linear_unknown = ([False] * num_res + - [l for l, uk in zip(linear, in_unknowns[1:]) if uk]) - params = dict(branches=branches_unknown, linear=tuple(linear_unknown)) - name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] - source = source_info_util.current().replace(name_stack=name_stack) - eqn = pe.new_eqn_recipe( - [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, - core.no_effects, source) - for t in out_tracers: t.recipe = eqn - return util.merge_lists(out_uks, out_consts, out_tracers) - -# When partially evaluating conditionals, each branch produces residuals -# depending on the computation carried out by the branch, and a corresponding -# staged jaxpr that accepts those residuals as its first few inputs. The -# residual-producing branches are staged as jaxprs and bound right away in a -# conditional. The residual-consuming jaxprs are assembled together in a jaxpr -# conditional. The following helper functions ensure that both collections of -# jaxprs (those evaluated and those staged) are valid for joint use under their -# respective conditionals. -# -# In particular, the residuals derived from each original branch may have -# distinct types. Because the branches of conditionals must have identical type -# signatures, we join residuals together across branches into a common format. - -# In order to set up a type signature that all branches can conform to, it would -# suffice to concatenate all branches' residuals. But concatenation can result -# in redundant inputs and outputs, and might lead to memory allocation that -# scales unnecessarily with the branch count. This function finds common -# residual types across branches for reuse, so as to avoid redundant -# allocation. It returns a list L of types (avals) representing the collection -# of residuals merged according to type, and, for each branch, a lookup table to -# match its residuals to their positions/types in L. Example input/output: -# -# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]] -# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]] -# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]] -def _merge_branch_residuals(branch_res_avals): - def enumerate_equal(xs): - counts = {v: itertools.count() for v in set(xs)} - return [(x, next(counts[x])) for x in xs] - branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals) - all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals)) - indices = {v: i for i, v in enumerate(all_tagged_avals)} - branch_indices = [ - [indices[aval] for aval in avals] for avals in branch_res_tagged_avals] - all_avals = [x for x, _ in all_tagged_avals] - return all_avals, branch_indices - -# This function augments branch outputs to agree with the merged residual -# format: each branch is made to return zero-filled values in the places of -# residual outputs that it does not populate. -def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr, - num_non_res_outputs): - def augment_jaxpr(jaxpr, res_indices): - @lu.wrap_init - def f_aug(*args): - outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args) - outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs]) - aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals) - aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals)) - return outs + list(aug_residuals) - - return _make_closed_jaxpr(f_aug, jaxpr.in_avals) - - return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr)) - -# This function augments branch inputs to agree with the merged residual format: -# each branch is made to accept all residuals, even though it will ignore those -# that it does not read. -def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals, - res_aval_indices_per_jaxpr): - newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_') - all_res_vars = _map(newvar, all_res_avals) - - def augment_jaxpr(jaxpr, res_indices): - num_res = len(res_indices) - res_vars = jaxpr.jaxpr.invars[:num_res] - non_res_vars = jaxpr.jaxpr.invars[num_res:] - - aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars))) - aug_invars = aug_res_vars + non_res_vars - jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, - jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns, - jaxpr.jaxpr.effects) - jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) - return jaxpr_aug - - return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr)) - -def _ordered_unique(xs): - d = collections.OrderedDict((x, None) for x in xs) - return list(d.keys()) - -def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes): - res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) - primal_avals = _map(raise_to_shaped, primal_avals) - - @lu.wrap_init - def transposed(*args): - res, cts_out = split_list(args, [num_res]) - primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] - cts_in = ad.backward_pass( - jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out) - _, cts_in = split_list(cts_in, [num_res]) - return _map(ad.instantiate_zeros_aval, primal_avals, cts_in) - - return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals) - -def _cond_transpose(reduce_axes, cts, *args, branches, linear): - index, *ops = args - in_avals = _map(raise_to_shaped, branches[0].in_avals) - num_res = len(ops) - sum(linear) - - branches_trans = tuple( - _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches) - lin_in_avals = [raise_to_shaped(a, weak_type=False) - for a, l in zip(in_avals, linear) if l] - assert all(core.typematch(out_aval, lin_in_aval) - for jaxpr in branches_trans - for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) - - res = ops[:num_res] - cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts) - linear_trans = (False,) * num_res + (True,) * len(cts) - - out = cond_p.bind( - index, *res, *cts, branches=branches_trans, linear=linear_trans) - assert all(_map(core.typecheck, lin_in_avals, out)) - - out_iter = iter(out) - out = [next(out_iter) if l else None for l in linear] - assert next(out_iter, None) is None - return [None] + out - -def _avals_short(avals): - to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))() - return ' '.join(_map(to_str, avals)) - -def _cond_typecheck(*avals, branches, linear): - tc = partial(_typecheck_param, 'cond') - tc(branches, 'branches', 'tuple of ClosedJaxpr', - type(branches) is tuple and - all(type(x) is core.ClosedJaxpr for x in branches)) - tc(linear, 'linear', 'tuple of bool', - type(linear) is tuple and all(type(x) is bool for x in linear)) - - if len(branches) == 0: - raise core.JaxprTypeError('cond requires at least one branch function') - if len(linear) + 1 != len(avals): - raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for ' - f'{len(avals) - 1} non-predicate operands') - - jaxpr0 = branches[0] - jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals) - jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals) - joined_effects = core.join_effects(*(b.effects for b in branches)) - disallowed_effects = joined_effects - allowed_effects - if disallowed_effects: - raise NotImplementedError( - f'Effects not supported in `cond`: {disallowed_effects}') - - for i, jaxpr in enumerate(branches[1:]): - if len(jaxpr0.in_avals) != len(jaxpr.in_avals): - raise core.JaxprTypeError( - f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, ' - f'branch {i+1} takes {len(jaxpr.in_avals)}') - if len(jaxpr0.out_avals) != len(jaxpr.out_avals): - raise core.JaxprTypeError( - f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, ' - f'branch {i+1} outputs {len(jaxpr.out_avals)}') - if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)): - raise core.JaxprTypeError( - f'cond branches 0 and {i+1} have mismatching input types: ' - f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}') - if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)): - raise core.JaxprTypeError( - f'cond branches 0 and {i+1} have mismatching output types: ' - f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}') - - if len(avals) != 1 + len(jaxpr0.in_avals): - raise core.JaxprTypeError( - f'cond called with {len(avals) - 1} non-predicate operands, ' - f'but branches take {len(jaxpr0.in_avals)} inputs') - - index_aval, *op_avals = avals - if index_aval.dtype != np.int32: - raise core.JaxprTypeError( - f'cond called with index of type {index_aval.dtype} instead of int32') - if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)): - raise core.JaxprTypeError( - f'cond branches take input types {jaxpr0_in_avals_str}, ' - f'called with operands of type {_avals_short(op_avals)}') - if any(b.effects != branches[0].effects for b in branches[1:]): - raise core.JaxprTypeError( - f'cond branches must have matching effect types: ' - f'{[b.effects for b in branches]}') - joined_effects = core.join_effects(*(b.effects for b in branches)) - return None, joined_effects - -def cond_bind(*args, branches, linear): - if config.jax_enable_checks: - avals = _map(core.get_aval, args) - _cond_typecheck(*avals, branches=branches, linear=linear) - for jaxpr in branches: - core.check_jaxpr(jaxpr.jaxpr) - return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear) - -cond_p = core.AxisPrimitive('cond') -cond_p.multiple_results = True -cond_p.def_impl(partial(xla.apply_primitive, cond_p)) -cond_p.def_effectful_abstract_eval(_cond_abstract_eval) -cond_p.def_custom_bind(cond_bind) -ad.primitive_jvps[cond_p] = _cond_jvp -ad.reducing_transposes[cond_p] = _cond_transpose -pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval -batching.axis_primitive_batchers[cond_p] = _cond_batching_rule -xla.register_initial_style_primitive(cond_p) -core.custom_typechecks[cond_p] = _cond_typecheck -pe.partial_eval_jaxpr_custom_rules[cond_p] = \ - partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond') - -def _cond_lowering(ctx, index, *args, branches, linear): - del linear # Unused. - joined_effects = core.join_effects(*(branch.effects for branch in branches)) - ordered_effects = [eff for eff in joined_effects - if eff in core.ordered_effects] - num_tokens = len(ordered_effects) - tokens_in = ctx.tokens_in.subset(ordered_effects) - output_token_types = [mlir.token_type() for _ in ordered_effects] - output_types = [ - *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)] - flat_output_types = util.flatten(output_types) - - # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks - # have no arguments; the computation within the block uses implicit - # captures. - case_op = mhlo.CaseOp(flat_output_types, index=index, - num_branches=len(branches)) - name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond') - for i, jaxpr in enumerate(branches): - branch = case_op.regions[i].blocks.append() - with ir.InsertionPoint(branch): - sub_ctx = ctx.module_context.replace( - name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun')) - out_vals, tokens_out = mlir.jaxpr_subcomp( - sub_ctx, jaxpr.jaxpr, tokens_in, - _map(mlir.ir_constants, jaxpr.consts), - *_map(mlir.wrap_singleton_ir_values, args)) - out_tokens = [tokens_out.get(eff) for eff in ordered_effects] - out_vals = [*out_tokens, *out_vals] - mhlo.ReturnOp(util.flatten(out_vals)) - - tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types)) - tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens]) - ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens))) - return outputs - -mlir.register_lowering(cond_p, _cond_lowering) - - - -### scan - -Carry = TypeVar('Carry') -X = TypeVar('X') -Y = TypeVar('Y') - -@api_boundary -def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], - init: Carry, - xs: X, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1) -> Tuple[Carry, Y]: - """Scan a function over leading array axes while carrying along state. +def fori_loop(lower, upper, body_fun, init_val): + """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. The `Haskell-like type signature`_ in brief is .. code-block:: haskell - scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) + fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a - where we use [t] here to denote the type t with an additional leading axis. - That is, if t is an array type then [t] represents the type with an additional - leading axis, and if t is a pytree (container) type with array leaves then [t] - represents the type with the same pytree structure and corresponding leaves - each with an additional leading axis. + The semantics of ``fori_loop`` are given by this Python implementation:: - When ``a`` is an array type or None, and ``b`` is an array type, the semantics - of ``scan`` are given roughly by this Python implementation:: + def fori_loop(lower, upper, body_fun, init_val): + val = init_val + for i in range(lower, upper): + val = body_fun(i, val) + return val - def scan(f, init, xs, length=None): - if xs is None: - xs = [None] * length - carry = init - ys = [] - for x in xs: - carry, y = f(carry, x) - ys.append(y) - return carry, np.stack(ys) + Unlike that Python version, ``fori_loop`` is implemented in terms of either a + call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the + trip count is static (meaning known at tracing time, perhaps because ``lower`` + and ``upper`` are Python integer literals) then the ``fori_loop`` is + implemented in terms of ``scan`` and reverse-mode autodiff is supported; + otherwise, a ``while_loop`` is used and reverse-mode autodiff is not + supported. See those functions' docstrings for more information. - Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree - types, and so multiple arrays can be scanned over at once and produce multiple - output arrays. (None is actually an empty pytree.) - - Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to - a single XLA While HLO. That makes it useful for reducing compilation times - for jit-compiled functions, since native Python loop constructs in an ``@jit`` - function are unrolled, leading to large XLA computations. - - Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype - across all iterations (and not just be consistent up to NumPy rank/shape - broadcasting and dtype promotion rules, for example). In other words, the type - ``c`` in the type signature above represents an array with a fixed shape and - dtype (or a nested tuple/list/dict container data structure with a fixed - structure and arrays with fixed shape and dtype at the leaves). + Also unlike the Python analogue, the loop-carried value ``val`` must hold a + fixed shape and dtype across all iterations (and not just be consistent up to + NumPy rank/shape broadcasting and dtype promotion rules, for example). In + other words, the type ``a`` in the type signature above represents an array + with a fixed shape and dtype (or a nested tuple/list/dict container data + structure with a fixed structure and arrays with fixed shape and dtype at the + leaves). .. note:: - :py:func:`scan` compiles ``f``, so while it can be combined with + :py:func:`fori_loop` compiles ``body_fun``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. Args: - f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning - that ``f`` accepts two arguments where the first is a value of the loop - carry and the second is a slice of ``xs`` along its leading axis, and that - ``f`` returns a pair where the first element represents a new value for - the loop carry and the second represents a slice of the output. - init: an initial loop carry value of type ``c``, which can be a scalar, - array, or any pytree (nested Python tuple/list/dict) thereof, representing - the initial loop carry value. This value must have the same structure as - the first element of the pair returned by ``f``. - xs: the value of type ``[a]`` over which to scan along the leading axis, - where ``[a]`` can be an array or any pytree (nested Python - tuple/list/dict) thereof with consistent leading axis sizes. - length: optional integer specifying the number of loop iterations, which - must agree with the sizes of leading axes of the arrays in ``xs`` (but can - be used to perform scans where no input ``xs`` are needed). - reverse: optional boolean specifying whether to run the scan iteration - forward (the default) or in reverse, equivalent to reversing the leading - axes of the arrays in both ``xs`` and in ``ys``. - unroll: optional positive int specifying, in the underlying operation of the - scan primitive, how many scan iterations to unroll within a single - iteration of a loop. + lower: an integer representing the loop index lower bound (inclusive) + upper: an integer representing the loop index upper bound (exclusive) + body_fun: function of type ``(int, a) -> a``. + init_val: initial loop carry value of type ``a``. Returns: - A pair of type ``(c, [b])`` where the first element represents the final - loop carry value and the second element represents the stacked outputs of - the second output of ``f`` when scanned over the leading axis of the inputs. + Loop value from the final iteration, of type ``a``. .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature """ - if not callable(f): - raise TypeError("lax.scan: f argument should be a callable.") - xs_flat, xs_tree = tree_flatten(xs) + if not callable(body_fun): + raise TypeError("lax.fori_loop: body_fun argument should be callable.") + # TODO(phawkins): perhaps do more type checking here, better error messages. + lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower)) + upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper)) + if lower_dtype != upper_dtype: + msg = ("lower and upper arguments to fori_loop must have equal types, " + "got {} and {}") + raise TypeError(msg.format(lower_dtype.name, upper_dtype.name)) - try: - lengths = [x.shape[0] for x in xs_flat] - except AttributeError as err: - msg = "scan got value with no leading axis to scan over: {}." - raise ValueError( - msg.format(', '.join(str(x) for x in xs_flat - if not hasattr(x, 'shape')))) from err - - if length is not None: - length = int(length) - if not all(length == l for l in lengths): - msg = ("scan got `length` argument of {} which disagrees with " - "leading axis sizes {}.") - raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) - else: - unique_lengths = set(lengths) - if len(unique_lengths) > 1: - msg = "scan got values with different leading axis sizes: {}." - raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) - elif len(unique_lengths) == 0: - msg = "scan got no values to scan over and `length` not provided." - raise ValueError(msg) + # If we can specialize on the trip count, call scan instead of a while_loop + # to enable efficient reverse-mode differentiation. + if (isinstance(core.get_aval(lower), ConcreteArray) and + isinstance(core.get_aval(upper), ConcreteArray)): + try: + lower_ = int(lower) + upper_ = int(upper) + except TypeError: + use_scan = False else: - length, = unique_lengths - - if config.jax_disable_jit: - if length == 0: - raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.") - carry = init - ys = [] - maybe_reversed = reversed if reverse else lambda x: x - for i in maybe_reversed(range(length)): - xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat] - carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) - ys.append(y) - stack = lambda *ys: jax.numpy.stack(ys) - stacked_y = tree_map(stack, *maybe_reversed(ys)) - return carry, stacked_y - - x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat] - x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat] - x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes)) - - def _create_jaxpr(init): - init_flat, init_tree = tree_flatten(init) - in_flat, in_tree = tree_flatten((init, xs)) - - carry_avals = tuple(_map(_abstractify, init_flat)) - jaxpr, consts, out_tree = _initial_style_jaxpr( - f, in_tree, carry_avals + x_avals, "scan") - out_tree_children = out_tree.children() - if len(out_tree_children) != 2: - msg = "scan body output must be a pair, got {}." - raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves] - return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children - - # The carry input and output avals must match exactly. However, we want to account for - # the case when init contains weakly-typed values (e.g. Python scalars), with avals that - # may not match the output despite being compatible by virtue of their weak type. - # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if - # necessary, a second time with modified init values. - init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) - new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out) - if changed: - new_init = tree_unflatten(init_tree, new_init_flat) - init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init) - in_flat, jaxpr, consts, out_tree, out_tree_children = rest - - _check_tree_and_avals("scan carry output and input", - # Extract the subtree and avals for the first element of the return tuple - out_tree_children[0], carry_avals_out, - init_tree, carry_avals) - disallowed_effects = jaxpr.effects - allowed_effects - if disallowed_effects: - raise NotImplementedError( - f'Effects not supported in `scan`: {disallowed_effects}') - - out = scan_p.bind(*consts, *in_flat, - reverse=reverse, length=length, jaxpr=jaxpr, - num_consts=len(consts), num_carry=len(init_flat), - linear=(False,) * (len(consts) + len(in_flat)), - unroll=unroll) - return tree_unflatten(out_tree, out) - -def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear, - f_impl, x_avals, y_avals): - consts, init, xs = split_list(args, [num_consts, num_carry]) - - carry = init - ys = [] - - for i in range(length): - i_ = length - i - 1 if reverse else i - x = _map(partial(_index_array, i_), x_avals, xs) - out = f_impl(*consts, *carry, *x) - carry, y = split_list(out, [num_carry]) - ys.append(y) - - ys = list(reversed(ys)) if reverse else ys - ys = list(zip(*ys)) - ys = _map(_stack, y_avals, ys) - return (*carry, *ys) - -def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear, - f_impl, x_avals, y_avals): - consts, init, xs = split_list(args, [num_consts, num_carry]) - - def cond_fun(vals): - i, *_ = vals - return i < length - - def body_fun(vals): - [i], carry, ys = split_list(vals, [1, num_carry]) - i_ = length - i - 1 if reverse else i - x = _map(partial(_dynamic_index_array, i_), x_avals, xs) - out_flat = f_impl(*consts, *carry, *x) - carry_out, y_updates = split_list(out_flat, [num_carry]) - ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates) - return [i + 1] + carry_out + ys_out - - ys_init = _map(partial(_empty_array, length), y_avals) - if length == 0: - return init + ys_init + use_scan = True else: - init_val = [lax._const(length, 0)] + init + ys_init - _, *outs = while_loop(cond_fun, body_fun, init_val) - return outs + use_scan = False -def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry, - linear, block_length, f_impl, x_avals, y_avals): - consts, init, xs = split_list(args, [num_consts, num_carry]) + if use_scan: + if config.jax_disable_jit and upper_ == lower_: + # non-jit implementation of scan does not support length=0 + return init_val - num_blocks, rem = divmod(length, block_length) - assert rem == 0 - - partition = partial(_partition_leading, num_blocks, block_length) - xs_block = _map(partition, x_avals, xs) - - prepend_aval = partial(_prepend_dim_to_aval, block_length) - x_block_avals = _map(prepend_aval, x_avals) - y_block_avals = _map(prepend_aval, y_avals) - - f_impl_block = partial( - _scan_impl_unrolled, reverse=reverse, length=block_length, - num_consts=num_consts, num_carry=num_carry, linear=linear, - f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) - - outs = _scan_impl_loop( - *consts, *init, *xs_block, reverse=reverse, length=num_blocks, - num_consts=num_consts, num_carry=num_carry, linear=linear, - f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals) - - carry, ys_blocks = split_list(outs, [num_carry]) - combine = partial(_combine_leading, num_blocks, block_length) - ys = _map(combine, y_avals, ys_blocks) - return (*carry, *ys) - -def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, - unroll): - _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) - _, y_avals = split_list(jaxpr.out_avals, [num_carry]) - f_impl = core.jaxpr_as_fun(jaxpr) - - if unroll == 1: - return _scan_impl_loop( - *args, reverse=reverse, length=length, num_consts=num_consts, - num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, - y_avals=y_avals) - - consts, init, xs = split_list(args, [num_consts, num_carry]) - num_blocks, rem = divmod(length, unroll) - length_div = num_blocks * unroll - - if rem > 0: - if reverse: - split = partial(_split_leading_dim, rem) - xs_rem, xs = unzip2(_map(split, x_avals, xs)) - else: - split = partial(_split_leading_dim, length_div) - xs, xs_rem = unzip2(_map(split, x_avals, xs)) - - outs = _scan_impl_block_unrolled( - *consts, *init, *xs, reverse=reverse, length=length_div, - num_consts=num_consts, num_carry=num_carry, linear=linear, - block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) - - carry, ys = split_list(outs, [num_carry]) - - if rem > 0: - outs = _scan_impl_unrolled( - *consts, *carry, *xs_rem, reverse=reverse, length=rem, - num_consts=num_consts, num_carry=num_carry, linear=linear, - f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) - carry, ys_rem = split_list(outs, [num_carry]) - if reverse: - ys = _map(_concatenate, y_avals, ys_rem, ys) - else: - ys = _map(_concatenate, y_avals, ys, ys_rem) - - return (*carry, *ys) - -def _stack(aval, vals): - vals = [lax.expand_dims(x, (0,)) for x in vals] - return lax.concatenate(vals, 0) - -def _concatenate(aval, x1, x2): - return lax.concatenate([x1, x2], 0) - -def _split_leading_dim(i, aval, x): - assert x.ndim >= 1 - return (slicing.slice_in_dim(x, 0, i), - slicing.slice_in_dim(x, i, x.shape[0])) - -def _dynamic_index_array(i, aval, x): - return slicing.dynamic_index_in_dim(x, i, keepdims=False) - -def _index_array(i, aval, x): - return slicing.index_in_dim(x, i, keepdims=False) - -def _empty_array(sz, aval): - return lax.full((sz,) + aval.shape, 0, aval.dtype) - -def _update_array(i, aval, xs, x): - return slicing.dynamic_update_index_in_dim(xs, x, i, 0) - -def _partition_leading(sz0, sz1, aval, x): - assert x.ndim >= 1 - assert x.shape[0] == sz0 * sz1 - return lax.reshape(x, (sz0, sz1, *x.shape[1:])) - -def _combine_leading(sz0, sz1, aval, x): - assert x.ndim >= 2 - assert x.shape[0] == sz0 - assert x.shape[1] == sz1 - return lax.collapse(x, 0, 2) - -def _prepend_dim_to_aval(sz, aval): - return core.unmapped_aval(sz, core.no_axis_name, 0, aval) - -def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, - linear, unroll): - carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) - ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) - return carry_avals + ys_avals, jaxpr.effects - -def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, - linear, unroll): - num_xs = len(jaxpr.in_avals) - num_carry - num_consts - num_ys = len(jaxpr.out_avals) - num_carry - nonzeros = [type(t) is not ad_util.Zero for t in tangents] - const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry]) - - # Fixpoint computation of which carry are not ad.zero: either - # non-zero from init, or the carry out is non-zero. Each iteration promotes - # at least one carry to non-zero. We need at most len(carry) iterations, - # but we need one last iteration to prepare the jaxpr based on the final - # carry_nz. - carry_nz = init_nz - for _ in range(1 + len(carry_nz)): - nonzeros = const_nz + carry_nz + xs_nz - jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr( - jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys) - carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:] - if carry_nz_out == carry_nz: - break - else: - carry_nz = _map(operator.or_, carry_nz, carry_nz_out) + (_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val), + None, length=upper_ - lower_) else: - assert False, "Fixpoint not reached" - - tangents = [ad.instantiate_zeros(t) if nz else t - for t, nz in zip(tangents, nonzeros)] - - consts, init, xs = split_list(primals, [num_consts, num_carry]) - all_tangents = split_list(tangents, [num_consts, num_carry]) - consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents) - - jaxpr_jvp_rearranged = ad.rearrange_binders( - jaxpr_jvp, - [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)], - [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)]) - - consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) - jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot) - + init_linear + [True] * len(init_dot) - + xs_linear + [True] * len(xs_dot)) - - out_flat = scan_p.bind( - *(consts + consts_dot + init + init_dot + xs + xs_dot), - reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged, - num_consts=num_consts + len(consts_dot), - num_carry=num_carry + len(init_dot), - linear=jaxpr_jvp_linear, unroll=unroll) - - carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) - primals_out = carry + ys - tangents_out_iter = iter(carry_dot + ys_dot) - tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p) - for p, nz in zip(primals_out, nonzeros_out)] - return primals_out, tangents_out - -def _prune_zeros(ts): - return [t for t in ts if type(t) is not ad_util.Zero] - -def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, - jaxpr, linear, unroll): - num_ys = len(jaxpr.out_avals) - num_carry - unknowns = [not t.pval.is_known() for t in tracers] - const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) - - # Fixpoint computation of which carry elements are unknown. Each iteration - # promotes at least one carry to unknown. We need at most len(carry) - # iterations, but we need one last iteration to prepare the jaxpr based on the - # final carry_uk. - carry_uk = init_uk - for _ in range(1 + len(carry_uk)): - unknowns = const_uk + carry_uk + xs_uk - jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits( - jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) - carry_uk_out, ys_uk = split_list(out_uk, [num_carry]) - if carry_uk_out == carry_uk: - break - else: - carry_uk = _map(operator.or_, carry_uk, carry_uk_out) - else: - assert False, "Fixpoint not reached" - num_res = len(res_avals) - del res_avals, carry_uk_out - - # Instantiate those inputs which must be treated as unknown from the fixpoint. - tracers = [trace.instantiate_const(t) if uk else t - for t, uk in zip(tracers, unknowns)] - - # The residual inputs and outputs of the jaxprs produced haven't yet been - # adapted to the scan calling convention; in particular, jaxpr_known has its - # residual outputs all at the end, meaning they're extensive outputs (which is - # fully general but may be wasteful for residuals which are loop-invariant) - # while jaxpr_unknown has its corresponding residual inputs at the front (just - # as a convention with partial_eval_jaxpr_nounits), making them constant - # inputs. To make them consistent, we move the residual inputs on - # jaxpr_unknown to the end, even though we may move some back in the sequel. - jaxpr_unknown = pe.move_binders_to_back( - jaxpr_unknown, [True] * num_res + [False] * sum(unknowns)) - - # At this point, all residuals are treated as extensive outputs of jaxpr_known - # (and extensive inputs to jaxpr_unknown). But residuals that are loop- - # invariant can be hoisted out of the scan, rather than letting them get - # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't - # want to broadcast the matrix!). So, outside the loop we perform a partial - # evaluation with known 'const' inputs (but all other inputs unknown). - const_pvals = [pe.PartialVal.known(t.pval.get_known()) - for t in tracers[:num_consts] if t.pval.is_known()] - other_pvals = [pe.PartialVal.unknown(aval) - for aval in jaxpr_known.in_avals[len(const_pvals):]] - with source_info_util.reset_name_stack(): - jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( - lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals, - instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) - jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) - # The above trace_to_jaxpr_nounits call computed loop-invariant residuals - # (known values in invar_pvals_out) and also computed loop-invariant values - # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the - # previous consts). We need to collect the computed inteisive residuals, and - # move corresponding intensive residual binders in jaxpr_unknown to the front. - res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] - intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] - jaxpr_unknown = pe.move_binders_to_front( - jaxpr_unknown, - [False] * sum(unknowns) + [pval.is_known() for pval in res_pvals]) - del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals - # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and - # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown. - - # As another optimization, for any extensive inputs that are just forwarded to - # extensive outputs, to avoid a copy (which would be looping over - # dynamic-update-slice) we'd rather forward the input tracer/value. That means - # pruning some outputs from jaxpr_known here, and updating `out_flat` below. - fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr) - # Prune fwds_known to include only extensive input to extensive output. - fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and - in_idx is not None and - in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk) - else None for out_idx, in_idx in enumerate(fwds_known)] - # Drop any extensive output we can instead get by forwarding an input. - # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint - jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts - jaxpr_known_.outvars = [x for x, i in zip(jaxpr_known_.outvars, fwds_known) - if i is None] - jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ()) - del jaxpr_known_ - # We use `fwds_known` below when forming the output of scanning jaxpr_known. - - # Run the known part of the scan (if it has any outputs or effects). - known_inputs = (list(jaxpr_known_consts) + - [t.pval.get_known() for t in tracers[num_consts:] - if t.pval.is_known()]) - if not jaxpr_known.out_avals and not jaxpr_known.effects: - out_known = [] - else: - linear_known = [False] * len(known_inputs) # conservative! - out_known = scan_p.bind( - *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, - num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), - linear=tuple(linear_known), unroll=unroll) - del linear_known - # Complete the known output by filling in forwarded values using fwds_known. - out_known_iter = iter(out_known) - out_known = [next(out_known_iter) if f is None - else _maybe_put(known_inputs[f]) for f in fwds_known] - assert next(out_known_iter, None) is None - del known_inputs, out_known_iter - - # Split known outputs from residuals. - out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)]) - assert len(intensive_res) + len(extensive_res) == num_res - - # Create input tracers for jaxpr_unknown bind. - unknown_inputs = [t for t in tracers if not t.pval.is_known()] - intensive_res = _map(trace.new_instantiated_const, intensive_res) - extensive_res = _map(trace.new_instantiated_const, extensive_res) - # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. - carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) - ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval) - for y_aval in y_avals] - out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in itertools.chain(carry_avals, ys_avals)] - del carry_avals, y_avals - # Create equation. - linear_unknown = tuple([False] * len(intensive_res) + - [l for l, uk in zip(linear, unknowns) if uk] + - [False] * len(extensive_res)) - name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] - source = source_info_util.current().replace(name_stack=name_stack) - assert len(out_tracers) == len(jaxpr_unknown.out_avals) - eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res], - out_tracers, scan_p, - dict(reverse=reverse, length=length, unroll=unroll, - jaxpr=jaxpr_unknown, linear=linear_unknown, - num_consts=len(intensive_res) + sum(const_uk), - num_carry=sum(carry_uk)), - jaxpr_unknown.effects, source) - for t in out_tracers: t.recipe = eqn - - # Merge known and unknown outputs into final result. - return util.merge_lists(out_uk, out_known, out_tracers) - -def _maybe_put(x): - if isinstance(x, np.ndarray): - return jax.device_put(x, jax.devices('cpu')[0]) - else: - return x - -def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts, - num_carry, jaxpr, linear, unroll): - # we've only implemented transposing scans with specific lin/nonlin patterns - consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) - num_ires = len(consts_lin) - sum(consts_lin) - num_eres = len(xs_lin) - sum(xs_lin) - if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires): - raise NotImplementedError - if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: - raise NotImplementedError - if not all(init_lin): - pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963 - - consts, _, xs = split_list(args, [num_consts, num_carry]) - ires, _ = split_list(consts, [num_ires]) - _, eres = split_list(xs, [sum(xs_lin)]) - assert not any(ad.is_undefined_primal(r) for r in ires) - assert not any(ad.is_undefined_primal(r) for r in eres) - - carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) - ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) - ct_carry, ct_ys = split_list(cts, [num_carry]) - ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry) - ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys) - ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) - - # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) - # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) - jaxpr_trans = _transpose_scan_jaxpr( - num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes) - linear_trans = ([False] * num_ires + - [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + - [False] * num_eres) - - outs = scan_p.bind( - *(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse, - length=length, jaxpr=jaxpr_trans, num_consts=num_ires, - num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans), - unroll=unroll) - ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) - return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres - -# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b) -# -> ([res1, CT c, CT b, res2] -> [CT c, CT a]) -def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes): - num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2 - # TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals - # if an axis isn't reduced - res1_avals, c_avals, a_avals, res2_avals = split_list( - jaxpr.in_avals, [num_res1, num_c, num_a]) - num_b = len(jaxpr.out_avals) - b_avals = list(jaxpr.out_avals) - - @lu.wrap_init - def transposed(*res1_cbar_bbar_res2): - res1, c_bar, b_bar, res2 = split_list( - res1_cbar_bbar_res2, [num_res1, num_c, num_b]) - primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] + - [ad.UndefinedPrimal(aval) for aval in a_avals] + res2) - cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, - primals, b_bar) - _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) - a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) - c_bar = _map(ad.instantiate_zeros_aval, c_avals, - _map(ad.add_tangents, c_bar, new_c_bar)) - return c_bar + a_bar - return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals) - -def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]): - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals) - return core.ClosedJaxpr(jaxpr, consts) - - -def _scan_batching_rule(axis_size, axis_name, main_type, args, dims, reverse, length, - jaxpr, num_consts, num_carry, linear, unroll): - num_ys = len(jaxpr.out_avals) - num_carry - orig_batched = [d is not batching.not_mapped for d in dims] - const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry]) - - # Fixpoint computation of which carry are batched: either - # batched from init, or the carry out is batched. Each iteration promotes - # at least one carry to batched. We need at most len(carry) iterations, - # but we need one last iteration to prepare the jaxpr based on the final - # carry_batched. - carry_batched = init_batched - for _ in range(1 + len(carry_batched)): - batched = const_batched + carry_batched + xs_batched - jaxpr_batched, batched_out = batching.batch_jaxpr( - jaxpr, axis_size, batched, - instantiate=carry_batched + [False] * num_ys, - axis_name=axis_name, - main_type=main_type) - carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] - if carry_batched_out == carry_batched: - break - else: - carry_batched = _map(operator.or_, carry_batched, carry_batched_out) - else: - assert False, "Fixpoint not reached" - - consts, init, xs = split_list(args, [num_consts, num_carry]) - consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) - new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 - else x for x, d in zip(consts, consts_bdims)] - new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched - else batching.moveaxis(x, d, 0) if now_batched else x - for x, d, was_batched, now_batched in - zip(init, init_bdims, init_batched, carry_batched)] - new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 - else x for x, d in zip(xs, xs_bdims)] - new_args = new_consts + new_init + new_xs - - outs = scan_p.bind( - *new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched, - num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll) - carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] - ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] - return outs, carry_bdims + ys_bdims - -def _scan_masking_rule(padded_vals, logical_shapes, reverse, length, - jaxpr, num_consts, num_carry, linear, unroll): - dynamic_length, = masking.shape_as_value((length,)) - masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry) - consts, init, xs = split_list(padded_vals, [num_consts, num_carry]) - max_length, = {x.shape[0] for x in xs} - const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) - dynamic_length = lax.convert_element_type(dynamic_length, dtypes.int_) - out_vals = scan_p.bind(dynamic_length, *consts, dtypes.int_(0), *init, *xs, - reverse=reverse, length=max_length, jaxpr=masked_jaxpr, - num_consts=1 + num_consts, num_carry=1 + num_carry, - linear=tuple([False] + const_linear + [False] + init_linear + xs_linear), - unroll=unroll) - return out_vals[1:] - -def _masked_scan_jaxpr(jaxpr, num_consts, num_carry): - fun = core.jaxpr_as_fun(jaxpr) - - @lu.wrap_init - def masked(*args): - [dynamic_length], consts, [i], carry, xs = split_list( - args, [1, num_consts, 1, num_carry]) - out = fun(*(consts + carry + xs)) - new_carry, ys = split_list(out, [num_carry]) - new_carry = [lax.select(i < dynamic_length, new_c, c) - for new_c, c in zip(new_carry, carry)] - return [i + 1] + new_carry + ys - - aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_)) - const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) - return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals) - -def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): - padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts)) - return scan_p.bind(*args, jaxpr=padded_jaxpr, **params) - -def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn - ) -> Tuple[List[bool], core.JaxprEqn]: - if not config.after_neurips: - return [True] * len(eqn.params['jaxpr'].in_avals), eqn - jaxpr = eqn.params['jaxpr'] - num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] - num_xs = len(jaxpr.in_avals) - num_consts - num_carry - used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry]) - for i in range(1 + num_carry): - used_outputs = used_carry_out + used_extensive_out - jaxpr_dce, used_inputs = pe.dce_jaxpr( - jaxpr.jaxpr, used_outputs, - instantiate=[False] * num_consts + used_carry_out + [False] * num_xs) - used_consts, used_carry_in, used_extensive_in = \ - split_list(used_inputs, [num_consts, num_carry]) - if list(used_carry_in) == list(used_carry_out): - break - else: - used_carry_out = _map(operator.or_, used_carry_out, used_carry_in) - else: - assert False, "Fixpoint not reached" - core.check_jaxpr(jaxpr.jaxpr) - - new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u] - new_params = dict(eqn.params, num_consts=sum(used_consts), - num_carry=sum(used_carry_in), linear=tuple(new_linear), - jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts)) - new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) - if used], - [v for v, used in zip(eqn.outvars, used_outputs) - if used], - eqn.primitive, new_params, eqn.effects, - eqn.source_info) - assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals ) - assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals) - return used_inputs, new_eqn - -def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry, - jaxpr, linear, unroll): - tc = partial(_typecheck_param, 'scan') - tc(reverse, 'reverse', 'bool', type(reverse) is bool) - tc(num_consts, 'num_consts', 'non-negative int', - type(num_consts) is int and num_consts >= 0) - tc(num_carry, 'num_carry', 'non-negative int', - type(num_carry) is int and num_carry >= 0) - tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr) - tc(linear, 'linear', 'tuple of bool', - type(linear) is tuple and all(type(x) is bool for x in linear)) - tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0) - - length_types = (int, masking.Poly) if bind_time else (int,) - tc(length, 'length', 'non-negative int', - type(length) in length_types and length >= 0) - - if len(linear) != len(avals): - raise core.JaxprTypeError( - f'scan param linear has length {len(linear)} for {len(avals)} operands') - - const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry]) - const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list( - jaxpr.in_avals, [num_consts, num_carry]) - carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry]) - x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) - - if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): - raise core.JaxprTypeError( - f'scan input carry input and output types mismatch: ' - f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}') - if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)): - raise core.JaxprTypeError( - f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n' - f'called with consts of type\n{_avals_short(const_avals)}') - if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)): - raise core.JaxprTypeError( - f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n' - f'called with initial carry of type\n{_avals_short(init_avals)}') - if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)): - raise core.JaxprTypeError( - f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n' - f'called with sequence of type\n{_avals_short(x_avals)}') - return None, jaxpr.effects - -def scan_bind(*args, **params): - if config.jax_enable_checks: - avals = _map(core.get_aval, args) - _scan_typecheck(True, *avals, **params) - core.check_jaxpr(params['jaxpr'].jaxpr) - return core.AxisPrimitive.bind(scan_p, *args, **params) - -scan_p = core.AxisPrimitive("scan") -scan_p.multiple_results = True -scan_p.def_custom_bind(scan_bind) -scan_p.def_impl(partial(xla.apply_primitive, scan_p)) -scan_p.def_effectful_abstract_eval(_scan_abstract_eval) -ad.primitive_jvps[scan_p] = _scan_jvp -ad.reducing_transposes[scan_p] = _scan_transpose -pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval -xla.register_initial_style_primitive(scan_p) -mlir.register_lowering(scan_p, - mlir.lower_fun(_scan_impl, multiple_results=True)) -batching.axis_primitive_batchers[scan_p] = _scan_batching_rule -masking.masking_rules[scan_p] = _scan_masking_rule -core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) -pe.partial_eval_jaxpr_custom_rules[scan_p] = \ - partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan') -pe.padding_rules[scan_p] = _scan_padding_rule -pe.dce_rules[scan_p] = _scan_dce_rule + _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun), + (lower, upper, init_val)) + return result +### map and miscellanous rules @api_boundary def map(f, xs): @@ -2222,509 +1556,7 @@ def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore -def _show_diff(array1, array2): - if core.typematch(array1, array2): - return f"{array1}" - return f"DIFFERENT {array1} vs. {array2}" - -def _check_tree_and_avals(what, tree1, avals1, tree2, avals2): - """Raises TypeError if (tree1, avals1) does not match (tree2, avals2). - - Corresponding `tree` and `avals` must match in the sense that the number of - leaves in `tree` must be equal to the length of `avals`. `what` will be - prepended to details of the mismatch in TypeError. - """ - if tree1 != tree2: - raise TypeError( - f"{what} must have same type structure, got {tree1} and {tree2}.") - if not all(_map(core.typematch, avals1, avals2)): - diff = tree_map(_show_diff, tree_unflatten(tree1, avals1), - tree_unflatten(tree2, avals2)) - raise TypeError(f"{what} must have identical types, got\n{diff}.") - - -def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False): - if has_aux: - actual_tree_children = actual_tree.children() - - if len(actual_tree_children) == 2: - # select first child as result tree - actual_tree = tree_structure(actual_tree_children[0]) - else: - raise ValueError( - f"{func_name}() produced a pytree with structure " - f"{actual_tree}, but a pytree tuple with auxiliary " - f"output was expected because has_aux was set to True.") - - if actual_tree != expected_tree: - raise TypeError( - f"{func_name}() output pytree structure must match {expected_name}, " - f"got {actual_tree} and {expected_tree}.") - - -def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): - """Promote weakly-typed in_vals to be compatible with out_avals. - - Args: - in_vals : flattened list of input values. - in_avals : corresponding list of avals. - out_avals : list of target output avals. - Returns: - in_vals_new : flattened list of modified in_vals with no weak types. - changed : bool; true if in_vals required modification. - """ - if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals): - # Calling function is responsible for catching this. - return in_vals, False - weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals)) - if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)] - if not weak_mismatches: - return in_vals, False - for i in weak_mismatches: - new_dtype = dtypes.result_type(in_vals[i], out_avals[i]) - in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype) - return in_vals, True - - -_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s') - - -def _split_root_args(args, const_lengths): - params_list = split_list(args, list(const_lengths)) - return _RootTuple(*params_list[:-1]), params_list[-1] - - -@api_boundary -def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): - """Differentiably solve for a roots of a function. - - This is a low-level routine, mostly intended for internal use in JAX. - Gradients of custom_root() are defined with respect to closed-over variables - from the provided function ``f`` via the implicit function theorem: - https://en.wikipedia.org/wiki/Implicit_function_theorem - - Args: - f: function for which to find a root. Should accept a single argument, - return a tree of arrays with the same structure as its input. - initial_guess: initial guess for a zero of f. - solve: function to solve for the roots of f. Should take two positional - arguments, f and initial_guess, and return a solution with the same - structure as initial_guess such that func(solution) = 0. In other words, - the following is assumed to be true (but not checked):: - - solution = solve(f, initial_guess) - error = f(solution) - assert all(error == 0) - - tangent_solve: function to solve the tangent system. Should take two - positional arguments, a linear function ``g`` (the function ``f`` - linearized at its root) and a tree of array(s) ``y`` with the same - structure as initial_guess, and return a solution ``x`` such that - ``g(x)=y``: - - - For scalar ``y``, use ``lambda g, y: y / g(1.0)``. - - For vector ``y``, you could use a linear solve with the Jacobian, if - dimensionality of ``y`` is not too large: - ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. - has_aux: bool indicating whether the ``solve`` function returns - auxiliary data like solver diagnostics as a second argument. - - Returns: - The result of calling solve(f, initial_guess) with gradients defined via - implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. - """ - guess_flat, in_args_tree = tree_flatten((initial_guess,)) - guess_avals = tuple(_map(_abstractify, guess_flat)) - f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( - f, in_args_tree, guess_avals) - - in_tree, = treedef_children(in_args_tree) - _check_tree("f", "initial_guess", out_tree, in_tree, False) - - solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( - partial(solve, f), in_args_tree, guess_avals) - _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) - - def linearize_and_solve(x, b): - unchecked_zeros, f_jvp = jax.linearize(f, x) - return tangent_solve(f_jvp, b) - - l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( - linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2) - _check_tree("tangent_solve", "x", out_tree, in_tree, False) - - all_consts = [f_consts, solve_consts, l_and_s_consts] - const_lengths = _RootTuple(*_map(len, all_consts)) - jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr) - - solution_flat = _custom_root( - const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat)) - return tree_unflatten(solution_tree, solution_flat) - - -@partial(jax.custom_jvp, nondiff_argnums=(0, 1)) -def _custom_root(const_lengths, jaxprs, *args): - params, initial_guess = _split_root_args(args, const_lengths) - solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess)) - return solution - - -@_custom_root.defjvp -def _root_jvp(const_lengths, jaxprs, primals, tangents): - params, _ = _split_root_args(primals, const_lengths) - sol = _custom_root(const_lengths, jaxprs, *primals) - - f_out_vals = len(jaxprs.f.out_avals) - solution, aux = split_list(sol, [f_out_vals]) - - params_dot, _ = _split_root_args(tangents, const_lengths) - - # F(m, u) = 0 # system of equations in u, parameterized by m - # # solution is u*(m) defined in a neighborhood - # F(m, u*(m)) = 0 # satisfied in a neighborhood - # - # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above - # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange - # - # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp - - f = core.jaxpr_as_fun(jaxprs.f) - linearize_and_solve = partial( - core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s) - f_at_solution = lambda *params: f(*params, *solution) - _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped( - params.f, params_dot.f) - solution_dot = _map( - operator.neg, linearize_and_solve(*solution, *rhs)) - # append aux, create symbolic zero tangents for the aux values - solution += aux - solution_dot += _map(lax.zeros_like_array, aux) - - return solution, solution_dot - - -class _LinearSolveTuple(collections.namedtuple( - '_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')): - - def transpose(self): - return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve) - - -def _split_linear_solve_args(args, const_lengths): - params_list = split_list(args, list(const_lengths)) - return _LinearSolveTuple(*params_list[:-1]), params_list[-1] - - -def _transpose_one_output(linear_fun, primals): - transpose_fun = jax.linear_transpose(linear_fun, primals) - def transposed_fun(x): - (y,) = transpose_fun(x) - return y - return transposed_fun - - -def _flatten(args): - return [x for arg in args for x in arg] - - -def _check_shapes(func_name, expected_name, actual, expected): - actual_shapes = _map(np.shape, tree_leaves(actual)) - expected_shapes = _map(np.shape, tree_leaves(expected)) - if actual_shapes != expected_shapes: - raise ValueError( - f"{func_name}() output shapes must match {expected_name}, " - f"got {actual_shapes} and {expected_shapes}") - - -@api_boundary -def custom_linear_solve( - matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False): - """Perform a matrix-free linear solve with implicitly defined gradients. - - This function allows for overriding or defining gradients for a linear - solve directly via implicit differentiation at the solution, rather than by - differentiating *through* the solve operation. This can sometimes be much faster - or more numerically stable, or differentiating through the solve operation - may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``). - - Required invariant:: - - x = solve(matvec, b) # solve the linear equation - assert matvec(x) == b # not checked - - Args: - matvec: linear function to invert. Must be differentiable. - b: constant right handle side of the equation. May be any nested structure - of arrays. - solve: higher level function that solves for solution to the linear - equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form - as ``b``. This function need not be differentiable. - transpose_solve: higher level function for solving the transpose linear - equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is - the transpose of the linear map ``matvec`` (computed automatically with - autodiff). Required for backwards mode automatic differentiation, unless - ``symmetric=True``, in which case ``solve`` provides the default value. - symmetric: bool indicating if it is safe to assume the linear map - corresponds to a symmetric matrix, i.e., ``matvec == vecmat``. - has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions - return auxiliary data like solver diagnostics as a second argument. - - Returns: - Result of ``solve(matvec, b)``, with gradients defined assuming that the - solution ``x`` satisfies the linear equation ``matvec(x) == b``. - """ - if transpose_solve is None and symmetric: - transpose_solve = solve - - b_flat, in_args_tree = tree_flatten((b,)) - b_avals = tuple(_map(_abstractify, b_flat)) - - tree, = treedef_children(in_args_tree) - - def _shape_checked(fun, name, has_aux): - def f(x): - y = fun(x) - _check_shapes(name, "b", y, b_flat) - return y - - def f_aux(x): - y, aux = fun(x) - _check_shapes(name, "b", y, b_flat) - return y, aux - - return f_aux if has_aux else f - - # no auxiliary data assumed for matvec - matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( - _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, - 'custom_linear_solve') - _check_tree("matvec", "b", out_tree, tree, False) - - solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( - _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, - 'custom_linear_solve') - _check_tree("solve", "b", out_tree, tree, has_aux) - - if transpose_solve is None: - vecmat_jaxpr = tr_solve_jaxpr = None - vecmat_consts = tr_solve_consts = [] - else: - if symmetric: - vecmat = matvec - vecmat_jaxpr = matvec_jaxpr - vecmat_consts = matvec_consts - else: - vecmat = _transpose_one_output(matvec, b) - vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( - vecmat, in_args_tree, b_avals, 'custom_linear_solve') - assert out_tree == tree - - tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( - _shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux), - in_args_tree, b_avals, 'custom_linear_solve') - _check_tree("transpose_solve", "b", out_tree, tree, has_aux) - - all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts] - const_lengths = _LinearSolveTuple(*_map(len, all_consts)) - jaxprs = _LinearSolveTuple( - matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) - - out_flat = linear_solve_p.bind( - *(_flatten(all_consts) + b_flat), - const_lengths=const_lengths, jaxprs=jaxprs) - - return tree_unflatten(out_tree, out_flat) - - -def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): - args_to_raise = args[sum(const_lengths):] - - # raise aux_args to shaped arrays as well if present - # number of aux args is the difference in out_avals - # of solve and matvec (since they map to the same vector space) - - num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) - if num_aux > 0: - args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return _map(raise_to_shaped, args_to_raise) - - -def _custom_linear_solve_impl(*args, const_lengths, jaxprs): - params, b = _split_linear_solve_args(args, const_lengths) - x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b)) - return x - - -def _tangent_linear_map(func, params, params_dot, *x): - """Compute the tangent of a linear map. - - Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``, - this function computes ``∂A @ x``. - """ - assert any(type(p) is not ad_util.Zero for p in params_dot) - zeros = _map(ad_util.Zero.from_value, x) - _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( - params + list(x), params_dot + zeros) - return out_tangent - - -def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs): - # A x - b = 0 - # ∂A x + A ∂x - ∂b = 0 - # ∂x = A^{-1} (∂b - ∂A x) - - kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs) - x = linear_solve_p.bind(*primals, **kwargs) - - params, _ = _split_linear_solve_args(primals, const_lengths) - params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths) - - num_x_leaves = len(b_dot) - # x is a flat tree with possible aux values appended - # since x_tree == b_tree == b_dot_tree, we can cut off - # aux values with len info provided by b_dot tree here - x_leaves, _ = split_list(x, [num_x_leaves]) - - if all(type(p) is ad_util.Zero for p in params_dot.matvec): - # no need to evaluate matvec_tangents - rhs = b_dot - else: - matvec_tangents = _tangent_linear_map( - core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves) - rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents)) - - x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs) - - # split into x tangents and aux tangents (these become zero) - dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves]) - - daux_leaves = _map(ad_util.Zero.from_value, daux_leaves) - - x_dot = dx_leaves + daux_leaves - - return x, x_dot - - -def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): - if jaxprs.transpose_solve is None: - raise TypeError('transpose_solve required for backwards mode automatic ' - 'differentiation of custom_linear_solve') - - params, b = _split_linear_solve_args(primals, const_lengths) - # split off symbolic zeros in the cotangent if present - x_cotangent, _ = split_list(cotangent, [len(b)]) - assert all(ad.is_undefined_primal(x) for x in b) - cotangent_b_full = linear_solve_p.bind( - *(_flatten(params.transpose()) + x_cotangent), - const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose()) - # drop aux values in cotangent computation - cotangent_b, _ = split_list(cotangent_b_full, [len(b)]) - return [None] * sum(const_lengths) + cotangent_b - - -def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims, - const_lengths, jaxprs): - orig_bat = [d is not batching.not_mapped for d in dims] - - params, b = _split_linear_solve_args(args, const_lengths) - params_dims, b_dims = _split_linear_solve_args(dims, const_lengths) - params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths) - - (matvec, vecmat, solve, solve_t) = jaxprs - (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat - - num_aux = len(solve.out_avals) - len(matvec.out_avals) - # Fixpoint computation of which parts of x and b are batched; we need to - # ensure this is consistent between all four jaxprs - b_bat = orig_b_bat - x_bat = [False] * len(solve.out_avals) - for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): - # Apply vecmat and solve -> new batched parts of x - solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( - solve, axis_size, solve_bat + b_bat, instantiate=x_bat, - axis_name=axis_name, main_type=main_type) - if vecmat is None: - vecmat_jaxpr_batched = None - x_bat_out = solve_x_bat - else: - vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( - vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat, - axis_name=axis_name, main_type=main_type) - # batch all aux data by default - x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) - - # Apply matvec and solve_t -> new batched parts of b - matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( - matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat, - axis_name=axis_name, main_type=main_type) - if solve_t is None: - solve_t_jaxpr_batched = None - b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) - else: - solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( - solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat, - axis_name=axis_name, main_type=main_type) - assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux - solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) - b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, - orig_b_bat) - if x_bat_out == x_bat and b_bat_out == b_bat: - break - else: - x_bat = x_bat_out - b_bat = b_bat_out - else: - assert False, "Fixedpoint not reached" - - batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched, - solve_jaxpr_batched, solve_t_jaxpr_batched) - - # Move batched axes to the front - new_params = [ - batching.moveaxis(x, d, 0) - if d is not batching.not_mapped and d != 0 else x - for x, d in zip(_flatten(params), _flatten(params_dims)) - ] - # Broadcast out b if necessary - new_b = [ - batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else - batching.moveaxis(x, d, 0) if now_bat and d != 0 else x - for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) - ] - - outs = linear_solve_p.bind( - *(new_params + new_b), - const_lengths=const_lengths, - jaxprs=batched_jaxprs) - out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat] - return outs, out_dims - - -linear_solve_p = core.AxisPrimitive('custom_linear_solve') -linear_solve_p.multiple_results = True -linear_solve_p.def_impl(_custom_linear_solve_impl) -linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) -ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp -xla.register_initial_style_primitive(linear_solve_p) -mlir.register_lowering( - linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, - multiple_results=True)) -ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule -batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule -pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \ - partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve') - - -def _interleave(a, b, axis): - """Given two Tensors of static shape, interleave them along the first axis.""" - assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1 - a_pad = [(0, 0, 0)] * a.ndim - b_pad = [(0, 0, 0)] * b.ndim - a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1) - b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1) - op = lax.bitwise_or if a.dtype == np.bool_ else lax.add - return op(lax.pad(a, lax._const(a, 0), a_pad), - lax.pad(b, lax._const(b, 0), b_pad)) +### associative_scan @api_boundary def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): @@ -2860,8 +1692,18 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): return tree_unflatten(tree, scans) +def _interleave(a, b, axis): + """Given two Tensors of static shape, interleave them along the first axis.""" + assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1 + a_pad = [(0, 0, 0)] * a.ndim + b_pad = [(0, 0, 0)] * b.ndim + a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1) + b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1) + op = lax.bitwise_or if a.dtype == np.bool_ else lax.add + return op(lax.pad(a, lax._const(a, 0), a_pad), + lax.pad(b, lax._const(b, 0), b_pad)) -# Cumulative reductions. +### Cumulative reductions. def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative sum along `axis`.""" @@ -2956,120 +1798,3 @@ def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool, ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul) ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min) ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max) - - -def _dummy_remat_result(aval: core.AbstractValue): - """A result that will be discarded""" - if aval is core.abstract_token: - return lax.create_token() - else: - return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore - -def _remat_translation_using_cond(*args, - jaxpr: core.Jaxpr): - # Implements: - # if(rng(0, 1) < 2) - # return eval_jaxpr(*args) - # else: - # return 0 - avals_out = tuple(ov.aval for ov in jaxpr.outvars) - - def remat_comp(*args): - return tuple(core.eval_jaxpr(jaxpr, (), *args)) - def dummy_comp(*args): - return tuple(_map(_dummy_remat_result, avals_out)) - - cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2)) - return cond(cond_pred, remat_comp, dummy_comp, *args) - -def _remat_translation_using_while(*args, - jaxpr: core.Jaxpr): - # Implements: - # for(counter=0, result=0; counter < rng(1, 2); counter ++) { - # result = eval_jaxpr(*args) - # } - # The loop carry is a tuple: (counter, result, args) - avals_out = tuple(ov.aval for ov in jaxpr.outvars) - dummies_like_result = tuple(_map(_dummy_remat_result, avals_out)) - carry_init = (np.int32(0), dummies_like_result, args) - def cond(carry): - counter, _, _ = carry - return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=()) - - def body(carry): - counter, _, args = carry - results = core.eval_jaxpr(jaxpr, (), *args) - return (counter + 1, tuple(results), args) - - carry_res = while_loop(cond, body, carry_init) - return carry_res[1] - - -def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): - args = _optimization_barrier(args) - return core.eval_jaxpr(jaxpr, (), *args) - - -def remat_impl(*args, - call_jaxpr: Optional[core.Jaxpr] = None, - jaxpr: Optional[core.Jaxpr] = None, - platform: str, - prevent_cse: bool, differentiated: bool, - policy, - concrete: bool = False, - name: str = "checkpoint"): - # Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat) - # name is not passed for remat2, defaults to "checkpoint" - # TODO: remove call_jaxpr once we drop the remat call primitive - if jaxpr is None: - jaxpr = call_jaxpr - assert jaxpr is not None - assert not jaxpr.constvars - - del concrete, policy # Unused. - if differentiated and prevent_cse: - if config.jax_remat_opt_barrier: - translation_rule = _remat_translation_using_opt_barrier - elif platform == 'gpu': - translation_rule = _remat_translation_using_while - else: - translation_rule = _remat_translation_using_cond - else: - translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) - - return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr) - -for platform in ("cpu", "gpu", "tpu"): - for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore - mlir.register_lowering(remat_primitive, - mlir.lower_fun(partial(remat_impl, - platform=platform), - multiple_results=True), - platform=platform) - - -def _optimization_barrier_abstract_eval(*args): - return args - - -def _optimization_barrier_lowering_rule(ctx, *args): - barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in) - flat_barrier_types = util.flatten(barrier_types) - - flat_args = mlir.flatten_lowering_ir_args(args) - barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args) - return util.unflatten(barrier_op.results, _map(len, barrier_types)) - - -def _optimization_barrier(arg): - flat_args, treedef = tree_flatten(arg) - return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args)) - - -optimization_barrier_p = core.Primitive('optimization_barrier') -optimization_barrier_p.multiple_results = True -optimization_barrier_p.def_impl( - partial(xla.apply_primitive, optimization_barrier_p)) -optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) -mlir.register_lowering(optimization_barrier_p, - _optimization_barrier_lowering_rule) diff --git a/jax/_src/lax/control_flow/remat_impl.py b/jax/_src/lax/control_flow/remat_impl.py new file mode 100644 index 000000000..ef05b52c0 --- /dev/null +++ b/jax/_src/lax/control_flow/remat_impl.py @@ -0,0 +1,151 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for the remat implementation.""" +from functools import partial + +from typing import Optional + +import jax +from jax import core +from jax import lax +from jax.config import config +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import xla +from jax.tree_util import tree_flatten, tree_unflatten +from jax._src import ad_checkpoint +from jax._src import util +from jax._src.util import safe_map, wrap_name +from jax._src.lax.control_flow.conditionals import cond +from jax._src.lib.mlir.dialects import mhlo +from jax._src.lax.control_flow.loops import while_loop +import numpy as np + +_map = safe_map + +def _dummy_remat_result(aval: core.AbstractValue): + """A result that will be discarded""" + if aval is core.abstract_token: + return lax.create_token() + else: + return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore + +def _remat_translation_using_cond(*args, + jaxpr: core.Jaxpr): + # Implements: + # if(rng(0, 1) < 2) + # return eval_jaxpr(*args) + # else: + # return 0 + avals_out = tuple(ov.aval for ov in jaxpr.outvars) + + def remat_comp(*args): + return tuple(core.eval_jaxpr(jaxpr, (), *args)) + def dummy_comp(*args): + return tuple(_map(_dummy_remat_result, avals_out)) + + cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2)) + return cond(cond_pred, remat_comp, dummy_comp, *args) + +def _remat_translation_using_while(*args, + jaxpr: core.Jaxpr): + # Implements: + # for(counter=0, result=0; counter < rng(1, 2); counter ++) { + # result = eval_jaxpr(*args) + # } + # The loop carry is a tuple: (counter, result, args) + avals_out = tuple(ov.aval for ov in jaxpr.outvars) + dummies_like_result = tuple(_map(_dummy_remat_result, avals_out)) + carry_init = (np.int32(0), dummies_like_result, args) + def cond(carry): + counter, _, _ = carry + return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=()) + + def body(carry): + counter, _, args = carry + results = core.eval_jaxpr(jaxpr, (), *args) + return (counter + 1, tuple(results), args) + + carry_res = while_loop(cond, body, carry_init) + return carry_res[1] + + +def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): + args = _optimization_barrier(args) + return core.eval_jaxpr(jaxpr, (), *args) + + +def remat_impl(*args, + call_jaxpr: Optional[core.Jaxpr] = None, + jaxpr: Optional[core.Jaxpr] = None, + platform: str, + prevent_cse: bool, differentiated: bool, + policy, + concrete: bool = False, + name: str = "checkpoint"): + # Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat) + # name is not passed for remat2, defaults to "checkpoint" + # TODO: remove call_jaxpr once we drop the remat call primitive + if jaxpr is None: + jaxpr = call_jaxpr + assert jaxpr is not None + assert not jaxpr.constvars + + del concrete, policy # Unused. + if differentiated and prevent_cse: + if config.jax_remat_opt_barrier: + translation_rule = _remat_translation_using_opt_barrier + elif platform == 'gpu': + translation_rule = _remat_translation_using_while + else: + translation_rule = _remat_translation_using_cond + else: + translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) + + return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr) + +for platform in ("cpu", "gpu", "tpu"): + for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore + mlir.register_lowering(remat_primitive, + mlir.lower_fun(partial(remat_impl, + platform=platform), + multiple_results=True), + platform=platform) + + +def _optimization_barrier_abstract_eval(*args): + return args + + +def _optimization_barrier_lowering_rule(ctx, *args): + barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in) + flat_barrier_types = util.flatten(barrier_types) + + flat_args = mlir.flatten_lowering_ir_args(args) + barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args) + return util.unflatten(barrier_op.results, _map(len, barrier_types)) + + +def _optimization_barrier(arg): + flat_args, treedef = tree_flatten(arg) + return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args)) + + +optimization_barrier_p = core.Primitive('optimization_barrier') +optimization_barrier_p.multiple_results = True +optimization_barrier_p.def_impl( + partial(xla.apply_primitive, optimization_barrier_p)) +optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) +mlir.register_lowering(optimization_barrier_p, + _optimization_barrier_lowering_rule) diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py new file mode 100644 index 000000000..35446807d --- /dev/null +++ b/jax/_src/lax/control_flow/solves.py @@ -0,0 +1,470 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for the custom linear solve and utilities.""" +import collections +from functools import partial +import operator + +import jax +from jax import core +from jax import lax +from jax import linear_util as lu +from jax.core import raise_to_shaped +from jax.interpreters import ad +from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import xla +from jax.tree_util import (tree_flatten, treedef_children, tree_leaves, + tree_unflatten, treedef_tuple) +from jax._src import ad_util +from jax._src.traceback_util import api_boundary +from jax._src.util import split_list, safe_map +import numpy as np + +from jax._src.lax.control_flow.common import ( + _abstractify, + _check_tree, + _initial_style_jaxpr, + ) + +_map = safe_map + +_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s') + + +def _split_root_args(args, const_lengths): + params_list = split_list(args, list(const_lengths)) + return _RootTuple(*params_list[:-1]), params_list[-1] + + +@api_boundary +def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): + """Differentiably solve for a roots of a function. + + This is a low-level routine, mostly intended for internal use in JAX. + Gradients of custom_root() are defined with respect to closed-over variables + from the provided function ``f`` via the implicit function theorem: + https://en.wikipedia.org/wiki/Implicit_function_theorem + + Args: + f: function for which to find a root. Should accept a single argument, + return a tree of arrays with the same structure as its input. + initial_guess: initial guess for a zero of f. + solve: function to solve for the roots of f. Should take two positional + arguments, f and initial_guess, and return a solution with the same + structure as initial_guess such that func(solution) = 0. In other words, + the following is assumed to be true (but not checked):: + + solution = solve(f, initial_guess) + error = f(solution) + assert all(error == 0) + + tangent_solve: function to solve the tangent system. Should take two + positional arguments, a linear function ``g`` (the function ``f`` + linearized at its root) and a tree of array(s) ``y`` with the same + structure as initial_guess, and return a solution ``x`` such that + ``g(x)=y``: + + - For scalar ``y``, use ``lambda g, y: y / g(1.0)``. + - For vector ``y``, you could use a linear solve with the Jacobian, if + dimensionality of ``y`` is not too large: + ``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``. + has_aux: bool indicating whether the ``solve`` function returns + auxiliary data like solver diagnostics as a second argument. + + Returns: + The result of calling solve(f, initial_guess) with gradients defined via + implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. + """ + guess_flat, in_args_tree = tree_flatten((initial_guess,)) + guess_avals = tuple(_map(_abstractify, guess_flat)) + f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( + f, in_args_tree, guess_avals) + + in_tree, = treedef_children(in_args_tree) + _check_tree("f", "initial_guess", out_tree, in_tree, False) + + solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( + partial(solve, f), in_args_tree, guess_avals) + _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) + + def linearize_and_solve(x, b): + unchecked_zeros, f_jvp = jax.linearize(f, x) + return tangent_solve(f_jvp, b) + + l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( + linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2) + _check_tree("tangent_solve", "x", out_tree, in_tree, False) + + all_consts = [f_consts, solve_consts, l_and_s_consts] + const_lengths = _RootTuple(*_map(len, all_consts)) + jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr) + + solution_flat = _custom_root( + const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat)) + return tree_unflatten(solution_tree, solution_flat) + + +@partial(jax.custom_jvp, nondiff_argnums=(0, 1)) +def _custom_root(const_lengths, jaxprs, *args): + params, initial_guess = _split_root_args(args, const_lengths) + solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess)) + return solution + + +@_custom_root.defjvp +def _root_jvp(const_lengths, jaxprs, primals, tangents): + params, _ = _split_root_args(primals, const_lengths) + sol = _custom_root(const_lengths, jaxprs, *primals) + + f_out_vals = len(jaxprs.f.out_avals) + solution, aux = split_list(sol, [f_out_vals]) + + params_dot, _ = _split_root_args(tangents, const_lengths) + + # F(m, u) = 0 # system of equations in u, parameterized by m + # # solution is u*(m) defined in a neighborhood + # F(m, u*(m)) = 0 # satisfied in a neighborhood + # + # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above + # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange + # + # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp + + f = core.jaxpr_as_fun(jaxprs.f) + linearize_and_solve = partial( + core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s) + f_at_solution = lambda *params: f(*params, *solution) + _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped( + params.f, params_dot.f) + solution_dot = _map( + operator.neg, linearize_and_solve(*solution, *rhs)) + # append aux, create symbolic zero tangents for the aux values + solution += aux + solution_dot += _map(lax.zeros_like_array, aux) + + return solution, solution_dot + + +class _LinearSolveTuple(collections.namedtuple( + '_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')): + + def transpose(self): + return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve) + + +def _split_linear_solve_args(args, const_lengths): + params_list = split_list(args, list(const_lengths)) + return _LinearSolveTuple(*params_list[:-1]), params_list[-1] + + +def _transpose_one_output(linear_fun, primals): + transpose_fun = jax.linear_transpose(linear_fun, primals) + def transposed_fun(x): + (y,) = transpose_fun(x) + return y + return transposed_fun + + +def _flatten(args): + return [x for arg in args for x in arg] + + +def _check_shapes(func_name, expected_name, actual, expected): + actual_shapes = _map(np.shape, tree_leaves(actual)) + expected_shapes = _map(np.shape, tree_leaves(expected)) + if actual_shapes != expected_shapes: + raise ValueError( + f"{func_name}() output shapes must match {expected_name}, " + f"got {actual_shapes} and {expected_shapes}") + + +@api_boundary +def custom_linear_solve( + matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False): + """Perform a matrix-free linear solve with implicitly defined gradients. + + This function allows for overriding or defining gradients for a linear + solve directly via implicit differentiation at the solution, rather than by + differentiating *through* the solve operation. This can sometimes be much faster + or more numerically stable, or differentiating through the solve operation + may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``). + + Required invariant:: + + x = solve(matvec, b) # solve the linear equation + assert matvec(x) == b # not checked + + Args: + matvec: linear function to invert. Must be differentiable. + b: constant right handle side of the equation. May be any nested structure + of arrays. + solve: higher level function that solves for solution to the linear + equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form + as ``b``. This function need not be differentiable. + transpose_solve: higher level function for solving the transpose linear + equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is + the transpose of the linear map ``matvec`` (computed automatically with + autodiff). Required for backwards mode automatic differentiation, unless + ``symmetric=True``, in which case ``solve`` provides the default value. + symmetric: bool indicating if it is safe to assume the linear map + corresponds to a symmetric matrix, i.e., ``matvec == vecmat``. + has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions + return auxiliary data like solver diagnostics as a second argument. + + Returns: + Result of ``solve(matvec, b)``, with gradients defined assuming that the + solution ``x`` satisfies the linear equation ``matvec(x) == b``. + """ + if transpose_solve is None and symmetric: + transpose_solve = solve + + b_flat, in_args_tree = tree_flatten((b,)) + b_avals = tuple(_map(_abstractify, b_flat)) + + tree, = treedef_children(in_args_tree) + + def _shape_checked(fun, name, has_aux): + def f(x): + y = fun(x) + _check_shapes(name, "b", y, b_flat) + return y + + def f_aux(x): + y, aux = fun(x) + _check_shapes(name, "b", y, b_flat) + return y, aux + + return f_aux if has_aux else f + + # no auxiliary data assumed for matvec + matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( + _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, + 'custom_linear_solve') + _check_tree("matvec", "b", out_tree, tree, False) + + solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( + _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, + 'custom_linear_solve') + _check_tree("solve", "b", out_tree, tree, has_aux) + + if transpose_solve is None: + vecmat_jaxpr = tr_solve_jaxpr = None + vecmat_consts = tr_solve_consts = [] + else: + if symmetric: + vecmat = matvec + vecmat_jaxpr = matvec_jaxpr + vecmat_consts = matvec_consts + else: + vecmat = _transpose_one_output(matvec, b) + vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( + vecmat, in_args_tree, b_avals, 'custom_linear_solve') + assert out_tree == tree + + tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( + _shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux), + in_args_tree, b_avals, 'custom_linear_solve') + _check_tree("transpose_solve", "b", out_tree, tree, has_aux) + + all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts] + const_lengths = _LinearSolveTuple(*_map(len, all_consts)) + jaxprs = _LinearSolveTuple( + matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) + + out_flat = linear_solve_p.bind( + *(_flatten(all_consts) + b_flat), + const_lengths=const_lengths, jaxprs=jaxprs) + + return tree_unflatten(out_tree, out_flat) + + +def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): + args_to_raise = args[sum(const_lengths):] + + # raise aux_args to shaped arrays as well if present + # number of aux args is the difference in out_avals + # of solve and matvec (since they map to the same vector space) + + num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) + if num_aux > 0: + args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) + return _map(raise_to_shaped, args_to_raise) + + +def _custom_linear_solve_impl(*args, const_lengths, jaxprs): + params, b = _split_linear_solve_args(args, const_lengths) + x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b)) + return x + + +def _tangent_linear_map(func, params, params_dot, *x): + """Compute the tangent of a linear map. + + Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``, + this function computes ``∂A @ x``. + """ + assert any(type(p) is not ad_util.Zero for p in params_dot) + zeros = _map(ad_util.Zero.from_value, x) + _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( + params + list(x), params_dot + zeros) + return out_tangent + + +def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs): + # A x - b = 0 + # ∂A x + A ∂x - ∂b = 0 + # ∂x = A^{-1} (∂b - ∂A x) + + kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs) + x = linear_solve_p.bind(*primals, **kwargs) + + params, _ = _split_linear_solve_args(primals, const_lengths) + params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths) + + num_x_leaves = len(b_dot) + # x is a flat tree with possible aux values appended + # since x_tree == b_tree == b_dot_tree, we can cut off + # aux values with len info provided by b_dot tree here + x_leaves, _ = split_list(x, [num_x_leaves]) + + if all(type(p) is ad_util.Zero for p in params_dot.matvec): + # no need to evaluate matvec_tangents + rhs = b_dot + else: + matvec_tangents = _tangent_linear_map( + core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves) + rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents)) + + x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs) + + # split into x tangents and aux tangents (these become zero) + dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves]) + + daux_leaves = _map(ad_util.Zero.from_value, daux_leaves) + + x_dot = dx_leaves + daux_leaves + + return x, x_dot + + +def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): + if jaxprs.transpose_solve is None: + raise TypeError('transpose_solve required for backwards mode automatic ' + 'differentiation of custom_linear_solve') + + params, b = _split_linear_solve_args(primals, const_lengths) + # split off symbolic zeros in the cotangent if present + x_cotangent, _ = split_list(cotangent, [len(b)]) + assert all(ad.is_undefined_primal(x) for x in b) + cotangent_b_full = linear_solve_p.bind( + *(_flatten(params.transpose()) + x_cotangent), + const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose()) + # drop aux values in cotangent computation + cotangent_b, _ = split_list(cotangent_b_full, [len(b)]) + return [None] * sum(const_lengths) + cotangent_b + + +def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims, + const_lengths, jaxprs): + orig_bat = [d is not batching.not_mapped for d in dims] + + params, b = _split_linear_solve_args(args, const_lengths) + params_dims, b_dims = _split_linear_solve_args(dims, const_lengths) + params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths) + + (matvec, vecmat, solve, solve_t) = jaxprs + (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat + + num_aux = len(solve.out_avals) - len(matvec.out_avals) + # Fixpoint computation of which parts of x and b are batched; we need to + # ensure this is consistent between all four jaxprs + b_bat = orig_b_bat + x_bat = [False] * len(solve.out_avals) + for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): + # Apply vecmat and solve -> new batched parts of x + solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( + solve, axis_size, solve_bat + b_bat, instantiate=x_bat, + axis_name=axis_name, main_type=main_type) + if vecmat is None: + vecmat_jaxpr_batched = None + x_bat_out = solve_x_bat + else: + vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( + vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat, + axis_name=axis_name, main_type=main_type) + # batch all aux data by default + x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) + + # Apply matvec and solve_t -> new batched parts of b + matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( + matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat, + axis_name=axis_name, main_type=main_type) + if solve_t is None: + solve_t_jaxpr_batched = None + b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) + else: + solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( + solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat, + axis_name=axis_name, main_type=main_type) + assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux + solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) + b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, + orig_b_bat) + if x_bat_out == x_bat and b_bat_out == b_bat: + break + else: + x_bat = x_bat_out + b_bat = b_bat_out + else: + assert False, "Fixedpoint not reached" + + batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched, + solve_jaxpr_batched, solve_t_jaxpr_batched) + + # Move batched axes to the front + new_params = [ + batching.moveaxis(x, d, 0) + if d is not batching.not_mapped and d != 0 else x + for x, d in zip(_flatten(params), _flatten(params_dims)) + ] + # Broadcast out b if necessary + new_b = [ + batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else + batching.moveaxis(x, d, 0) if now_bat and d != 0 else x + for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) + ] + + outs = linear_solve_p.bind( + *(new_params + new_b), + const_lengths=const_lengths, + jaxprs=batched_jaxprs) + out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat] + return outs, out_dims + + +linear_solve_p = core.AxisPrimitive('custom_linear_solve') +linear_solve_p.multiple_results = True +linear_solve_p.def_impl(_custom_linear_solve_impl) +linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) +ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp +xla.register_initial_style_primitive(linear_solve_p) +mlir.register_lowering( + linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, + multiple_results=True)) +ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule +batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule +pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \ + partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')