Refactor control_flow.py into several smaller pieces

This commit is contained in:
Sharad Vikram 2022-06-02 11:50:03 -07:00
parent bc877faae0
commit ed156a2f55
6 changed files with 2390 additions and 2195 deletions

View File

@ -0,0 +1,34 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the control flow primitives."""
from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p,
cummin, cummin_p, cumprod,
cumprod_p, cumsum, cumsum_p,
cumred_tpu_impl, fori_loop, map,
scan, scan_bind, scan_p,
_scan_impl, while_loop, while_p)
from jax._src.lax.control_flow.conditionals import cond, cond_p, switch
from jax._src.lax.control_flow.remat_impl import (remat_impl,
optimization_barrier_p)
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
_custom_linear_solve_impl,
linear_solve_p)
from jax._src.lax.control_flow.common import allowed_effects
# Private utilities used elsewhere in JAX
# TODO(sharadmv): lift them into a more common place
from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
_initial_style_jaxpr,
_initial_style_jaxprs_with_common_consts,
_check_tree_and_avals)

View File

@ -0,0 +1,140 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the common control flow utilities."""
import os
from functools import partial
from typing import Callable, Optional, Sequence, Set
from jax import core
from jax import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src import ad_util
from jax._src import util
from jax._src.util import cache, safe_map, unzip3
from jax.tree_util import tree_map, tree_unflatten, tree_structure
map, unsafe_map = safe_map, map
allowed_effects: Set[core.Effect] = set()
def _abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
def _typecheck_param(prim, param, name, msg_required, pred):
if not pred:
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
f'{msg_required} required:')
param_str = str(param)
sep = os.linesep if os.linesep in param_str else ' '
msg = sep.join([msg, param_str])
raise core.JaxprTypeError(msg)
@cache()
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
return jaxpr, consts, out_tree()
@cache()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
fun, in_tree, in_avals, primitive_name)
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
return closed_jaxpr, consts, out_tree
@cache()
def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable], in_tree, in_avals, primitive_name: str):
# When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for
# their (input) signatures to match. This function "joins" the staged jaxprs:
# for each one, it makes another that accepts *all* constants, but only uses
# those that it needs (dropping the rest).
jaxprs, all_consts, all_out_trees = \
unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
for fun in funs)
newvar = core.gensym(jaxprs, suffix='_')
all_const_avals = [map(_abstractify, consts) for consts in all_consts]
unused_const_vars = [map(newvar, const_avals)
for const_avals in all_const_avals]
def pad_jaxpr_constvars(i, jaxpr):
prefix = util.concatenate(unused_const_vars[:i])
suffix = util.concatenate(unused_const_vars[i + 1:])
constvars = [*prefix, *jaxpr.constvars, *suffix]
return jaxpr.replace(constvars=constvars)
consts = util.concatenate(all_consts)
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
for jaxpr in jaxprs]
return closed_jaxprs, consts, all_out_trees
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
Corresponding `tree` and `avals` must match in the sense that the number of
leaves in `tree` must be equal to the length of `avals`. `what` will be
prepended to details of the mismatch in TypeError.
"""
if tree1 != tree2:
raise TypeError(
f"{what} must have same type structure, got {tree1} and {tree2}.")
if not all(map(core.typematch, avals1, avals2)):
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
tree_unflatten(tree2, avals2))
raise TypeError(f"{what} must have identical types, got\n{diff}.")
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
if has_aux:
actual_tree_children = actual_tree.children()
if len(actual_tree_children) == 2:
# select first child as result tree
actual_tree = tree_structure(actual_tree_children[0])
else:
raise ValueError(
f"{func_name}() produced a pytree with structure "
f"{actual_tree}, but a pytree tuple with auxiliary "
f"output was expected because has_aux was set to True.")
if actual_tree != expected_tree:
raise TypeError(
f"{func_name}() output pytree structure must match {expected_name}, "
f"got {actual_tree} and {expected_tree}.")
def _prune_zeros(ts):
return [t for t in ts if type(t) is not ad_util.Zero]
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
return core.ClosedJaxpr(jaxpr, consts)
def _show_diff(array1, array2):
if core.typematch(array1, array2):
return f"{array1}"
return f"DIFFERENT {array1} vs. {array2}"
def _avals_short(avals):
to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
return ' '.join(map(to_str, avals))

View File

@ -0,0 +1,675 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for conditional control flow primitives."""
import collections
import functools
from functools import partial
import inspect
import itertools
from typing import Callable, Sequence
from jax import core
from jax import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import safe_map, extend_name_stack, split_list
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_avals_short,
_check_tree_and_avals,
_initial_style_jaxprs_with_common_consts,
_make_closed_jaxpr,
_prune_zeros,
_typecheck_param,
allowed_effects,
)
_map, unsafe_map = safe_map, map
# For backward compatibility with a previous switch/cond calling convention,
# we allow a single (pytree) `operand` argument to be passed by keyword. We use
# a sentinel object as its default value to indicate when it is _not_ passed.
_no_operand_sentinel = object()
@api_boundary
def switch(index, branches: Sequence[Callable], *operands,
operand=_no_operand_sentinel):
"""Apply exactly one of ``branches`` given by ``index``.
If ``index`` is out of bounds, it is clamped to within bounds.
Has the semantics of the following Python::
def switch(index, branches, *operands):
index = clamp(0, index, len(branches) - 1)
return branches[index](*operands)
Args:
index: Integer scalar type, indicating which branch function to apply.
branches: Sequence of functions (A -> B) to be applied based on ``index``.
operands: Operands (A) input to whichever branch is applied.
Returns:
Value (B) of ``branch(*operands)`` for the branch that was selected based
on ``index``.
"""
if not all(callable(branch) for branch in branches):
raise TypeError("lax.switch: branches argument should be a sequence of callables.")
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
f"operands can be passed, got operand={operand} "
f"and positional operands {operands}")
operands = (operand,)
del operand
if len(np.shape(index)) != 0:
raise TypeError(
f"Branch index must be scalar, "
f"got {index} of shape {np.shape(index)}.")
try:
index_dtype = dtypes.result_type(index)
except TypeError as err:
msg = f"Index type must be an integer, got {index}."
raise TypeError(msg) from err
if index_dtype.kind not in 'iu':
raise TypeError(
f"Index type must be an integer, got {index} as {index_dtype}")
branches = tuple(branches)
if len(branches) == 0:
raise ValueError("Empty branch sequence")
elif len(branches) == 1:
return branches[0](*operands)
index = lax.convert_element_type(index, np.int32)
lo = np.array(0, np.int32)
hi = np.array(len(branches) - 1, np.int32)
index = lax.clamp(lo, index, hi)
if (config.jax_disable_jit and
isinstance(core.get_aval(index), ConcreteArray)):
return branches[int(index)](*operands)
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(_map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
out_trees[0], jaxprs[0].out_avals,
out_tree, jaxpr.out_avals)
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `switch`: {disallowed_effects}')
linear = (False,) * (len(consts) + len(ops))
out = cond_p.bind(
index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
return tree_unflatten(out_trees[0], out)
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
``cond()`` has equivalent semantics to this Python implementation::
def cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
``pred`` must be a scalar type.
Args:
pred: Boolean scalar type, indicating which branch function to apply.
true_fun: Function (A -> B), to be applied if ``pred`` is True.
false_fun: Function (A -> B), to be applied if ``pred`` is False.
operands: Operands (A) input to either branch depending on ``pred``. The
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
thereof.
Returns:
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
depending on the value of ``pred``. The type can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof.
"""
if not (callable(true_fun) and callable(false_fun)):
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
f"operands can be passed, got operand={operand} "
f"and positional operands {operands}")
operands = (operand,)
del operand
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
raise TypeError(
f"Pred must be a scalar, got {pred} of " +
(f"type {type(pred)}" if isinstance(pred, Sequence)
else f"shape {np.shape(pred)}."))
try:
pred_dtype = dtypes.result_type(pred)
except TypeError as err:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred)) from err
if pred_dtype.kind != 'b':
if pred_dtype.kind in 'iuf':
pred = pred != 0
else:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred_dtype))
if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
ops, ops_tree = tree_flatten(operands)
if linear is None:
linear_ops = [False] * len(ops)
else:
linear_ops, ops_tree2 = tree_flatten(linear)
if ops_tree != ops_tree2:
raise TypeError('linear tree and operand tree mismatch')
ops_avals = tuple(_map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees
_check_tree_and_avals("true_fun and false_fun output",
out_tree, true_jaxpr.out_avals,
false_out_tree, false_jaxpr.out_avals)
joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
index = lax.convert_element_type(pred, np.int32)
linear = [False] * len(consts) + linear_ops
out = cond_p.bind(
index, *consts, *ops,
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
return tree_unflatten(out_tree, out)
@api_boundary
@functools.wraps(_cond)
def cond(*args, **kwargs):
# detect an attempt to call the former, deprecated cond
try:
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
except TypeError:
pass
else:
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
_, _, maybe_true_fun, _, maybe_false_fun = ba.args
if callable(maybe_true_fun) and callable(maybe_false_fun):
return _cond_with_per_branch_args(*ba.args)
return _cond(*args, **kwargs)
def _cond_with_per_branch_args(pred,
true_operand, true_fun: Callable,
false_operand, false_fun: Callable):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Has equivalent semantics to this Python implementation::
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
Pred has to be a scalar type, collection types (list, tuple) are not supported
"""
if not (callable(true_fun) and callable(false_fun)):
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
return _cond(pred,
lambda op: true_fun(op[0]),
lambda op: false_fun(op[1]),
(true_operand, false_operand))
def _cond_abstract_eval(*args, branches, **kwargs):
joined_effects = core.join_effects(*(b.effects for b in branches))
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
joined_effects = core.join_effects(*(b.effects for b in branches))
return _map(raise_to_shaped, branches[0].out_avals), joined_effects
def _bcast_select(pred, on_true, on_false):
if np.ndim(pred) != np.ndim(on_true):
idx = list(range(np.ndim(pred)))
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
return lax.select(pred, on_true, on_false)
def _bcast_select_n(pred, *cases):
if np.ndim(pred) != np.ndim(cases[0]):
idx = list(range(np.ndim(pred)))
pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
return lax.select_n(pred, *cases)
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
index, *ops = args
index_dim, *op_dims = dims
if index_dim is not batching.not_mapped:
# Convert to a lax.select. While we could get away with not broadcasting
# some operands yet, because all outputs must be broadcast together anyway
# for the select we broadcast the input operands for simplicity and leave
# optimizations to XLA.
# TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
index, *ops = (
batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims))
in_batched = [True] * len(branches[0].in_avals)
out_batched = [True] * len(branches[0].out_avals)
branches_batched = [
batching.batch_jaxpr(
jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0]
for jaxpr in branches]
branch_outs = []
for i, jaxpr in enumerate(branches_batched):
# Perform a select on the inputs for safety of reverse-mode autodiff; see
# https://github.com/google/jax/issues/1052
predicate = lax.eq(index, lax._const(index, i))
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
return out, [0 if b else None for b in out_batched]
else:
ops_bat = [d is not batching.not_mapped for d in op_dims]
ops = [batching.moveaxis(x, d, 0) if b else x
for b, x, d in zip(ops_bat, ops, op_dims)]
branches_out_bat = [
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1]
for jaxpr in branches]
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
branches_batched = tuple(
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0]
for jaxpr in branches)
out_dims = [0 if b else batching.not_mapped for b in out_bat]
out = cond_p.bind(
index, *ops, branches=branches_batched, linear=linear)
return out, out_dims
def _cond_jvp(primals, tangents, branches, linear):
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
index_nz, *ops_nz = nonzeros
assert index_nz is False
branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1]
for jaxpr in branches]
out_nz = [any(nz) for nz in zip(*branches_out_nz)]
branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
for jaxpr in branches)
index, *ops = primals
_, *ops_dot = tangents
ops_dot = _prune_zeros(ops_dot)
ops_lin = tuple(linear)
linear_jvp = ops_lin + (True,) * len(ops_dot)
out = cond_p.bind(
index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp)
out_primals, out_tangents = split_list(out, [len(out_nz)])
out_tangents_iter = iter(out_tangents)
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(out_primals, out_nz)]
return out_primals, out_tangents
def _cond_partial_eval(trace, *tracers, branches, linear):
in_unknowns = [t.pval[0] is not None for t in tracers]
index_uk, *ops_uk = in_unknowns
if index_uk:
# When the branch index is unknown, we stage out the whole cond.
# TODO(mattjj): remove this path when old remat is removed
params = dict(branches=branches, linear=linear)
return trace.default_process_primitive(cond_p, tracers, params)
branches_out_uks = []
for branch_jaxpr in branches:
_, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(
branch_jaxpr, ops_uk, instantiate=False)
branches_out_uks.append(out_uks)
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
branches_known, branches_unknown, branch_res_avals = [], [], []
for branch_jaxpr in branches:
branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
branches_known.append(branch_jaxpr_known)
branches_unknown.append(branch_jaxpr_unknown)
branch_res_avals.append(res_avals)
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
num_res = len(all_res_avals)
num_known_outs = len(out_uks) - sum(out_uks)
branches_known = _join_cond_outputs(
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
branches_unknown, all_res_avals, res_avals_per_branch)
assert all(all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
for j in branches_known[1:])
in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
out_consts_res = cond_p.bind(*in_consts, branches=branches_known,
linear=tuple(linear_known))
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
index_tracer = trace.instantiate_const(tracers[0])
ops_tracers = [trace.instantiate_const(t)
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
res_tracers = _map(trace.new_instantiated_const, res)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in branches_unknown[0].out_avals]
linear_unknown = ([False] * num_res +
[l for l, uk in zip(linear, in_unknowns[1:]) if uk])
params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
core.no_effects, source)
for t in out_tracers: t.recipe = eqn
return util.merge_lists(out_uks, out_consts, out_tracers)
# When partially evaluating conditionals, each branch produces residuals
# depending on the computation carried out by the branch, and a corresponding
# staged jaxpr that accepts those residuals as its first few inputs. The
# residual-producing branches are staged as jaxprs and bound right away in a
# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
# conditional. The following helper functions ensure that both collections of
# jaxprs (those evaluated and those staged) are valid for joint use under their
# respective conditionals.
#
# In particular, the residuals derived from each original branch may have
# distinct types. Because the branches of conditionals must have identical type
# signatures, we join residuals together across branches into a common format.
# In order to set up a type signature that all branches can conform to, it would
# suffice to concatenate all branches' residuals. But concatenation can result
# in redundant inputs and outputs, and might lead to memory allocation that
# scales unnecessarily with the branch count. This function finds common
# residual types across branches for reuse, so as to avoid redundant
# allocation. It returns a list L of types (avals) representing the collection
# of residuals merged according to type, and, for each branch, a lookup table to
# match its residuals to their positions/types in L. Example input/output:
#
# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]]
# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]]
# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
def _merge_branch_residuals(branch_res_avals):
def enumerate_equal(xs):
counts = {v: itertools.count() for v in set(xs)}
return [(x, next(counts[x])) for x in xs]
branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
indices = {v: i for i, v in enumerate(all_tagged_avals)}
branch_indices = [
[indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
all_avals = [x for x, _ in all_tagged_avals]
return all_avals, branch_indices
# This function augments branch outputs to agree with the merged residual
# format: each branch is made to return zero-filled values in the places of
# residual outputs that it does not populate.
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
num_non_res_outputs):
def augment_jaxpr(jaxpr, res_indices):
@lu.wrap_init
def f_aug(*args):
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
return outs + list(aug_residuals)
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
# This function augments branch inputs to agree with the merged residual format:
# each branch is made to accept all residuals, even though it will ignore those
# that it does not read.
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
res_aval_indices_per_jaxpr):
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
all_res_vars = _map(newvar, all_res_avals)
def augment_jaxpr(jaxpr, res_indices):
num_res = len(res_indices)
res_vars = jaxpr.jaxpr.invars[:num_res]
non_res_vars = jaxpr.jaxpr.invars[num_res:]
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
aug_invars = aug_res_vars + non_res_vars
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects)
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
return jaxpr_aug
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
def _ordered_unique(xs):
d = collections.OrderedDict((x, None) for x in xs)
return list(d.keys())
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = _map(raise_to_shaped, primal_avals)
@lu.wrap_init
def transposed(*args):
res, cts_out = split_list(args, [num_res])
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
cts_in = ad.backward_pass(
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
index, *ops = args
in_avals = _map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)
branches_trans = tuple(
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
lin_in_avals = [raise_to_shaped(a, weak_type=False)
for a, l in zip(in_avals, linear) if l]
assert all(core.typematch(out_aval, lin_in_aval)
for jaxpr in branches_trans
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
res = ops[:num_res]
cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
linear_trans = (False,) * num_res + (True,) * len(cts)
out = cond_p.bind(
index, *res, *cts, branches=branches_trans, linear=linear_trans)
assert all(_map(core.typecheck, lin_in_avals, out))
out_iter = iter(out)
out = [next(out_iter) if l else None for l in linear]
assert next(out_iter, None) is None
return [None] + out
def _cond_typecheck(*avals, branches, linear):
tc = partial(_typecheck_param, 'cond')
tc(branches, 'branches', 'tuple of ClosedJaxpr',
type(branches) is tuple and
all(type(x) is core.ClosedJaxpr for x in branches))
tc(linear, 'linear', 'tuple of bool',
type(linear) is tuple and all(type(x) is bool for x in linear))
if len(branches) == 0:
raise core.JaxprTypeError('cond requires at least one branch function')
if len(linear) + 1 != len(avals):
raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for '
f'{len(avals) - 1} non-predicate operands')
jaxpr0 = branches[0]
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
joined_effects = core.join_effects(*(b.effects for b in branches))
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
for i, jaxpr in enumerate(branches[1:]):
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
raise core.JaxprTypeError(
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
f'branch {i+1} takes {len(jaxpr.in_avals)}')
if len(jaxpr0.out_avals) != len(jaxpr.out_avals):
raise core.JaxprTypeError(
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching input types: '
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching output types: '
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
if len(avals) != 1 + len(jaxpr0.in_avals):
raise core.JaxprTypeError(
f'cond called with {len(avals) - 1} non-predicate operands, '
f'but branches take {len(jaxpr0.in_avals)} inputs')
index_aval, *op_avals = avals
if index_aval.dtype != np.int32:
raise core.JaxprTypeError(
f'cond called with index of type {index_aval.dtype} instead of int32')
if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)):
raise core.JaxprTypeError(
f'cond branches take input types {jaxpr0_in_avals_str}, '
f'called with operands of type {_avals_short(op_avals)}')
if any(b.effects != branches[0].effects for b in branches[1:]):
raise core.JaxprTypeError(
f'cond branches must have matching effect types: '
f'{[b.effects for b in branches]}')
joined_effects = core.join_effects(*(b.effects for b in branches))
return None, joined_effects
def cond_bind(*args, branches, linear):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_cond_typecheck(*avals, branches=branches, linear=linear)
for jaxpr in branches:
core.check_jaxpr(jaxpr.jaxpr)
return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear)
cond_p = core.AxisPrimitive('cond')
cond_p.multiple_results = True
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
xla.register_initial_style_primitive(cond_p)
core.custom_typechecks[cond_p] = _cond_typecheck
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
joined_effects = core.join_effects(*(branch.effects for branch in branches))
ordered_effects = [eff for eff in joined_effects
if eff in core.ordered_effects]
num_tokens = len(ordered_effects)
tokens_in = ctx.tokens_in.subset(ordered_effects)
output_token_types = [mlir.token_type() for _ in ordered_effects]
output_types = [
*output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)]
flat_output_types = util.flatten(output_types)
# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
# have no arguments; the computation within the block uses implicit
# captures.
case_op = mhlo.CaseOp(flat_output_types, index=index,
num_branches=len(branches))
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
out_vals, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr.jaxpr, tokens_in,
_map(mlir.ir_constants, jaxpr.consts),
*_map(mlir.wrap_singleton_ir_values, args))
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
mhlo.ReturnOp(util.flatten(out_vals))
tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types))
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
return outputs
mlir.register_lowering(cond_p, _cond_lowering)

