mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Refactor control_flow.py
into several smaller pieces
This commit is contained in:
parent
bc877faae0
commit
ed156a2f55
34
jax/_src/lax/control_flow/__init__.py
Normal file
34
jax/_src/lax/control_flow/__init__.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for the control flow primitives."""
|
||||
from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p,
|
||||
cummin, cummin_p, cumprod,
|
||||
cumprod_p, cumsum, cumsum_p,
|
||||
cumred_tpu_impl, fori_loop, map,
|
||||
scan, scan_bind, scan_p,
|
||||
_scan_impl, while_loop, while_p)
|
||||
from jax._src.lax.control_flow.conditionals import cond, cond_p, switch
|
||||
from jax._src.lax.control_flow.remat_impl import (remat_impl,
|
||||
optimization_barrier_p)
|
||||
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
|
||||
_custom_linear_solve_impl,
|
||||
linear_solve_p)
|
||||
|
||||
from jax._src.lax.control_flow.common import allowed_effects
|
||||
# Private utilities used elsewhere in JAX
|
||||
# TODO(sharadmv): lift them into a more common place
|
||||
from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
|
||||
_initial_style_jaxpr,
|
||||
_initial_style_jaxprs_with_common_consts,
|
||||
_check_tree_and_avals)
|
140
jax/_src/lax/control_flow/common.py
Normal file
140
jax/_src/lax/control_flow/common.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for the common control flow utilities."""
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from typing import Callable, Optional, Sequence, Set
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import ad_util
|
||||
from jax._src import util
|
||||
from jax._src.util import cache, safe_map, unzip3
|
||||
from jax.tree_util import tree_map, tree_unflatten, tree_structure
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
allowed_effects: Set[core.Effect] = set()
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
return core.raise_to_shaped(core.get_aval(x))
|
||||
|
||||
def _typecheck_param(prim, param, name, msg_required, pred):
|
||||
if not pred:
|
||||
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
|
||||
f'{msg_required} required:')
|
||||
param_str = str(param)
|
||||
sep = os.linesep if os.linesep in param_str else ' '
|
||||
msg = sep.join([msg, param_str])
|
||||
raise core.JaxprTypeError(msg)
|
||||
|
||||
@cache()
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
primitive_name: Optional[str] = None):
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
primitive_name: Optional[str] = None):
|
||||
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
|
||||
fun, in_tree, in_avals, primitive_name)
|
||||
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
return closed_jaxpr, consts, out_tree
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxprs_with_common_consts(
|
||||
funs: Sequence[Callable], in_tree, in_avals, primitive_name: str):
|
||||
# When staging the branches of a conditional into jaxprs, constants are
|
||||
# extracted from each branch and converted to jaxpr arguments. To use the
|
||||
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
||||
# their (input) signatures to match. This function "joins" the staged jaxprs:
|
||||
# for each one, it makes another that accepts *all* constants, but only uses
|
||||
# those that it needs (dropping the rest).
|
||||
|
||||
jaxprs, all_consts, all_out_trees = \
|
||||
unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
|
||||
for fun in funs)
|
||||
|
||||
newvar = core.gensym(jaxprs, suffix='_')
|
||||
all_const_avals = [map(_abstractify, consts) for consts in all_consts]
|
||||
unused_const_vars = [map(newvar, const_avals)
|
||||
for const_avals in all_const_avals]
|
||||
def pad_jaxpr_constvars(i, jaxpr):
|
||||
prefix = util.concatenate(unused_const_vars[:i])
|
||||
suffix = util.concatenate(unused_const_vars[i + 1:])
|
||||
constvars = [*prefix, *jaxpr.constvars, *suffix]
|
||||
return jaxpr.replace(constvars=constvars)
|
||||
|
||||
consts = util.concatenate(all_consts)
|
||||
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
|
||||
closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
for jaxpr in jaxprs]
|
||||
return closed_jaxprs, consts, all_out_trees
|
||||
|
||||
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
||||
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
|
||||
|
||||
Corresponding `tree` and `avals` must match in the sense that the number of
|
||||
leaves in `tree` must be equal to the length of `avals`. `what` will be
|
||||
prepended to details of the mismatch in TypeError.
|
||||
"""
|
||||
if tree1 != tree2:
|
||||
raise TypeError(
|
||||
f"{what} must have same type structure, got {tree1} and {tree2}.")
|
||||
if not all(map(core.typematch, avals1, avals2)):
|
||||
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
|
||||
tree_unflatten(tree2, avals2))
|
||||
raise TypeError(f"{what} must have identical types, got\n{diff}.")
|
||||
|
||||
|
||||
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
|
||||
if has_aux:
|
||||
actual_tree_children = actual_tree.children()
|
||||
|
||||
if len(actual_tree_children) == 2:
|
||||
# select first child as result tree
|
||||
actual_tree = tree_structure(actual_tree_children[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{func_name}() produced a pytree with structure "
|
||||
f"{actual_tree}, but a pytree tuple with auxiliary "
|
||||
f"output was expected because has_aux was set to True.")
|
||||
|
||||
if actual_tree != expected_tree:
|
||||
raise TypeError(
|
||||
f"{func_name}() output pytree structure must match {expected_name}, "
|
||||
f"got {actual_tree} and {expected_tree}.")
|
||||
|
||||
def _prune_zeros(ts):
|
||||
return [t for t in ts if type(t) is not ad_util.Zero]
|
||||
|
||||
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
def _show_diff(array1, array2):
|
||||
if core.typematch(array1, array2):
|
||||
return f"{array1}"
|
||||
return f"DIFFERENT {array1} vs. {array2}"
|
||||
|
||||
def _avals_short(avals):
|
||||
to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
|
||||
return ' '.join(map(to_str, avals))
|
675
jax/_src/lax/control_flow/conditionals.py
Normal file
675
jax/_src/lax/control_flow/conditionals.py
Normal file
@ -0,0 +1,675 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for conditional control flow primitives."""
|
||||
import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools
|
||||
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.core import ConcreteArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import safe_map, extend_name_stack, split_list
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_abstractify,
|
||||
_avals_short,
|
||||
_check_tree_and_avals,
|
||||
_initial_style_jaxprs_with_common_consts,
|
||||
_make_closed_jaxpr,
|
||||
_prune_zeros,
|
||||
_typecheck_param,
|
||||
allowed_effects,
|
||||
)
|
||||
|
||||
_map, unsafe_map = safe_map, map
|
||||
|
||||
|
||||
# For backward compatibility with a previous switch/cond calling convention,
|
||||
# we allow a single (pytree) `operand` argument to be passed by keyword. We use
|
||||
# a sentinel object as its default value to indicate when it is _not_ passed.
|
||||
_no_operand_sentinel = object()
|
||||
|
||||
@api_boundary
|
||||
def switch(index, branches: Sequence[Callable], *operands,
|
||||
operand=_no_operand_sentinel):
|
||||
"""Apply exactly one of ``branches`` given by ``index``.
|
||||
|
||||
If ``index`` is out of bounds, it is clamped to within bounds.
|
||||
|
||||
Has the semantics of the following Python::
|
||||
|
||||
def switch(index, branches, *operands):
|
||||
index = clamp(0, index, len(branches) - 1)
|
||||
return branches[index](*operands)
|
||||
|
||||
Args:
|
||||
index: Integer scalar type, indicating which branch function to apply.
|
||||
branches: Sequence of functions (A -> B) to be applied based on ``index``.
|
||||
operands: Operands (A) input to whichever branch is applied.
|
||||
|
||||
Returns:
|
||||
Value (B) of ``branch(*operands)`` for the branch that was selected based
|
||||
on ``index``.
|
||||
"""
|
||||
if not all(callable(branch) for branch in branches):
|
||||
raise TypeError("lax.switch: branches argument should be a sequence of callables.")
|
||||
if operand is not _no_operand_sentinel:
|
||||
if operands:
|
||||
raise TypeError("if 'operand' keyword is passed then no positional "
|
||||
f"operands can be passed, got operand={operand} "
|
||||
f"and positional operands {operands}")
|
||||
operands = (operand,)
|
||||
del operand
|
||||
|
||||
if len(np.shape(index)) != 0:
|
||||
raise TypeError(
|
||||
f"Branch index must be scalar, "
|
||||
f"got {index} of shape {np.shape(index)}.")
|
||||
|
||||
try:
|
||||
index_dtype = dtypes.result_type(index)
|
||||
except TypeError as err:
|
||||
msg = f"Index type must be an integer, got {index}."
|
||||
raise TypeError(msg) from err
|
||||
|
||||
if index_dtype.kind not in 'iu':
|
||||
raise TypeError(
|
||||
f"Index type must be an integer, got {index} as {index_dtype}")
|
||||
|
||||
branches = tuple(branches)
|
||||
|
||||
if len(branches) == 0:
|
||||
raise ValueError("Empty branch sequence")
|
||||
elif len(branches) == 1:
|
||||
return branches[0](*operands)
|
||||
|
||||
index = lax.convert_element_type(index, np.int32)
|
||||
lo = np.array(0, np.int32)
|
||||
hi = np.array(len(branches) - 1, np.int32)
|
||||
index = lax.clamp(lo, index, hi)
|
||||
|
||||
if (config.jax_disable_jit and
|
||||
isinstance(core.get_aval(index), ConcreteArray)):
|
||||
return branches[int(index)](*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(_map(_abstractify, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals, primitive_name='switch')
|
||||
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
||||
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
||||
out_trees[0], jaxprs[0].out_avals,
|
||||
out_tree, jaxpr.out_avals)
|
||||
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `switch`: {disallowed_effects}')
|
||||
|
||||
linear = (False,) * (len(consts) + len(ops))
|
||||
out = cond_p.bind(
|
||||
index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
|
||||
return tree_unflatten(out_trees[0], out)
|
||||
|
||||
|
||||
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
operand=_no_operand_sentinel, linear=None):
|
||||
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
||||
|
||||
``cond()`` has equivalent semantics to this Python implementation::
|
||||
|
||||
def cond(pred, true_fun, false_fun, *operands):
|
||||
if pred:
|
||||
return true_fun(*operands)
|
||||
else:
|
||||
return false_fun(*operands)
|
||||
|
||||
``pred`` must be a scalar type.
|
||||
|
||||
Args:
|
||||
pred: Boolean scalar type, indicating which branch function to apply.
|
||||
true_fun: Function (A -> B), to be applied if ``pred`` is True.
|
||||
false_fun: Function (A -> B), to be applied if ``pred`` is False.
|
||||
operands: Operands (A) input to either branch depending on ``pred``. The
|
||||
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
|
||||
thereof.
|
||||
|
||||
Returns:
|
||||
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
|
||||
depending on the value of ``pred``. The type can be a scalar, array, or any
|
||||
pytree (nested Python tuple/list/dict) thereof.
|
||||
"""
|
||||
if not (callable(true_fun) and callable(false_fun)):
|
||||
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
||||
if operand is not _no_operand_sentinel:
|
||||
if operands:
|
||||
raise TypeError("if 'operand' keyword is passed then no positional "
|
||||
f"operands can be passed, got operand={operand} "
|
||||
f"and positional operands {operands}")
|
||||
operands = (operand,)
|
||||
del operand
|
||||
|
||||
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
||||
raise TypeError(
|
||||
f"Pred must be a scalar, got {pred} of " +
|
||||
(f"type {type(pred)}" if isinstance(pred, Sequence)
|
||||
else f"shape {np.shape(pred)}."))
|
||||
|
||||
try:
|
||||
pred_dtype = dtypes.result_type(pred)
|
||||
except TypeError as err:
|
||||
msg = ("Pred type must be either boolean or number, got {}.")
|
||||
raise TypeError(msg.format(pred)) from err
|
||||
|
||||
if pred_dtype.kind != 'b':
|
||||
if pred_dtype.kind in 'iuf':
|
||||
pred = pred != 0
|
||||
else:
|
||||
msg = ("Pred type must be either boolean or number, got {}.")
|
||||
raise TypeError(msg.format(pred_dtype))
|
||||
|
||||
if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray):
|
||||
if pred:
|
||||
return true_fun(*operands)
|
||||
else:
|
||||
return false_fun(*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
if linear is None:
|
||||
linear_ops = [False] * len(ops)
|
||||
else:
|
||||
linear_ops, ops_tree2 = tree_flatten(linear)
|
||||
if ops_tree != ops_tree2:
|
||||
raise TypeError('linear tree and operand tree mismatch')
|
||||
ops_avals = tuple(_map(_abstractify, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
true_jaxpr, false_jaxpr = jaxprs
|
||||
out_tree, false_out_tree = out_trees
|
||||
|
||||
_check_tree_and_avals("true_fun and false_fun output",
|
||||
out_tree, true_jaxpr.out_avals,
|
||||
false_out_tree, false_jaxpr.out_avals)
|
||||
joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `cond`: {disallowed_effects}')
|
||||
|
||||
index = lax.convert_element_type(pred, np.int32)
|
||||
|
||||
linear = [False] * len(consts) + linear_ops
|
||||
out = cond_p.bind(
|
||||
index, *consts, *ops,
|
||||
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
@api_boundary
|
||||
@functools.wraps(_cond)
|
||||
def cond(*args, **kwargs):
|
||||
# detect an attempt to call the former, deprecated cond
|
||||
try:
|
||||
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
|
||||
_, _, maybe_true_fun, _, maybe_false_fun = ba.args
|
||||
if callable(maybe_true_fun) and callable(maybe_false_fun):
|
||||
return _cond_with_per_branch_args(*ba.args)
|
||||
|
||||
return _cond(*args, **kwargs)
|
||||
|
||||
def _cond_with_per_branch_args(pred,
|
||||
true_operand, true_fun: Callable,
|
||||
false_operand, false_fun: Callable):
|
||||
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
||||
|
||||
Has equivalent semantics to this Python implementation::
|
||||
|
||||
def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
||||
if pred:
|
||||
return true_fun(true_operand)
|
||||
else:
|
||||
return false_fun(false_operand)
|
||||
|
||||
Pred has to be a scalar type, collection types (list, tuple) are not supported
|
||||
"""
|
||||
if not (callable(true_fun) and callable(false_fun)):
|
||||
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
||||
return _cond(pred,
|
||||
lambda op: true_fun(op[0]),
|
||||
lambda op: false_fun(op[1]),
|
||||
(true_operand, false_operand))
|
||||
|
||||
def _cond_abstract_eval(*args, branches, **kwargs):
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `cond`: {disallowed_effects}')
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
return _map(raise_to_shaped, branches[0].out_avals), joined_effects
|
||||
|
||||
def _bcast_select(pred, on_true, on_false):
|
||||
if np.ndim(pred) != np.ndim(on_true):
|
||||
idx = list(range(np.ndim(pred)))
|
||||
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
|
||||
return lax.select(pred, on_true, on_false)
|
||||
|
||||
def _bcast_select_n(pred, *cases):
|
||||
if np.ndim(pred) != np.ndim(cases[0]):
|
||||
idx = list(range(np.ndim(pred)))
|
||||
pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
|
||||
return lax.select_n(pred, *cases)
|
||||
|
||||
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
|
||||
index, *ops = args
|
||||
index_dim, *op_dims = dims
|
||||
|
||||
if index_dim is not batching.not_mapped:
|
||||
# Convert to a lax.select. While we could get away with not broadcasting
|
||||
# some operands yet, because all outputs must be broadcast together anyway
|
||||
# for the select we broadcast the input operands for simplicity and leave
|
||||
# optimizations to XLA.
|
||||
# TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
|
||||
index, *ops = (
|
||||
batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims))
|
||||
|
||||
in_batched = [True] * len(branches[0].in_avals)
|
||||
out_batched = [True] * len(branches[0].out_avals)
|
||||
|
||||
branches_batched = [
|
||||
batching.batch_jaxpr(
|
||||
jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0]
|
||||
for jaxpr in branches]
|
||||
|
||||
branch_outs = []
|
||||
for i, jaxpr in enumerate(branches_batched):
|
||||
# Perform a select on the inputs for safety of reverse-mode autodiff; see
|
||||
# https://github.com/google/jax/issues/1052
|
||||
predicate = lax.eq(index, lax._const(index, i))
|
||||
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]
|
||||
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
|
||||
out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
|
||||
return out, [0 if b else None for b in out_batched]
|
||||
else:
|
||||
ops_bat = [d is not batching.not_mapped for d in op_dims]
|
||||
ops = [batching.moveaxis(x, d, 0) if b else x
|
||||
for b, x, d in zip(ops_bat, ops, op_dims)]
|
||||
|
||||
branches_out_bat = [
|
||||
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1]
|
||||
for jaxpr in branches]
|
||||
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
|
||||
branches_batched = tuple(
|
||||
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0]
|
||||
for jaxpr in branches)
|
||||
|
||||
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
||||
out = cond_p.bind(
|
||||
index, *ops, branches=branches_batched, linear=linear)
|
||||
return out, out_dims
|
||||
|
||||
def _cond_jvp(primals, tangents, branches, linear):
|
||||
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
||||
|
||||
index_nz, *ops_nz = nonzeros
|
||||
assert index_nz is False
|
||||
|
||||
branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1]
|
||||
for jaxpr in branches]
|
||||
out_nz = [any(nz) for nz in zip(*branches_out_nz)]
|
||||
|
||||
branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
|
||||
for jaxpr in branches)
|
||||
|
||||
index, *ops = primals
|
||||
_, *ops_dot = tangents
|
||||
ops_dot = _prune_zeros(ops_dot)
|
||||
|
||||
ops_lin = tuple(linear)
|
||||
linear_jvp = ops_lin + (True,) * len(ops_dot)
|
||||
out = cond_p.bind(
|
||||
index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp)
|
||||
out_primals, out_tangents = split_list(out, [len(out_nz)])
|
||||
out_tangents_iter = iter(out_tangents)
|
||||
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
||||
for p, nz in zip(out_primals, out_nz)]
|
||||
return out_primals, out_tangents
|
||||
|
||||
def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
in_unknowns = [t.pval[0] is not None for t in tracers]
|
||||
index_uk, *ops_uk = in_unknowns
|
||||
|
||||
if index_uk:
|
||||
# When the branch index is unknown, we stage out the whole cond.
|
||||
# TODO(mattjj): remove this path when old remat is removed
|
||||
params = dict(branches=branches, linear=linear)
|
||||
return trace.default_process_primitive(cond_p, tracers, params)
|
||||
|
||||
branches_out_uks = []
|
||||
for branch_jaxpr in branches:
|
||||
_, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(
|
||||
branch_jaxpr, ops_uk, instantiate=False)
|
||||
branches_out_uks.append(out_uks)
|
||||
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
|
||||
|
||||
branches_known, branches_unknown, branch_res_avals = [], [], []
|
||||
for branch_jaxpr in branches:
|
||||
branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
|
||||
pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
|
||||
branches_known.append(branch_jaxpr_known)
|
||||
branches_unknown.append(branch_jaxpr_unknown)
|
||||
branch_res_avals.append(res_avals)
|
||||
|
||||
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
|
||||
num_res = len(all_res_avals)
|
||||
|
||||
num_known_outs = len(out_uks) - sum(out_uks)
|
||||
branches_known = _join_cond_outputs(
|
||||
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
|
||||
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
|
||||
branches_unknown, all_res_avals, res_avals_per_branch)
|
||||
assert all(all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
||||
for j in branches_known[1:])
|
||||
|
||||
in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
|
||||
linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
|
||||
out_consts_res = cond_p.bind(*in_consts, branches=branches_known,
|
||||
linear=tuple(linear_known))
|
||||
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
|
||||
|
||||
index_tracer = trace.instantiate_const(tracers[0])
|
||||
ops_tracers = [trace.instantiate_const(t)
|
||||
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
|
||||
res_tracers = _map(trace.new_instantiated_const, res)
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
for aval in branches_unknown[0].out_avals]
|
||||
linear_unknown = ([False] * num_res +
|
||||
[l for l, uk in zip(linear, in_unknowns[1:]) if uk])
|
||||
params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
|
||||
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = pe.new_eqn_recipe(
|
||||
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
|
||||
core.no_effects, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return util.merge_lists(out_uks, out_consts, out_tracers)
|
||||
|
||||
# When partially evaluating conditionals, each branch produces residuals
|
||||
# depending on the computation carried out by the branch, and a corresponding
|
||||
# staged jaxpr that accepts those residuals as its first few inputs. The
|
||||
# residual-producing branches are staged as jaxprs and bound right away in a
|
||||
# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
|
||||
# conditional. The following helper functions ensure that both collections of
|
||||
# jaxprs (those evaluated and those staged) are valid for joint use under their
|
||||
# respective conditionals.
|
||||
#
|
||||
# In particular, the residuals derived from each original branch may have
|
||||
# distinct types. Because the branches of conditionals must have identical type
|
||||
# signatures, we join residuals together across branches into a common format.
|
||||
|
||||
# In order to set up a type signature that all branches can conform to, it would
|
||||
# suffice to concatenate all branches' residuals. But concatenation can result
|
||||
# in redundant inputs and outputs, and might lead to memory allocation that
|
||||
# scales unnecessarily with the branch count. This function finds common
|
||||
# residual types across branches for reuse, so as to avoid redundant
|
||||
# allocation. It returns a list L of types (avals) representing the collection
|
||||
# of residuals merged according to type, and, for each branch, a lookup table to
|
||||
# match its residuals to their positions/types in L. Example input/output:
|
||||
#
|
||||
# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]]
|
||||
# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]]
|
||||
# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
|
||||
def _merge_branch_residuals(branch_res_avals):
|
||||
def enumerate_equal(xs):
|
||||
counts = {v: itertools.count() for v in set(xs)}
|
||||
return [(x, next(counts[x])) for x in xs]
|
||||
branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
|
||||
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
|
||||
indices = {v: i for i, v in enumerate(all_tagged_avals)}
|
||||
branch_indices = [
|
||||
[indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
|
||||
all_avals = [x for x, _ in all_tagged_avals]
|
||||
return all_avals, branch_indices
|
||||
|
||||
# This function augments branch outputs to agree with the merged residual
|
||||
# format: each branch is made to return zero-filled values in the places of
|
||||
# residual outputs that it does not populate.
|
||||
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
||||
num_non_res_outputs):
|
||||
def augment_jaxpr(jaxpr, res_indices):
|
||||
@lu.wrap_init
|
||||
def f_aug(*args):
|
||||
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
|
||||
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
|
||||
aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
|
||||
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
|
||||
return outs + list(aug_residuals)
|
||||
|
||||
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
|
||||
|
||||
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
||||
|
||||
# This function augments branch inputs to agree with the merged residual format:
|
||||
# each branch is made to accept all residuals, even though it will ignore those
|
||||
# that it does not read.
|
||||
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
||||
res_aval_indices_per_jaxpr):
|
||||
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
|
||||
all_res_vars = _map(newvar, all_res_avals)
|
||||
|
||||
def augment_jaxpr(jaxpr, res_indices):
|
||||
num_res = len(res_indices)
|
||||
res_vars = jaxpr.jaxpr.invars[:num_res]
|
||||
non_res_vars = jaxpr.jaxpr.invars[num_res:]
|
||||
|
||||
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
|
||||
aug_invars = aug_res_vars + non_res_vars
|
||||
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
|
||||
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
|
||||
jaxpr.jaxpr.effects)
|
||||
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
||||
return jaxpr_aug
|
||||
|
||||
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
||||
|
||||
def _ordered_unique(xs):
|
||||
d = collections.OrderedDict((x, None) for x in xs)
|
||||
return list(d.keys())
|
||||
|
||||
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
|
||||
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
||||
primal_avals = _map(raise_to_shaped, primal_avals)
|
||||
|
||||
@lu.wrap_init
|
||||
def transposed(*args):
|
||||
res, cts_out = split_list(args, [num_res])
|
||||
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
||||
cts_in = ad.backward_pass(
|
||||
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
|
||||
_, cts_in = split_list(cts_in, [num_res])
|
||||
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
||||
|
||||
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
||||
|
||||
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
|
||||
index, *ops = args
|
||||
in_avals = _map(raise_to_shaped, branches[0].in_avals)
|
||||
num_res = len(ops) - sum(linear)
|
||||
|
||||
branches_trans = tuple(
|
||||
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
|
||||
lin_in_avals = [raise_to_shaped(a, weak_type=False)
|
||||
for a, l in zip(in_avals, linear) if l]
|
||||
assert all(core.typematch(out_aval, lin_in_aval)
|
||||
for jaxpr in branches_trans
|
||||
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
|
||||
|
||||
res = ops[:num_res]
|
||||
cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
|
||||
linear_trans = (False,) * num_res + (True,) * len(cts)
|
||||
|
||||
out = cond_p.bind(
|
||||
index, *res, *cts, branches=branches_trans, linear=linear_trans)
|
||||
assert all(_map(core.typecheck, lin_in_avals, out))
|
||||
|
||||
out_iter = iter(out)
|
||||
out = [next(out_iter) if l else None for l in linear]
|
||||
assert next(out_iter, None) is None
|
||||
return [None] + out
|
||||
|
||||
def _cond_typecheck(*avals, branches, linear):
|
||||
tc = partial(_typecheck_param, 'cond')
|
||||
tc(branches, 'branches', 'tuple of ClosedJaxpr',
|
||||
type(branches) is tuple and
|
||||
all(type(x) is core.ClosedJaxpr for x in branches))
|
||||
tc(linear, 'linear', 'tuple of bool',
|
||||
type(linear) is tuple and all(type(x) is bool for x in linear))
|
||||
|
||||
if len(branches) == 0:
|
||||
raise core.JaxprTypeError('cond requires at least one branch function')
|
||||
if len(linear) + 1 != len(avals):
|
||||
raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for '
|
||||
f'{len(avals) - 1} non-predicate operands')
|
||||
|
||||
jaxpr0 = branches[0]
|
||||
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
|
||||
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `cond`: {disallowed_effects}')
|
||||
|
||||
for i, jaxpr in enumerate(branches[1:]):
|
||||
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
|
||||
f'branch {i+1} takes {len(jaxpr.in_avals)}')
|
||||
if len(jaxpr0.out_avals) != len(jaxpr.out_avals):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
|
||||
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
|
||||
if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches 0 and {i+1} have mismatching input types: '
|
||||
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
|
||||
if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches 0 and {i+1} have mismatching output types: '
|
||||
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
|
||||
|
||||
if len(avals) != 1 + len(jaxpr0.in_avals):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond called with {len(avals) - 1} non-predicate operands, '
|
||||
f'but branches take {len(jaxpr0.in_avals)} inputs')
|
||||
|
||||
index_aval, *op_avals = avals
|
||||
if index_aval.dtype != np.int32:
|
||||
raise core.JaxprTypeError(
|
||||
f'cond called with index of type {index_aval.dtype} instead of int32')
|
||||
if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches take input types {jaxpr0_in_avals_str}, '
|
||||
f'called with operands of type {_avals_short(op_avals)}')
|
||||
if any(b.effects != branches[0].effects for b in branches[1:]):
|
||||
raise core.JaxprTypeError(
|
||||
f'cond branches must have matching effect types: '
|
||||
f'{[b.effects for b in branches]}')
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
return None, joined_effects
|
||||
|
||||
def cond_bind(*args, branches, linear):
|
||||
if config.jax_enable_checks:
|
||||
avals = _map(core.get_aval, args)
|
||||
_cond_typecheck(*avals, branches=branches, linear=linear)
|
||||
for jaxpr in branches:
|
||||
core.check_jaxpr(jaxpr.jaxpr)
|
||||
return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear)
|
||||
|
||||
cond_p = core.AxisPrimitive('cond')
|
||||
cond_p.multiple_results = True
|
||||
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
||||
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
|
||||
cond_p.def_custom_bind(cond_bind)
|
||||
ad.primitive_jvps[cond_p] = _cond_jvp
|
||||
ad.reducing_transposes[cond_p] = _cond_transpose
|
||||
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
||||
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
|
||||
xla.register_initial_style_primitive(cond_p)
|
||||
core.custom_typechecks[cond_p] = _cond_typecheck
|
||||
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
|
||||
|
||||
def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
del linear # Unused.
|
||||
joined_effects = core.join_effects(*(branch.effects for branch in branches))
|
||||
ordered_effects = [eff for eff in joined_effects
|
||||
if eff in core.ordered_effects]
|
||||
num_tokens = len(ordered_effects)
|
||||
tokens_in = ctx.tokens_in.subset(ordered_effects)
|
||||
output_token_types = [mlir.token_type() for _ in ordered_effects]
|
||||
output_types = [
|
||||
*output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)]
|
||||
flat_output_types = util.flatten(output_types)
|
||||
|
||||
# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
|
||||
# have no arguments; the computation within the block uses implicit
|
||||
# captures.
|
||||
case_op = mhlo.CaseOp(flat_output_types, index=index,
|
||||
num_branches=len(branches))
|
||||
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
|
||||
for i, jaxpr in enumerate(branches):
|
||||
branch = case_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
|
||||
out_vals, tokens_out = mlir.jaxpr_subcomp(
|
||||
sub_ctx, jaxpr.jaxpr, tokens_in,
|
||||
_map(mlir.ir_constants, jaxpr.consts),
|
||||
*_map(mlir.wrap_singleton_ir_values, args))
|
||||
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
|
||||
out_vals = [*out_tokens, *out_vals]
|
||||
mhlo.ReturnOp(util.flatten(out_vals))
|
||||
|
||||
tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types))
|
||||
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
|
||||
return outputs
|
||||
|
||||
mlir.register_lowering(cond_p, _cond_lowering)
|
File diff suppressed because it is too large
Load Diff
151
jax/_src/lax/control_flow/remat_impl.py
Normal file
151
jax/_src/lax/control_flow/remat_impl.py
Normal file
@ -0,0 +1,151 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for the remat implementation."""
|
||||
from functools import partial
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import util
|
||||
from jax._src.util import safe_map, wrap_name
|
||||
from jax._src.lax.control_flow.conditionals import cond
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lax.control_flow.loops import while_loop
|
||||
import numpy as np
|
||||
|
||||
_map = safe_map
|
||||
|
||||
def _dummy_remat_result(aval: core.AbstractValue):
|
||||
"""A result that will be discarded"""
|
||||
if aval is core.abstract_token:
|
||||
return lax.create_token()
|
||||
else:
|
||||
return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore
|
||||
|
||||
def _remat_translation_using_cond(*args,
|
||||
jaxpr: core.Jaxpr):
|
||||
# Implements:
|
||||
# if(rng(0, 1) < 2)
|
||||
# return eval_jaxpr(*args)
|
||||
# else:
|
||||
# return 0
|
||||
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
|
||||
|
||||
def remat_comp(*args):
|
||||
return tuple(core.eval_jaxpr(jaxpr, (), *args))
|
||||
def dummy_comp(*args):
|
||||
return tuple(_map(_dummy_remat_result, avals_out))
|
||||
|
||||
cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2))
|
||||
return cond(cond_pred, remat_comp, dummy_comp, *args)
|
||||
|
||||
def _remat_translation_using_while(*args,
|
||||
jaxpr: core.Jaxpr):
|
||||
# Implements:
|
||||
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
|
||||
# result = eval_jaxpr(*args)
|
||||
# }
|
||||
# The loop carry is a tuple: (counter, result, args)
|
||||
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
|
||||
dummies_like_result = tuple(_map(_dummy_remat_result, avals_out))
|
||||
carry_init = (np.int32(0), dummies_like_result, args)
|
||||
def cond(carry):
|
||||
counter, _, _ = carry
|
||||
return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())
|
||||
|
||||
def body(carry):
|
||||
counter, _, args = carry
|
||||
results = core.eval_jaxpr(jaxpr, (), *args)
|
||||
return (counter + 1, tuple(results), args)
|
||||
|
||||
carry_res = while_loop(cond, body, carry_init)
|
||||
return carry_res[1]
|
||||
|
||||
|
||||
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
||||
args = _optimization_barrier(args)
|
||||
return core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
|
||||
def remat_impl(*args,
|
||||
call_jaxpr: Optional[core.Jaxpr] = None,
|
||||
jaxpr: Optional[core.Jaxpr] = None,
|
||||
platform: str,
|
||||
prevent_cse: bool, differentiated: bool,
|
||||
policy,
|
||||
concrete: bool = False,
|
||||
name: str = "checkpoint"):
|
||||
# Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat)
|
||||
# name is not passed for remat2, defaults to "checkpoint"
|
||||
# TODO: remove call_jaxpr once we drop the remat call primitive
|
||||
if jaxpr is None:
|
||||
jaxpr = call_jaxpr
|
||||
assert jaxpr is not None
|
||||
assert not jaxpr.constvars
|
||||
|
||||
del concrete, policy # Unused.
|
||||
if differentiated and prevent_cse:
|
||||
if config.jax_remat_opt_barrier:
|
||||
translation_rule = _remat_translation_using_opt_barrier
|
||||
elif platform == 'gpu':
|
||||
translation_rule = _remat_translation_using_while
|
||||
else:
|
||||
translation_rule = _remat_translation_using_cond
|
||||
else:
|
||||
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
|
||||
|
||||
for platform in ("cpu", "gpu", "tpu"):
|
||||
for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore
|
||||
mlir.register_lowering(remat_primitive,
|
||||
mlir.lower_fun(partial(remat_impl,
|
||||
platform=platform),
|
||||
multiple_results=True),
|
||||
platform=platform)
|
||||
|
||||
|
||||
def _optimization_barrier_abstract_eval(*args):
|
||||
return args
|
||||
|
||||
|
||||
def _optimization_barrier_lowering_rule(ctx, *args):
|
||||
barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
flat_barrier_types = util.flatten(barrier_types)
|
||||
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
|
||||
return util.unflatten(barrier_op.results, _map(len, barrier_types))
|
||||
|
||||
|
||||
def _optimization_barrier(arg):
|
||||
flat_args, treedef = tree_flatten(arg)
|
||||
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
|
||||
|
||||
|
||||
optimization_barrier_p = core.Primitive('optimization_barrier')
|
||||
optimization_barrier_p.multiple_results = True
|
||||
optimization_barrier_p.def_impl(
|
||||
partial(xla.apply_primitive, optimization_barrier_p))
|
||||
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
||||
mlir.register_lowering(optimization_barrier_p,
|
||||
_optimization_barrier_lowering_rule)
|
470
jax/_src/lax/control_flow/solves.py
Normal file
470
jax/_src/lax/control_flow/solves.py
Normal file
@ -0,0 +1,470 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for the custom linear solve and utilities."""
|
||||
import collections
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.core import raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, treedef_children, tree_leaves,
|
||||
tree_unflatten, treedef_tuple)
|
||||
from jax._src import ad_util
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import split_list, safe_map
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_abstractify,
|
||||
_check_tree,
|
||||
_initial_style_jaxpr,
|
||||
)
|
||||
|
||||
_map = safe_map
|
||||
|
||||
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
|
||||
|
||||
|
||||
def _split_root_args(args, const_lengths):
|
||||
params_list = split_list(args, list(const_lengths))
|
||||
return _RootTuple(*params_list[:-1]), params_list[-1]
|
||||
|
||||
|
||||
@api_boundary
|
||||
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
|
||||
"""Differentiably solve for a roots of a function.
|
||||
|
||||
This is a low-level routine, mostly intended for internal use in JAX.
|
||||
Gradients of custom_root() are defined with respect to closed-over variables
|
||||
from the provided function ``f`` via the implicit function theorem:
|
||||
https://en.wikipedia.org/wiki/Implicit_function_theorem
|
||||
|
||||
Args:
|
||||
f: function for which to find a root. Should accept a single argument,
|
||||
return a tree of arrays with the same structure as its input.
|
||||
initial_guess: initial guess for a zero of f.
|
||||
solve: function to solve for the roots of f. Should take two positional
|
||||
arguments, f and initial_guess, and return a solution with the same
|
||||
structure as initial_guess such that func(solution) = 0. In other words,
|
||||
the following is assumed to be true (but not checked)::
|
||||
|
||||
solution = solve(f, initial_guess)
|
||||
error = f(solution)
|
||||
assert all(error == 0)
|
||||
|
||||
tangent_solve: function to solve the tangent system. Should take two
|
||||
positional arguments, a linear function ``g`` (the function ``f``
|
||||
linearized at its root) and a tree of array(s) ``y`` with the same
|
||||
structure as initial_guess, and return a solution ``x`` such that
|
||||
``g(x)=y``:
|
||||
|
||||
- For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
|
||||
- For vector ``y``, you could use a linear solve with the Jacobian, if
|
||||
dimensionality of ``y`` is not too large:
|
||||
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
|
||||
has_aux: bool indicating whether the ``solve`` function returns
|
||||
auxiliary data like solver diagnostics as a second argument.
|
||||
|
||||
Returns:
|
||||
The result of calling solve(f, initial_guess) with gradients defined via
|
||||
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
|
||||
"""
|
||||
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
||||
guess_avals = tuple(_map(_abstractify, guess_flat))
|
||||
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_args_tree, guess_avals)
|
||||
|
||||
in_tree, = treedef_children(in_args_tree)
|
||||
_check_tree("f", "initial_guess", out_tree, in_tree, False)
|
||||
|
||||
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
|
||||
partial(solve, f), in_args_tree, guess_avals)
|
||||
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
|
||||
|
||||
def linearize_and_solve(x, b):
|
||||
unchecked_zeros, f_jvp = jax.linearize(f, x)
|
||||
return tangent_solve(f_jvp, b)
|
||||
|
||||
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
|
||||
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
|
||||
_check_tree("tangent_solve", "x", out_tree, in_tree, False)
|
||||
|
||||
all_consts = [f_consts, solve_consts, l_and_s_consts]
|
||||
const_lengths = _RootTuple(*_map(len, all_consts))
|
||||
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
|
||||
|
||||
solution_flat = _custom_root(
|
||||
const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
|
||||
return tree_unflatten(solution_tree, solution_flat)
|
||||
|
||||
|
||||
@partial(jax.custom_jvp, nondiff_argnums=(0, 1))
|
||||
def _custom_root(const_lengths, jaxprs, *args):
|
||||
params, initial_guess = _split_root_args(args, const_lengths)
|
||||
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
|
||||
return solution
|
||||
|
||||
|
||||
@_custom_root.defjvp
|
||||
def _root_jvp(const_lengths, jaxprs, primals, tangents):
|
||||
params, _ = _split_root_args(primals, const_lengths)
|
||||
sol = _custom_root(const_lengths, jaxprs, *primals)
|
||||
|
||||
f_out_vals = len(jaxprs.f.out_avals)
|
||||
solution, aux = split_list(sol, [f_out_vals])
|
||||
|
||||
params_dot, _ = _split_root_args(tangents, const_lengths)
|
||||
|
||||
# F(m, u) = 0 # system of equations in u, parameterized by m
|
||||
# # solution is u*(m) defined in a neighborhood
|
||||
# F(m, u*(m)) = 0 # satisfied in a neighborhood
|
||||
#
|
||||
# ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above
|
||||
# ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange
|
||||
#
|
||||
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
|
||||
|
||||
f = core.jaxpr_as_fun(jaxprs.f)
|
||||
linearize_and_solve = partial(
|
||||
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
|
||||
f_at_solution = lambda *params: f(*params, *solution)
|
||||
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
|
||||
params.f, params_dot.f)
|
||||
solution_dot = _map(
|
||||
operator.neg, linearize_and_solve(*solution, *rhs))
|
||||
# append aux, create symbolic zero tangents for the aux values
|
||||
solution += aux
|
||||
solution_dot += _map(lax.zeros_like_array, aux)
|
||||
|
||||
return solution, solution_dot
|
||||
|
||||
|
||||
class _LinearSolveTuple(collections.namedtuple(
|
||||
'_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):
|
||||
|
||||
def transpose(self):
|
||||
return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)
|
||||
|
||||
|
||||
def _split_linear_solve_args(args, const_lengths):
|
||||
params_list = split_list(args, list(const_lengths))
|
||||
return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
|
||||
|
||||
|
||||
def _transpose_one_output(linear_fun, primals):
|
||||
transpose_fun = jax.linear_transpose(linear_fun, primals)
|
||||
def transposed_fun(x):
|
||||
(y,) = transpose_fun(x)
|
||||
return y
|
||||
return transposed_fun
|
||||
|
||||
|
||||
def _flatten(args):
|
||||
return [x for arg in args for x in arg]
|
||||
|
||||
|
||||
def _check_shapes(func_name, expected_name, actual, expected):
|
||||
actual_shapes = _map(np.shape, tree_leaves(actual))
|
||||
expected_shapes = _map(np.shape, tree_leaves(expected))
|
||||
if actual_shapes != expected_shapes:
|
||||
raise ValueError(
|
||||
f"{func_name}() output shapes must match {expected_name}, "
|
||||
f"got {actual_shapes} and {expected_shapes}")
|
||||
|
||||
|
||||
@api_boundary
|
||||
def custom_linear_solve(
|
||||
matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False):
|
||||
"""Perform a matrix-free linear solve with implicitly defined gradients.
|
||||
|
||||
This function allows for overriding or defining gradients for a linear
|
||||
solve directly via implicit differentiation at the solution, rather than by
|
||||
differentiating *through* the solve operation. This can sometimes be much faster
|
||||
or more numerically stable, or differentiating through the solve operation
|
||||
may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``).
|
||||
|
||||
Required invariant::
|
||||
|
||||
x = solve(matvec, b) # solve the linear equation
|
||||
assert matvec(x) == b # not checked
|
||||
|
||||
Args:
|
||||
matvec: linear function to invert. Must be differentiable.
|
||||
b: constant right handle side of the equation. May be any nested structure
|
||||
of arrays.
|
||||
solve: higher level function that solves for solution to the linear
|
||||
equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form
|
||||
as ``b``. This function need not be differentiable.
|
||||
transpose_solve: higher level function for solving the transpose linear
|
||||
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
|
||||
the transpose of the linear map ``matvec`` (computed automatically with
|
||||
autodiff). Required for backwards mode automatic differentiation, unless
|
||||
``symmetric=True``, in which case ``solve`` provides the default value.
|
||||
symmetric: bool indicating if it is safe to assume the linear map
|
||||
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
|
||||
has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions
|
||||
return auxiliary data like solver diagnostics as a second argument.
|
||||
|
||||
Returns:
|
||||
Result of ``solve(matvec, b)``, with gradients defined assuming that the
|
||||
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
|
||||
"""
|
||||
if transpose_solve is None and symmetric:
|
||||
transpose_solve = solve
|
||||
|
||||
b_flat, in_args_tree = tree_flatten((b,))
|
||||
b_avals = tuple(_map(_abstractify, b_flat))
|
||||
|
||||
tree, = treedef_children(in_args_tree)
|
||||
|
||||
def _shape_checked(fun, name, has_aux):
|
||||
def f(x):
|
||||
y = fun(x)
|
||||
_check_shapes(name, "b", y, b_flat)
|
||||
return y
|
||||
|
||||
def f_aux(x):
|
||||
y, aux = fun(x)
|
||||
_check_shapes(name, "b", y, b_flat)
|
||||
return y, aux
|
||||
|
||||
return f_aux if has_aux else f
|
||||
|
||||
# no auxiliary data assumed for matvec
|
||||
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
|
||||
'custom_linear_solve')
|
||||
_check_tree("matvec", "b", out_tree, tree, False)
|
||||
|
||||
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
|
||||
'custom_linear_solve')
|
||||
_check_tree("solve", "b", out_tree, tree, has_aux)
|
||||
|
||||
if transpose_solve is None:
|
||||
vecmat_jaxpr = tr_solve_jaxpr = None
|
||||
vecmat_consts = tr_solve_consts = []
|
||||
else:
|
||||
if symmetric:
|
||||
vecmat = matvec
|
||||
vecmat_jaxpr = matvec_jaxpr
|
||||
vecmat_consts = matvec_consts
|
||||
else:
|
||||
vecmat = _transpose_one_output(matvec, b)
|
||||
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
|
||||
vecmat, in_args_tree, b_avals, 'custom_linear_solve')
|
||||
assert out_tree == tree
|
||||
|
||||
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux),
|
||||
in_args_tree, b_avals, 'custom_linear_solve')
|
||||
_check_tree("transpose_solve", "b", out_tree, tree, has_aux)
|
||||
|
||||
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
|
||||
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
|
||||
jaxprs = _LinearSolveTuple(
|
||||
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
|
||||
|
||||
out_flat = linear_solve_p.bind(
|
||||
*(_flatten(all_consts) + b_flat),
|
||||
const_lengths=const_lengths, jaxprs=jaxprs)
|
||||
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
|
||||
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
|
||||
args_to_raise = args[sum(const_lengths):]
|
||||
|
||||
# raise aux_args to shaped arrays as well if present
|
||||
# number of aux args is the difference in out_avals
|
||||
# of solve and matvec (since they map to the same vector space)
|
||||
|
||||
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
|
||||
if num_aux > 0:
|
||||
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
|
||||
return _map(raise_to_shaped, args_to_raise)
|
||||
|
||||
|
||||
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
||||
params, b = _split_linear_solve_args(args, const_lengths)
|
||||
x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
|
||||
return x
|
||||
|
||||
|
||||
def _tangent_linear_map(func, params, params_dot, *x):
|
||||
"""Compute the tangent of a linear map.
|
||||
|
||||
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
|
||||
this function computes ``∂A @ x``.
|
||||
"""
|
||||
assert any(type(p) is not ad_util.Zero for p in params_dot)
|
||||
zeros = _map(ad_util.Zero.from_value, x)
|
||||
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
|
||||
params + list(x), params_dot + zeros)
|
||||
return out_tangent
|
||||
|
||||
|
||||
def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
|
||||
# A x - b = 0
|
||||
# ∂A x + A ∂x - ∂b = 0
|
||||
# ∂x = A^{-1} (∂b - ∂A x)
|
||||
|
||||
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs)
|
||||
x = linear_solve_p.bind(*primals, **kwargs)
|
||||
|
||||
params, _ = _split_linear_solve_args(primals, const_lengths)
|
||||
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
|
||||
|
||||
num_x_leaves = len(b_dot)
|
||||
# x is a flat tree with possible aux values appended
|
||||
# since x_tree == b_tree == b_dot_tree, we can cut off
|
||||
# aux values with len info provided by b_dot tree here
|
||||
x_leaves, _ = split_list(x, [num_x_leaves])
|
||||
|
||||
if all(type(p) is ad_util.Zero for p in params_dot.matvec):
|
||||
# no need to evaluate matvec_tangents
|
||||
rhs = b_dot
|
||||
else:
|
||||
matvec_tangents = _tangent_linear_map(
|
||||
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves)
|
||||
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
|
||||
|
||||
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
|
||||
|
||||
# split into x tangents and aux tangents (these become zero)
|
||||
dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves])
|
||||
|
||||
daux_leaves = _map(ad_util.Zero.from_value, daux_leaves)
|
||||
|
||||
x_dot = dx_leaves + daux_leaves
|
||||
|
||||
return x, x_dot
|
||||
|
||||
|
||||
def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
|
||||
if jaxprs.transpose_solve is None:
|
||||
raise TypeError('transpose_solve required for backwards mode automatic '
|
||||
'differentiation of custom_linear_solve')
|
||||
|
||||
params, b = _split_linear_solve_args(primals, const_lengths)
|
||||
# split off symbolic zeros in the cotangent if present
|
||||
x_cotangent, _ = split_list(cotangent, [len(b)])
|
||||
assert all(ad.is_undefined_primal(x) for x in b)
|
||||
cotangent_b_full = linear_solve_p.bind(
|
||||
*(_flatten(params.transpose()) + x_cotangent),
|
||||
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
|
||||
# drop aux values in cotangent computation
|
||||
cotangent_b, _ = split_list(cotangent_b_full, [len(b)])
|
||||
return [None] * sum(const_lengths) + cotangent_b
|
||||
|
||||
|
||||
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
|
||||
const_lengths, jaxprs):
|
||||
orig_bat = [d is not batching.not_mapped for d in dims]
|
||||
|
||||
params, b = _split_linear_solve_args(args, const_lengths)
|
||||
params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
|
||||
params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)
|
||||
|
||||
(matvec, vecmat, solve, solve_t) = jaxprs
|
||||
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
|
||||
|
||||
num_aux = len(solve.out_avals) - len(matvec.out_avals)
|
||||
# Fixpoint computation of which parts of x and b are batched; we need to
|
||||
# ensure this is consistent between all four jaxprs
|
||||
b_bat = orig_b_bat
|
||||
x_bat = [False] * len(solve.out_avals)
|
||||
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
|
||||
# Apply vecmat and solve -> new batched parts of x
|
||||
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
|
||||
solve, axis_size, solve_bat + b_bat, instantiate=x_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
if vecmat is None:
|
||||
vecmat_jaxpr_batched = None
|
||||
x_bat_out = solve_x_bat
|
||||
else:
|
||||
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
|
||||
vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
# batch all aux data by default
|
||||
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
|
||||
|
||||
# Apply matvec and solve_t -> new batched parts of b
|
||||
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
|
||||
matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
if solve_t is None:
|
||||
solve_t_jaxpr_batched = None
|
||||
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
|
||||
else:
|
||||
solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
|
||||
solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
|
||||
solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
|
||||
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
|
||||
orig_b_bat)
|
||||
if x_bat_out == x_bat and b_bat_out == b_bat:
|
||||
break
|
||||
else:
|
||||
x_bat = x_bat_out
|
||||
b_bat = b_bat_out
|
||||
else:
|
||||
assert False, "Fixedpoint not reached"
|
||||
|
||||
batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched,
|
||||
solve_jaxpr_batched, solve_t_jaxpr_batched)
|
||||
|
||||
# Move batched axes to the front
|
||||
new_params = [
|
||||
batching.moveaxis(x, d, 0)
|
||||
if d is not batching.not_mapped and d != 0 else x
|
||||
for x, d in zip(_flatten(params), _flatten(params_dims))
|
||||
]
|
||||
# Broadcast out b if necessary
|
||||
new_b = [
|
||||
batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
|
||||
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
|
||||
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
|
||||
]
|
||||
|
||||
outs = linear_solve_p.bind(
|
||||
*(new_params + new_b),
|
||||
const_lengths=const_lengths,
|
||||
jaxprs=batched_jaxprs)
|
||||
out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat]
|
||||
return outs, out_dims
|
||||
|
||||
|
||||
linear_solve_p = core.AxisPrimitive('custom_linear_solve')
|
||||
linear_solve_p.multiple_results = True
|
||||
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
||||
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
||||
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
||||
xla.register_initial_style_primitive(linear_solve_p)
|
||||
mlir.register_lowering(
|
||||
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
|
||||
multiple_results=True))
|
||||
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
|
||||
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')
|
Loading…
x
Reference in New Issue
Block a user