View File

@ -0,0 +1,151 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the remat implementation."""
from functools import partial
from typing import Optional
import jax
from jax import core
from jax import lax
from jax.config import config
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_checkpoint
from jax._src import util
from jax._src.util import safe_map, wrap_name
from jax._src.lax.control_flow.conditionals import cond
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lax.control_flow.loops import while_loop
import numpy as np
_map = safe_map
def _dummy_remat_result(aval: core.AbstractValue):
"""A result that will be discarded"""
if aval is core.abstract_token:
return lax.create_token()
else:
return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore
def _remat_translation_using_cond(*args,
jaxpr: core.Jaxpr):
# Implements:
# if(rng(0, 1) < 2)
# return eval_jaxpr(*args)
# else:
# return 0
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
def remat_comp(*args):
return tuple(core.eval_jaxpr(jaxpr, (), *args))
def dummy_comp(*args):
return tuple(_map(_dummy_remat_result, avals_out))
cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2))
return cond(cond_pred, remat_comp, dummy_comp, *args)
def _remat_translation_using_while(*args,
jaxpr: core.Jaxpr):
# Implements:
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
# result = eval_jaxpr(*args)
# }
# The loop carry is a tuple: (counter, result, args)
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
dummies_like_result = tuple(_map(_dummy_remat_result, avals_out))
carry_init = (np.int32(0), dummies_like_result, args)
def cond(carry):
counter, _, _ = carry
return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())
def body(carry):
counter, _, args = carry
results = core.eval_jaxpr(jaxpr, (), *args)
return (counter + 1, tuple(results), args)
carry_res = while_loop(cond, body, carry_init)
return carry_res[1]
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
return core.eval_jaxpr(jaxpr, (), *args)
def remat_impl(*args,
call_jaxpr: Optional[core.Jaxpr] = None,
jaxpr: Optional[core.Jaxpr] = None,
platform: str,
prevent_cse: bool, differentiated: bool,
policy,
concrete: bool = False,
name: str = "checkpoint"):
# Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat)
# name is not passed for remat2, defaults to "checkpoint"
# TODO: remove call_jaxpr once we drop the remat call primitive
if jaxpr is None:
jaxpr = call_jaxpr
assert jaxpr is not None
assert not jaxpr.constvars
del concrete, policy # Unused.
if differentiated and prevent_cse:
if config.jax_remat_opt_barrier:
translation_rule = _remat_translation_using_opt_barrier
elif platform == 'gpu':
translation_rule = _remat_translation_using_while
else:
translation_rule = _remat_translation_using_cond
else:
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
for platform in ("cpu", "gpu", "tpu"):
for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore
mlir.register_lowering(remat_primitive,
mlir.lower_fun(partial(remat_impl,
platform=platform),
multiple_results=True),
platform=platform)
def _optimization_barrier_abstract_eval(*args):
return args
def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
flat_barrier_types = util.flatten(barrier_types)
flat_args = mlir.flatten_lowering_ir_args(args)
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
return util.unflatten(barrier_op.results, _map(len, barrier_types))
def _optimization_barrier(arg):
flat_args, treedef = tree_flatten(arg)
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(xla.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)

View File

@ -0,0 +1,470 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the custom linear solve and utilities."""
import collections
from functools import partial
import operator
import jax
from jax import core
from jax import lax
from jax import linear_util as lu
from jax.core import raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, treedef_children, tree_leaves,
tree_unflatten, treedef_tuple)
from jax._src import ad_util
from jax._src.traceback_util import api_boundary
from jax._src.util import split_list, safe_map
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_check_tree,
_initial_style_jaxpr,
)
_map = safe_map
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
def _split_root_args(args, const_lengths):
params_list = split_list(args, list(const_lengths))
return _RootTuple(*params_list[:-1]), params_list[-1]
@api_boundary
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
"""Differentiably solve for a roots of a function.
This is a low-level routine, mostly intended for internal use in JAX.
Gradients of custom_root() are defined with respect to closed-over variables
from the provided function ``f`` via the implicit function theorem:
https://en.wikipedia.org/wiki/Implicit_function_theorem
Args:
f: function for which to find a root. Should accept a single argument,
return a tree of arrays with the same structure as its input.
initial_guess: initial guess for a zero of f.
solve: function to solve for the roots of f. Should take two positional
arguments, f and initial_guess, and return a solution with the same
structure as initial_guess such that func(solution) = 0. In other words,
the following is assumed to be true (but not checked)::
solution = solve(f, initial_guess)
error = f(solution)
assert all(error == 0)
tangent_solve: function to solve the tangent system. Should take two
positional arguments, a linear function ``g`` (the function ``f``
linearized at its root) and a tree of array(s) ``y`` with the same
structure as initial_guess, and return a solution ``x`` such that
``g(x)=y``:
- For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
- For vector ``y``, you could use a linear solve with the Jacobian, if
dimensionality of ``y`` is not too large:
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
has_aux: bool indicating whether the ``solve`` function returns
auxiliary data like solver diagnostics as a second argument.
Returns:
The result of calling solve(f, initial_guess) with gradients defined via
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals)
in_tree, = treedef_children(in_args_tree)
_check_tree("f", "initial_guess", out_tree, in_tree, False)
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
partial(solve, f), in_args_tree, guess_avals)
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
def linearize_and_solve(x, b):
unchecked_zeros, f_jvp = jax.linearize(f, x)
return tangent_solve(f_jvp, b)
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
_check_tree("tangent_solve", "x", out_tree, in_tree, False)
all_consts = [f_consts, solve_consts, l_and_s_consts]
const_lengths = _RootTuple(*_map(len, all_consts))
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
solution_flat = _custom_root(
const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
return tree_unflatten(solution_tree, solution_flat)
@partial(jax.custom_jvp, nondiff_argnums=(0, 1))
def _custom_root(const_lengths, jaxprs, *args):
params, initial_guess = _split_root_args(args, const_lengths)
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
return solution
@_custom_root.defjvp
def _root_jvp(const_lengths, jaxprs, primals, tangents):
params, _ = _split_root_args(primals, const_lengths)
sol = _custom_root(const_lengths, jaxprs, *primals)
f_out_vals = len(jaxprs.f.out_avals)
solution, aux = split_list(sol, [f_out_vals])
params_dot, _ = _split_root_args(tangents, const_lengths)
# F(m, u) = 0 # system of equations in u, parameterized by m
# # solution is u*(m) defined in a neighborhood
# F(m, u*(m)) = 0 # satisfied in a neighborhood
#
# ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above
# ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange
#
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
f = core.jaxpr_as_fun(jaxprs.f)
linearize_and_solve = partial(
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
f_at_solution = lambda *params: f(*params, *solution)
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
params.f, params_dot.f)
solution_dot = _map(
operator.neg, linearize_and_solve(*solution, *rhs))
# append aux, create symbolic zero tangents for the aux values
solution += aux
solution_dot += _map(lax.zeros_like_array, aux)
return solution, solution_dot
class _LinearSolveTuple(collections.namedtuple(
'_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):
def transpose(self):
return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)
def _split_linear_solve_args(args, const_lengths):
params_list = split_list(args, list(const_lengths))
return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
def _transpose_one_output(linear_fun, primals):
transpose_fun = jax.linear_transpose(linear_fun, primals)
def transposed_fun(x):
(y,) = transpose_fun(x)
return y
return transposed_fun
def _flatten(args):
return [x for arg in args for x in arg]
def _check_shapes(func_name, expected_name, actual, expected):
actual_shapes = _map(np.shape, tree_leaves(actual))
expected_shapes = _map(np.shape, tree_leaves(expected))
if actual_shapes != expected_shapes:
raise ValueError(
f"{func_name}() output shapes must match {expected_name}, "
f"got {actual_shapes} and {expected_shapes}")
@api_boundary
def custom_linear_solve(
matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False):
"""Perform a matrix-free linear solve with implicitly defined gradients.
This function allows for overriding or defining gradients for a linear
solve directly via implicit differentiation at the solution, rather than by
differentiating *through* the solve operation. This can sometimes be much faster
or more numerically stable, or differentiating through the solve operation
may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``).
Required invariant::
x = solve(matvec, b) # solve the linear equation
assert matvec(x) == b # not checked
Args:
matvec: linear function to invert. Must be differentiable.
b: constant right handle side of the equation. May be any nested structure
of arrays.
solve: higher level function that solves for solution to the linear
equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form
as ``b``. This function need not be differentiable.
transpose_solve: higher level function for solving the transpose linear
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
the transpose of the linear map ``matvec`` (computed automatically with
autodiff). Required for backwards mode automatic differentiation, unless
``symmetric=True``, in which case ``solve`` provides the default value.
symmetric: bool indicating if it is safe to assume the linear map
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions
return auxiliary data like solver diagnostics as a second argument.
Returns:
Result of ``solve(matvec, b)``, with gradients defined assuming that the
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
"""
if transpose_solve is None and symmetric:
transpose_solve = solve
b_flat, in_args_tree = tree_flatten((b,))
b_avals = tuple(_map(_abstractify, b_flat))
tree, = treedef_children(in_args_tree)
def _shape_checked(fun, name, has_aux):
def f(x):
y = fun(x)
_check_shapes(name, "b", y, b_flat)
return y
def f_aux(x):
y, aux = fun(x)
_check_shapes(name, "b", y, b_flat)
return y, aux
return f_aux if has_aux else f
# no auxiliary data assumed for matvec
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
'custom_linear_solve')
_check_tree("matvec", "b", out_tree, tree, False)
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
'custom_linear_solve')
_check_tree("solve", "b", out_tree, tree, has_aux)
if transpose_solve is None:
vecmat_jaxpr = tr_solve_jaxpr = None
vecmat_consts = tr_solve_consts = []
else:
if symmetric:
vecmat = matvec
vecmat_jaxpr = matvec_jaxpr
vecmat_consts = matvec_consts
else:
vecmat = _transpose_one_output(matvec, b)
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
vecmat, in_args_tree, b_avals, 'custom_linear_solve')
assert out_tree == tree
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux),
in_args_tree, b_avals, 'custom_linear_solve')
_check_tree("transpose_solve", "b", out_tree, tree, has_aux)
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
jaxprs = _LinearSolveTuple(
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
out_flat = linear_solve_p.bind(
*(_flatten(all_consts) + b_flat),
const_lengths=const_lengths, jaxprs=jaxprs)
return tree_unflatten(out_tree, out_flat)
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
args_to_raise = args[sum(const_lengths):]
# raise aux_args to shaped arrays as well if present
# number of aux args is the difference in out_avals
# of solve and matvec (since they map to the same vector space)
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
if num_aux > 0:
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
return _map(raise_to_shaped, args_to_raise)
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
params, b = _split_linear_solve_args(args, const_lengths)
x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
return x
def _tangent_linear_map(func, params, params_dot, *x):
"""Compute the tangent of a linear map.
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
this function computes ``A @ x``.
"""
assert any(type(p) is not ad_util.Zero for p in params_dot)
zeros = _map(ad_util.Zero.from_value, x)
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
params + list(x), params_dot + zeros)
return out_tangent
def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
# A x - b = 0
# ∂A x + A ∂x - ∂b = 0
# ∂x = A^{-1} (∂b - ∂A x)
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs)
x = linear_solve_p.bind(*primals, **kwargs)
params, _ = _split_linear_solve_args(primals, const_lengths)
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
num_x_leaves = len(b_dot)
# x is a flat tree with possible aux values appended
# since x_tree == b_tree == b_dot_tree, we can cut off
# aux values with len info provided by b_dot tree here
x_leaves, _ = split_list(x, [num_x_leaves])
if all(type(p) is ad_util.Zero for p in params_dot.matvec):
# no need to evaluate matvec_tangents
rhs = b_dot
else:
matvec_tangents = _tangent_linear_map(
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves)
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
# split into x tangents and aux tangents (these become zero)
dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves])
daux_leaves = _map(ad_util.Zero.from_value, daux_leaves)
x_dot = dx_leaves + daux_leaves
return x, x_dot
def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
if jaxprs.transpose_solve is None:
raise TypeError('transpose_solve required for backwards mode automatic '
'differentiation of custom_linear_solve')
params, b = _split_linear_solve_args(primals, const_lengths)
# split off symbolic zeros in the cotangent if present
x_cotangent, _ = split_list(cotangent, [len(b)])
assert all(ad.is_undefined_primal(x) for x in b)
cotangent_b_full = linear_solve_p.bind(
*(_flatten(params.transpose()) + x_cotangent),
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
# drop aux values in cotangent computation
cotangent_b, _ = split_list(cotangent_b_full, [len(b)])
return [None] * sum(const_lengths) + cotangent_b
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
const_lengths, jaxprs):
orig_bat = [d is not batching.not_mapped for d in dims]
params, b = _split_linear_solve_args(args, const_lengths)
params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)
(matvec, vecmat, solve, solve_t) = jaxprs
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
num_aux = len(solve.out_avals) - len(matvec.out_avals)
# Fixpoint computation of which parts of x and b are batched; we need to
# ensure this is consistent between all four jaxprs
b_bat = orig_b_bat
x_bat = [False] * len(solve.out_avals)
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
# Apply vecmat and solve -> new batched parts of x
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
solve, axis_size, solve_bat + b_bat, instantiate=x_bat,
axis_name=axis_name, main_type=main_type)
if vecmat is None:
vecmat_jaxpr_batched = None
x_bat_out = solve_x_bat
else:
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat,
axis_name=axis_name, main_type=main_type)
# batch all aux data by default
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
# Apply matvec and solve_t -> new batched parts of b
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat,
axis_name=axis_name, main_type=main_type)
if solve_t is None:
solve_t_jaxpr_batched = None
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
else:
solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat,
axis_name=axis_name, main_type=main_type)
assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
orig_b_bat)
if x_bat_out == x_bat and b_bat_out == b_bat:
break
else:
x_bat = x_bat_out
b_bat = b_bat_out
else:
assert False, "Fixedpoint not reached"
batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched,
solve_jaxpr_batched, solve_t_jaxpr_batched)
# Move batched axes to the front
new_params = [
batching.moveaxis(x, d, 0)
if d is not batching.not_mapped and d != 0 else x
for x, d in zip(_flatten(params), _flatten(params_dims))
]
# Broadcast out b if necessary
new_b = [
batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
]
outs = linear_solve_p.bind(
*(new_params + new_b),
const_lengths=const_lengths,
jaxprs=batched_jaxprs)
out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat]
return outs, out_dims
linear_solve_p = core.AxisPrimitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
mlir.register_lowering(
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
multiple_results=True))
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')