2020-10-17 14:33:26 -04:00
|
|
|
# coding=utf-8
|
|
|
|
# Copyright 2019 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""
|
|
|
|
Control flow primitives.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import collections
|
|
|
|
import functools
|
2021-09-13 17:24:44 -04:00
|
|
|
from functools import partial
|
2020-10-17 14:33:26 -04:00
|
|
|
import inspect
|
|
|
|
import itertools
|
|
|
|
import operator
|
|
|
|
import os
|
2020-10-26 23:19:33 -04:00
|
|
|
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import jax
|
2021-04-13 09:42:54 -07:00
|
|
|
from jax._src import api
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax import core
|
2021-12-02 08:19:16 +02:00
|
|
|
from jax._src import ad_checkpoint
|
2021-04-07 19:35:17 -07:00
|
|
|
from jax._src import dtypes
|
2020-11-04 11:54:01 -08:00
|
|
|
from jax._src import source_info_util
|
2021-01-11 14:20:32 -08:00
|
|
|
from jax._src import util
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax._src.lax import lax
|
2021-11-23 16:34:33 -08:00
|
|
|
from jax._src.lax import slicing
|
2021-11-18 16:31:06 -05:00
|
|
|
from jax._src.lax import windowed_reductions
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax import linear_util as lu
|
2020-11-18 21:17:02 -05:00
|
|
|
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
|
2021-09-08 09:00:23 -07:00
|
|
|
from jax._src.api_util import flatten_fun_nokwargs
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax.interpreters import ad
|
|
|
|
from jax.interpreters import partial_eval as pe
|
2021-11-23 18:57:45 -08:00
|
|
|
from jax.interpreters import mlir
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax.interpreters import xla
|
|
|
|
from jax.interpreters import batching
|
|
|
|
from jax.interpreters import masking
|
2021-11-23 18:57:45 -08:00
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
from jax._src.lib.mlir.dialects import mhlo
|
2021-02-23 20:08:10 -08:00
|
|
|
from jax._src.traceback_util import api_boundary
|
2021-09-13 17:24:44 -04:00
|
|
|
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
|
2021-12-02 08:19:16 +02:00
|
|
|
split_list, cache, extend_name_stack, wrap_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
2022-04-01 14:51:54 -07:00
|
|
|
treedef_children, treedef_tuple, tree_map,
|
2021-05-25 18:00:46 +02:00
|
|
|
tree_leaves, tree_structure)
|
2021-06-07 14:51:04 -07:00
|
|
|
from jax._src import ad_util
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax.config import config
|
|
|
|
|
|
|
|
_map = safe_map
|
|
|
|
zip = safe_zip
|
|
|
|
_reduce = functools.reduce
|
|
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
Array = Any
|
2022-01-18 15:12:20 -05:00
|
|
|
BooleanNumeric = Any # A bool, or a Boolean array.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
@cache()
|
2021-03-23 19:47:58 +00:00
|
|
|
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
2021-05-01 12:28:12 -07:00
|
|
|
primitive_name: Optional[str] = None):
|
2020-10-17 14:33:26 -04:00
|
|
|
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
2021-05-01 12:28:12 -07:00
|
|
|
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
|
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
2020-12-30 10:02:18 -08:00
|
|
|
return jaxpr, consts, out_tree()
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
@cache()
|
2021-03-23 19:47:58 +00:00
|
|
|
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
|
2021-05-01 12:28:12 -07:00
|
|
|
primitive_name: Optional[str] = None):
|
|
|
|
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
|
|
|
|
fun, in_tree, in_avals, primitive_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
return closed_jaxpr, consts, out_tree
|
|
|
|
|
2020-12-30 10:02:18 -08:00
|
|
|
@cache()
|
2021-05-01 12:28:12 -07:00
|
|
|
def _initial_style_jaxprs_with_common_consts(
|
|
|
|
funs: Sequence[Callable], in_tree, in_avals, primitive_name: str):
|
2020-10-17 14:33:26 -04:00
|
|
|
# When staging the branches of a conditional into jaxprs, constants are
|
|
|
|
# extracted from each branch and converted to jaxpr arguments. To use the
|
|
|
|
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
|
|
|
# their (input) signatures to match. This function "joins" the staged jaxprs:
|
|
|
|
# for each one, it makes another that accepts *all* constants, but only uses
|
|
|
|
# those that it needs (dropping the rest).
|
|
|
|
|
2021-05-01 12:28:12 -07:00
|
|
|
jaxprs, all_consts, all_out_trees = \
|
|
|
|
unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
|
|
|
|
for fun in funs)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
newvar = core.gensym(jaxprs, suffix='_')
|
|
|
|
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
|
2020-12-30 10:02:18 -08:00
|
|
|
for consts in all_consts]
|
2020-10-17 14:33:26 -04:00
|
|
|
unused_const_vars = [[newvar(aval) for aval in const_avals]
|
2020-12-30 10:02:18 -08:00
|
|
|
for const_avals in all_const_avals]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def pad_jaxpr_constvars(i, jaxpr):
|
|
|
|
prefix = util.concatenate(unused_const_vars[:i])
|
2020-12-30 10:02:18 -08:00
|
|
|
suffix = util.concatenate(unused_const_vars[i + 1:])
|
2020-10-17 14:33:26 -04:00
|
|
|
constvars = [*prefix, *jaxpr.constvars, *suffix]
|
2022-02-28 13:36:39 -08:00
|
|
|
return jaxpr.replace(constvars=constvars)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
consts = util.concatenate(all_consts)
|
|
|
|
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
|
|
|
|
closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
2020-12-30 10:02:18 -08:00
|
|
|
for jaxpr in jaxprs]
|
2020-10-17 14:33:26 -04:00
|
|
|
return closed_jaxprs, consts, all_out_trees
|
|
|
|
|
|
|
|
def _abstractify(x):
|
|
|
|
return raise_to_shaped(core.get_aval(x))
|
|
|
|
|
|
|
|
def _typecheck_param(prim, param, name, msg_required, pred):
|
2022-01-18 10:41:43 -08:00
|
|
|
if not pred:
|
|
|
|
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
|
|
|
|
f'{msg_required} required:')
|
|
|
|
param_str = str(param)
|
|
|
|
sep = os.linesep if os.linesep in param_str else ' '
|
|
|
|
msg = sep.join([msg, param_str])
|
|
|
|
raise core.JaxprTypeError(msg)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
### fori_loop and while_loop
|
|
|
|
|
|
|
|
def _fori_cond_fun(loop_carry):
|
|
|
|
i, upper, _ = loop_carry
|
|
|
|
return lax.lt(i, upper)
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _fori_body_fun(body_fun):
|
|
|
|
def while_body_fun(loop_carry):
|
|
|
|
i, upper, x = loop_carry
|
|
|
|
return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
|
|
|
|
return while_body_fun
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _fori_scan_body_fun(body_fun):
|
|
|
|
def scanned_fun(loop_carry, _):
|
2021-04-23 08:31:11 -07:00
|
|
|
i, x = loop_carry
|
|
|
|
return (i + 1, body_fun(i, x)), None
|
2020-10-17 14:33:26 -04:00
|
|
|
return scanned_fun
|
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2020-10-17 14:33:26 -04:00
|
|
|
def fori_loop(lower, upper, body_fun, init_val):
|
|
|
|
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
|
|
|
|
|
|
|
|
The type signature in brief is
|
|
|
|
|
|
|
|
.. code-block:: haskell
|
|
|
|
|
2021-06-07 12:35:23 +00:00
|
|
|
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
The semantics of ``fori_loop`` are given by this Python implementation::
|
|
|
|
|
|
|
|
def fori_loop(lower, upper, body_fun, init_val):
|
|
|
|
val = init_val
|
|
|
|
for i in range(lower, upper):
|
|
|
|
val = body_fun(i, val)
|
|
|
|
return val
|
|
|
|
|
2021-04-23 08:31:11 -07:00
|
|
|
Unlike that Python version, ``fori_loop`` is implemented in terms of either a
|
|
|
|
call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the
|
|
|
|
trip count is static (meaning known at tracing time, perhaps because ``lower``
|
2021-06-28 23:35:26 -07:00
|
|
|
and ``upper`` are Python integer literals) then the ``fori_loop`` is
|
2021-04-23 08:31:11 -07:00
|
|
|
implemented in terms of ``scan`` and reverse-mode autodiff is supported;
|
|
|
|
otherwise, a ``while_loop`` is used and reverse-mode autodiff is not
|
|
|
|
supported. See those functions' docstrings for more information.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
|
|
|
fixed shape and dtype across all iterations (and not just be consistent up to
|
|
|
|
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
|
|
|
other words, the type ``a`` in the type signature above represents an array
|
|
|
|
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
|
|
|
structure with a fixed structure and arrays with fixed shape and dtype at the
|
|
|
|
leaves).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lower: an integer representing the loop index lower bound (inclusive)
|
|
|
|
upper: an integer representing the loop index upper bound (exclusive)
|
|
|
|
body_fun: function of type ``(int, a) -> a``.
|
|
|
|
init_val: initial loop carry value of type ``a``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Loop value from the final iteration, of type ``a``.
|
|
|
|
"""
|
2022-04-13 10:39:01 -07:00
|
|
|
if not callable(body_fun):
|
|
|
|
raise TypeError("lax.fori_loop: body_fun argument should be callable.")
|
2020-10-17 14:33:26 -04:00
|
|
|
# TODO(phawkins): perhaps do more type checking here, better error messages.
|
|
|
|
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
|
|
|
|
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
|
|
|
|
if lower_dtype != upper_dtype:
|
|
|
|
msg = ("lower and upper arguments to fori_loop must have equal types, "
|
|
|
|
"got {} and {}")
|
|
|
|
raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
|
|
|
|
|
|
|
|
# If we can specialize on the trip count, call scan instead of a while_loop
|
|
|
|
# to enable efficient reverse-mode differentiation.
|
2021-05-01 12:28:12 -07:00
|
|
|
if (isinstance(core.get_aval(lower), ConcreteArray) and
|
|
|
|
isinstance(core.get_aval(upper), ConcreteArray)):
|
|
|
|
try:
|
|
|
|
lower_ = int(lower)
|
|
|
|
upper_ = int(upper)
|
|
|
|
except TypeError:
|
|
|
|
use_scan = False
|
|
|
|
else:
|
|
|
|
use_scan = True
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-05-01 12:28:12 -07:00
|
|
|
use_scan = False
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
if use_scan:
|
2021-10-11 18:56:15 +02:00
|
|
|
if config.jax_disable_jit and upper_ == lower_:
|
|
|
|
# non-jit implementation of scan does not support length=0
|
|
|
|
return init_val
|
|
|
|
|
2021-04-23 08:31:11 -07:00
|
|
|
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
|
|
|
|
None, length=upper_ - lower_)
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
|
|
|
|
(lower, upper, init_val))
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2022-01-18 15:12:20 -05:00
|
|
|
def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
2020-10-17 14:33:26 -04:00
|
|
|
body_fun: Callable[[T], T],
|
|
|
|
init_val: T) -> T:
|
|
|
|
"""Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
|
|
|
|
|
|
|
|
The type signature in brief is
|
|
|
|
|
|
|
|
.. code-block:: haskell
|
|
|
|
|
|
|
|
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
|
|
|
|
|
|
|
|
The semantics of ``while_loop`` are given by this Python implementation::
|
|
|
|
|
|
|
|
def while_loop(cond_fun, body_fun, init_val):
|
|
|
|
val = init_val
|
|
|
|
while cond_fun(val):
|
|
|
|
val = body_fun(val)
|
|
|
|
return val
|
|
|
|
|
|
|
|
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
|
|
|
to a single XLA While HLO. That makes it useful for reducing compilation times
|
|
|
|
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
|
|
|
function are unrolled, leading to large XLA computations.
|
|
|
|
|
|
|
|
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
|
|
|
fixed shape and dtype across all iterations (and not just be consistent up to
|
|
|
|
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
|
|
|
other words, the type ``a`` in the type signature above represents an array
|
|
|
|
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
|
|
|
structure with a fixed structure and arrays with fixed shape and dtype at the
|
|
|
|
leaves).
|
|
|
|
|
|
|
|
Another difference from using Python-native loop constructs is that
|
|
|
|
``while_loop`` is not reverse-mode differentiable because XLA computations
|
|
|
|
require static bounds on memory requirements.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
cond_fun: function of type ``a -> Bool``.
|
|
|
|
body_fun: function of type ``a -> a``.
|
|
|
|
init_val: value of type ``a``, a type that can be a scalar, array, or any
|
|
|
|
pytree (nested Python tuple/list/dict) thereof, representing the initial
|
|
|
|
loop carry value.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The output from the final iteration of body_fun, of type ``a``.
|
|
|
|
"""
|
2022-04-13 10:39:01 -07:00
|
|
|
if not (callable(body_fun) and callable(cond_fun)):
|
|
|
|
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
|
2021-03-24 11:20:35 -04:00
|
|
|
if config.jax_disable_jit:
|
2020-10-17 14:33:26 -04:00
|
|
|
try:
|
|
|
|
val = init_val
|
|
|
|
while cond_fun(val):
|
|
|
|
val = body_fun(val)
|
|
|
|
return val
|
|
|
|
except core.ConcretizationTypeError:
|
|
|
|
# Can't run this while_loop in Python (e.g. because there's a vmap
|
|
|
|
# transformation on it), so we fall back to the primitive version.
|
|
|
|
pass
|
|
|
|
|
|
|
|
def _create_jaxpr(init_val):
|
|
|
|
init_vals, in_tree = tree_flatten((init_val,))
|
|
|
|
init_avals = tuple(_map(_abstractify, init_vals))
|
2021-05-01 12:28:12 -07:00
|
|
|
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
|
|
|
|
cond_fun, in_tree, init_avals, "while_cond")
|
|
|
|
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
|
|
|
|
body_fun, in_tree, init_avals, "while_loop")
|
2020-10-17 14:33:26 -04:00
|
|
|
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
|
|
|
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
|
|
|
raise TypeError(msg.format(cond_tree))
|
2021-02-24 11:03:11 -08:00
|
|
|
pred_aval = cond_jaxpr.out_avals[0]
|
|
|
|
if (not isinstance(pred_aval, ShapedArray)
|
|
|
|
or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)):
|
2020-10-17 14:33:26 -04:00
|
|
|
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
|
|
|
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
|
|
|
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
|
|
|
|
|
|
|
|
# The body input and output avals must match exactly. However, we want to account for
|
|
|
|
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
|
|
|
|
# may not match the output despite being compatible by virtue of their weak type.
|
|
|
|
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
|
|
|
|
# necessary, a second time with modified init values.
|
|
|
|
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
|
|
|
|
new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
|
|
|
|
if changed:
|
|
|
|
new_init_val, = tree_unflatten(in_tree, new_init_vals)
|
|
|
|
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
|
|
|
|
cond_jaxpr, cond_consts, body_consts, body_tree = rest
|
|
|
|
|
|
|
|
in_tree_children = in_tree.children()
|
|
|
|
assert len(in_tree_children) == 1
|
|
|
|
_check_tree_and_avals("body_fun output and input",
|
|
|
|
body_tree, body_jaxpr.out_avals,
|
|
|
|
in_tree_children[0], init_avals)
|
2022-04-12 13:32:43 -07:00
|
|
|
if cond_jaxpr.effects or body_jaxpr.effects:
|
|
|
|
raise NotImplementedError('Effects not supported in `while`.')
|
2021-07-09 12:38:34 -07:00
|
|
|
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
|
2020-10-17 14:33:26 -04:00
|
|
|
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
|
|
|
|
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
|
|
|
|
return tree_unflatten(body_tree, outs)
|
|
|
|
|
2022-04-12 13:32:43 -07:00
|
|
|
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
|
2022-02-28 13:36:39 -08:00
|
|
|
del args, kwargs
|
2022-04-12 13:32:43 -07:00
|
|
|
if cond_jaxpr.effects or body_jaxpr.effects:
|
|
|
|
raise NotImplementedError('Effects not supported in `while_loop`.')
|
|
|
|
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
|
|
|
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
def _while_loop_batching_rule(axis_size, axis_name, main_type, args, dims,
|
2020-10-26 10:11:13 +00:00
|
|
|
cond_nconsts, cond_jaxpr,
|
2020-10-17 14:33:26 -04:00
|
|
|
body_nconsts, body_jaxpr):
|
|
|
|
orig_batched = [d is not batching.not_mapped for d in dims]
|
|
|
|
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
|
2021-07-07 11:03:59 -07:00
|
|
|
cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts])
|
|
|
|
cconst_dims, bconst_dims, init_dims = split_list(dims, [cond_nconsts, body_nconsts])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-07-07 11:03:59 -07:00
|
|
|
carry_bat = init_bat
|
2020-10-17 14:33:26 -04:00
|
|
|
# Fixpoint computation of which carry are batched: either
|
|
|
|
# batched from init, or the carry out is batched. Each iteration promotes
|
2021-07-07 11:03:59 -07:00
|
|
|
# at least one carry to batched. We need at most len(carry) iterations to
|
|
|
|
# reach a fixpoint.
|
2020-10-17 14:33:26 -04:00
|
|
|
for _ in range(1 + len(carry_bat)):
|
2021-07-07 11:03:59 -07:00
|
|
|
_, carry_bat_out = batching.batch_jaxpr(
|
2021-10-11 10:33:21 +00:00
|
|
|
body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat,
|
2021-04-20 08:32:41 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
2021-07-07 11:03:59 -07:00
|
|
|
if carry_bat == carry_bat_out:
|
2020-10-17 14:33:26 -04:00
|
|
|
break
|
2021-07-07 11:03:59 -07:00
|
|
|
carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out)
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
2021-07-07 11:03:59 -07:00
|
|
|
# Knowing how the carry is batched now, we can determine if the predicate is
|
|
|
|
# batched.
|
|
|
|
_, (pred_bat,) = batching.batch_jaxpr(
|
2021-08-26 13:34:01 -07:00
|
|
|
cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False,
|
2021-07-07 11:03:59 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
|
|
|
|
|
|
|
if pred_bat:
|
|
|
|
# If the predicate is batched, we have to batch *all* of the carry
|
|
|
|
# regardless of if the body needs it.
|
|
|
|
carry_bat = [True] * len(carry_bat)
|
|
|
|
carry_dims = [0] * len(carry_bat)
|
|
|
|
body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
|
2021-08-26 13:34:01 -07:00
|
|
|
body_jaxpr, axis_size, bconst_dims + carry_dims,
|
2021-07-07 11:03:59 -07:00
|
|
|
carry_dims, axis_name=axis_name, main_type=main_type)
|
|
|
|
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
|
2021-08-26 13:34:01 -07:00
|
|
|
cond_jaxpr, axis_size, cconst_dims + carry_dims, [0],
|
2021-07-07 11:03:59 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
|
|
|
else:
|
|
|
|
# If the predicate is not batched, we can look at the `cond_jaxpr`'s out
|
2021-10-06 14:18:07 -07:00
|
|
|
# shape to determine the rank of the predicate. From this rank we pick the
|
|
|
|
# dims of the carry to be batched to ensure that the predicate shape is a
|
|
|
|
# prefix of the carry in and out shapes. We can then batch the `body_jaxpr`
|
|
|
|
# according to these new batch dims.
|
2021-07-07 11:03:59 -07:00
|
|
|
cond_rank = len(cond_jaxpr.out_avals[0].shape)
|
|
|
|
carry_dims = [cond_rank if b else None for b in carry_bat]
|
|
|
|
body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
|
2021-08-26 13:34:01 -07:00
|
|
|
body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims,
|
2021-07-07 11:03:59 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
|
|
|
# Now we need to rebatch the `cond_jaxpr` according to the new dims of the
|
|
|
|
# carry.
|
|
|
|
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
|
2021-08-26 13:34:01 -07:00
|
|
|
cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,),
|
2021-07-07 11:03:59 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
|
|
|
|
|
|
|
# To prepare the `init` to the `while_p`, we broadcast values if they are
|
|
|
|
# unbatched and need to have an out axis. If their current batch axis does not
|
|
|
|
# match the one it needs to be for the translation rule to work, we move it
|
|
|
|
# into place.
|
|
|
|
new_init = []
|
|
|
|
for x, old_axis, new_axis in zip(init, init_dims, carry_dims):
|
|
|
|
if old_axis is batching.not_mapped and new_axis is not batching.not_mapped:
|
2021-08-26 13:34:01 -07:00
|
|
|
new_init.append(batching.broadcast(x, axis_size, new_axis))
|
2021-07-07 11:03:59 -07:00
|
|
|
elif old_axis is batching.not_mapped and new_axis is batching.not_mapped:
|
|
|
|
new_init.append(x)
|
|
|
|
else:
|
|
|
|
assert new_axis is not batching.not_mapped
|
|
|
|
new_init.append(batching.moveaxis(x, old_axis, new_axis))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-07-07 11:03:59 -07:00
|
|
|
outs = while_p.bind(*(cconsts + bconsts + new_init),
|
2020-10-17 14:33:26 -04:00
|
|
|
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
|
|
|
|
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
|
2021-07-07 11:03:59 -07:00
|
|
|
return outs, carry_dims
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
|
|
|
|
body_jaxpr):
|
|
|
|
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts])
|
|
|
|
|
|
|
|
carry_nz = init_nz
|
|
|
|
for _ in range(1 + len(carry_nz)):
|
|
|
|
body_nonzeros = bconst_nz + carry_nz
|
|
|
|
body_jvp, nonzeros_out = ad.jvp_jaxpr(
|
|
|
|
body_jaxpr, body_nonzeros, instantiate=carry_nz)
|
|
|
|
if nonzeros_out == carry_nz:
|
|
|
|
break
|
|
|
|
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
|
|
|
nonzeros = cconst_nz + body_nonzeros
|
|
|
|
tangents = [ad.instantiate_zeros(t) if nz else t
|
|
|
|
for t, nz in zip(tangents, nonzeros)]
|
|
|
|
|
|
|
|
cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts])
|
|
|
|
_, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts])
|
|
|
|
bconst_dot = _prune_zeros(bconst_dot)
|
|
|
|
init_dot = _prune_zeros(init_dot)
|
|
|
|
|
|
|
|
num_carry = len(primals) - cond_nconsts - body_nconsts
|
|
|
|
|
|
|
|
body_jvp_rearranged = ad.rearrange_binders(
|
|
|
|
body_jvp,
|
|
|
|
[body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
|
|
|
|
[num_carry], [len(init_dot)])
|
|
|
|
|
|
|
|
newvar = core.gensym([cond_jaxpr.jaxpr])
|
|
|
|
invars_aug = (
|
|
|
|
cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot])
|
|
|
|
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
|
|
|
|
invars_aug,
|
|
|
|
cond_jaxpr.jaxpr.outvars,
|
2022-02-28 13:36:39 -08:00
|
|
|
cond_jaxpr.jaxpr.eqns,
|
|
|
|
cond_jaxpr.jaxpr.effects)
|
2020-10-17 14:33:26 -04:00
|
|
|
cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts)
|
|
|
|
|
|
|
|
out = while_p.bind(
|
|
|
|
*(cconst + bconst + bconst_dot + init + init_dot),
|
|
|
|
cond_nconsts=cond_nconsts,
|
|
|
|
cond_jaxpr=cond_jaxpr_augmented,
|
|
|
|
body_nconsts=len(bconst) + len(bconst_dot),
|
|
|
|
body_jaxpr=body_jvp_rearranged)
|
|
|
|
|
|
|
|
out_carry, out_carry_dot = split_list(out, [num_carry])
|
|
|
|
out_tangents_iter = iter(out_carry_dot)
|
|
|
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
|
|
|
for p, nz in zip(out_carry, nonzeros_out)]
|
|
|
|
return out_carry, out_tangents
|
|
|
|
|
|
|
|
def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
|
|
|
|
cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int,
|
|
|
|
body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]:
|
2022-04-23 15:32:30 -07:00
|
|
|
# As long as some carry (and hence output) are known and the output of
|
|
|
|
# `cond_jaxpr` is known, we use a portion of the loop body to compute the
|
|
|
|
# known outputs of the `while_loop`. For the unknown outputs we generate a
|
|
|
|
# jaxpr to run the whole while, including recomputing the known parts,
|
|
|
|
# basically like building in checkpointing/rematieralization. This means that
|
|
|
|
# we don't actually save any computation by partial evaluation if there are
|
|
|
|
# unknown outputs.
|
|
|
|
#
|
|
|
|
# What this achieves is twofold: jax.linearize works, and we can give a proper
|
|
|
|
# error for reverse differentiation of `while`.
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
unknowns = [not t.pval.is_known() for t in tracers]
|
|
|
|
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
|
|
|
|
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
|
|
|
|
|
2022-04-23 15:32:30 -07:00
|
|
|
cond_consts_uk, body_consts_uk, carry_init_uk = \
|
|
|
|
split_list(unknowns, [cond_nconsts, body_nconsts])
|
|
|
|
|
|
|
|
# Fixpoint computation of unknown carry. Each iteration promotes at least one
|
|
|
|
# carry to unknown. We need one last iteration to prepare the jaxpr.
|
2020-10-17 14:33:26 -04:00
|
|
|
carry_uk = carry_init_uk
|
|
|
|
for _ in range(1 + len(carry_uk)):
|
2022-04-23 15:32:30 -07:00
|
|
|
body_jaxpr_known, _, carry_out_uk, body_res_avals = pe.partial_eval_jaxpr_nounits( # type: ignore
|
2020-10-17 14:33:26 -04:00
|
|
|
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
|
|
|
|
if carry_out_uk == carry_uk:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
2022-04-23 15:32:30 -07:00
|
|
|
cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore
|
2020-10-17 14:33:26 -04:00
|
|
|
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
|
|
|
|
|
|
|
|
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
|
|
|
|
# If conditional is unknown, or all inputs are known, or all are unknown,
|
|
|
|
# just do the default processing.
|
|
|
|
return trace.default_process_primitive(while_p, tracers, params)
|
|
|
|
|
2022-04-23 15:32:30 -07:00
|
|
|
# Run the known part of the while.
|
|
|
|
in_consts = [t.pval.get_known() for uk, t in
|
|
|
|
zip(cond_consts_uk + body_consts_uk + carry_uk, tracers)
|
|
|
|
if not uk]
|
|
|
|
cond_nconsts_known = len(cond_consts_uk) - sum(cond_consts_uk)
|
|
|
|
body_nconsts_known = len(body_consts_uk) - sum(body_consts_uk)
|
|
|
|
num_known_outs = len(carry_uk) - sum(carry_uk)
|
|
|
|
# TODO(mattjj): use pe.dce_jaxpr to drop res computations and not just outputs
|
|
|
|
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:num_known_outs]
|
2020-10-17 14:33:26 -04:00
|
|
|
out_known = while_p.bind(
|
2022-04-23 15:32:30 -07:00
|
|
|
*in_consts, cond_nconsts=cond_nconsts_known, cond_jaxpr=cond_jaxpr_known,
|
|
|
|
body_nconsts=body_nconsts_known, body_jaxpr=body_jaxpr_known)
|
|
|
|
del body_jaxpr_known
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
# Run the whole while_loop to get all the outputs, then merge with known ones
|
2022-04-23 15:32:30 -07:00
|
|
|
out_tracers_ = trace.default_process_primitive(while_p, tracers, params)
|
|
|
|
out_tracers = [t for t, uk in zip(out_tracers_, carry_uk) if uk]
|
|
|
|
return util.merge_lists(carry_uk, out_known, out_tracers)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _while_transpose_error(*_, **kwargs):
|
|
|
|
raise ValueError("Reverse-mode differentiation does not work for "
|
|
|
|
"lax.while_loop or lax.fori_loop. "
|
|
|
|
"Try using lax.scan instead.")
|
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
while_p = core.AxisPrimitive('while')
|
2020-10-17 14:33:26 -04:00
|
|
|
while_p.multiple_results = True
|
|
|
|
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
2022-02-28 13:36:39 -08:00
|
|
|
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_jvps[while_p] = _while_loop_jvp
|
|
|
|
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
2022-04-13 07:25:51 -07:00
|
|
|
xla.register_initial_style_primitive(while_p)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_transposes[while_p] = _while_transpose_error
|
2021-08-26 13:34:01 -07:00
|
|
|
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
|
2022-01-19 20:40:38 -08:00
|
|
|
pe.partial_eval_jaxpr_custom_rules[while_p] = \
|
|
|
|
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'while_loop')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def _pred_bcast_select_mhlo(
|
|
|
|
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
|
|
|
|
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
|
|
|
|
if x_y_aval is core.abstract_unit:
|
|
|
|
return []
|
|
|
|
elif x_y_aval is core.abstract_token:
|
|
|
|
x, = xs
|
|
|
|
y, = ys
|
|
|
|
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
|
|
|
|
else:
|
|
|
|
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
|
|
|
|
x, = xs
|
|
|
|
y, = ys
|
|
|
|
assert x.type == y.type, (x.type, y.type)
|
|
|
|
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
|
|
|
|
pred_aval.shape, x_y_aval)
|
|
|
|
bcast_pred = mhlo.BroadcastInDimOp(
|
|
|
|
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
|
|
|
|
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
|
|
|
|
return mhlo.SelectOp(bcast_pred, x, y).results
|
|
|
|
|
|
|
|
|
2022-03-04 10:25:22 -05:00
|
|
|
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
|
|
|
body_nconsts):
|
|
|
|
pred_aval = cond_jaxpr.out_avals[0]
|
|
|
|
batched = bool(pred_aval.shape)
|
|
|
|
|
|
|
|
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
|
|
|
|
flat_loop_carry_types = util.flatten(loop_carry_types)
|
|
|
|
|
|
|
|
flat_args = mlir.flatten_lowering_ir_args(args)
|
|
|
|
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)
|
|
|
|
|
|
|
|
# Loop condition
|
|
|
|
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
|
2022-04-13 06:28:03 -07:00
|
|
|
name_stack = extend_name_stack(ctx.module_context.name_stack, 'while')
|
2022-04-19 10:45:09 -07:00
|
|
|
if cond_jaxpr.effects:
|
|
|
|
raise NotImplementedError('`while_loop` with effects in `cond` not supported.')
|
2022-03-04 10:25:22 -05:00
|
|
|
with ir.InsertionPoint(cond_block):
|
|
|
|
flat_cond_args = [
|
|
|
|
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
|
|
|
|
]
|
|
|
|
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
|
|
|
|
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
|
|
|
|
cond_ctx = ctx.module_context.replace(
|
2022-04-13 06:28:03 -07:00
|
|
|
name_stack=xla.extend_name_stack(name_stack, 'cond'))
|
2022-04-19 10:45:09 -07:00
|
|
|
((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
|
|
|
|
_map(mlir.ir_constants, cond_jaxpr.consts),
|
|
|
|
*(x + z))
|
2022-03-04 10:25:22 -05:00
|
|
|
if batched:
|
|
|
|
pred_ctx = mlir.LoweringRuleContext(
|
|
|
|
module_context=ctx.module_context,
|
|
|
|
primitive=None,
|
|
|
|
avals_in=[pred_aval],
|
2022-04-19 10:45:09 -07:00
|
|
|
avals_out=[pred_aval.update(shape=())],
|
|
|
|
tokens_in=mlir.TokenSet(),
|
|
|
|
tokens_out=None)
|
2022-03-04 10:25:22 -05:00
|
|
|
pred, = lax._unary_reduce_lower(
|
|
|
|
mhlo.OrOp,
|
|
|
|
lambda dtype: np.array(False, dtype),
|
|
|
|
pred_ctx,
|
|
|
|
pred,
|
|
|
|
axes=tuple(range(len(pred_aval.shape))))
|
|
|
|
mhlo.ReturnOp([pred])
|
|
|
|
|
|
|
|
# Loop body
|
|
|
|
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
|
|
|
|
with ir.InsertionPoint(body_block):
|
|
|
|
flat_body_args = [
|
|
|
|
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
|
|
|
|
]
|
|
|
|
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
|
|
|
|
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
|
|
|
|
body_ctx = ctx.module_context.replace(
|
2022-04-13 06:28:03 -07:00
|
|
|
name_stack=xla.extend_name_stack(name_stack, 'body'))
|
2022-04-19 10:45:09 -07:00
|
|
|
if body_jaxpr.effects:
|
|
|
|
raise NotImplementedError('`while_loop` with effects in `body` not supported.')
|
|
|
|
new_z, _ = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr, mlir.TokenSet(),
|
|
|
|
_map(mlir.ir_constants, body_jaxpr.consts),
|
|
|
|
*(y + z))
|
2022-03-04 10:25:22 -05:00
|
|
|
if batched:
|
|
|
|
body_pred_ctx = ctx.module_context.replace(
|
2022-04-13 06:28:03 -07:00
|
|
|
name_stack=xla.extend_name_stack(name_stack,
|
2022-03-04 10:25:22 -05:00
|
|
|
'body_pred'))
|
2022-04-19 10:45:09 -07:00
|
|
|
((body_pred,),), _ = mlir.jaxpr_subcomp(
|
|
|
|
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
|
2022-03-04 10:25:22 -05:00
|
|
|
_map(mlir.ir_constants, cond_jaxpr.consts), *(x + z))
|
|
|
|
new_z = _map(
|
|
|
|
partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z,
|
|
|
|
body_jaxpr.out_avals)
|
|
|
|
|
|
|
|
mhlo.ReturnOp([*util.flatten(x), *util.flatten(y), *util.flatten(new_z)])
|
|
|
|
|
|
|
|
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
|
|
|
|
_, _, z = util.split_list(outputs, [cond_nconsts, body_nconsts])
|
|
|
|
return z
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(while_p, _while_lowering)
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
### cond and switch
|
|
|
|
|
2021-11-05 09:11:01 -07:00
|
|
|
|
|
|
|
# For backward compatibility with a previous switch/cond calling convention,
|
|
|
|
# we allow a single (pytree) `operand` argument to be passed by keyword. We use
|
|
|
|
# a sentinel object as its default value to indicate when it is _not_ passed.
|
|
|
|
_no_operand_sentinel = object()
|
|
|
|
|
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2021-11-05 09:11:01 -07:00
|
|
|
def switch(index, branches: Sequence[Callable], *operands,
|
|
|
|
operand=_no_operand_sentinel):
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Apply exactly one of ``branches`` given by ``index``.
|
|
|
|
|
|
|
|
If ``index`` is out of bounds, it is clamped to within bounds.
|
|
|
|
|
|
|
|
Has the semantics of the following Python::
|
|
|
|
|
|
|
|
def switch(index, branches, operand):
|
|
|
|
index = clamp(0, index, len(branches) - 1)
|
|
|
|
return branches[index](operand)
|
|
|
|
|
2021-01-15 11:49:19 +11:00
|
|
|
Args:
|
2020-10-17 14:33:26 -04:00
|
|
|
index: Integer scalar type, indicating which branch function to apply.
|
2021-11-05 09:11:01 -07:00
|
|
|
branches: Sequence of functions (A -> B) to be applied based on ``index``.
|
|
|
|
operands: Operands (A) input to whichever branch is applied.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Value (B) of ``branch(*operands)`` for the branch that was selected based
|
|
|
|
on ``index``.
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
2022-04-13 10:39:01 -07:00
|
|
|
if not all(callable(branch) for branch in branches):
|
|
|
|
raise TypeError("lax.switch: branches argument should be a sequence of callables.")
|
2021-11-05 09:11:01 -07:00
|
|
|
if operand is not _no_operand_sentinel:
|
|
|
|
if operands:
|
|
|
|
raise TypeError("if 'operand' keyword is passed then no positional "
|
|
|
|
f"operands can be passed, got operand={operand} "
|
|
|
|
f"and positional operands {operands}")
|
|
|
|
operands = (operand,)
|
|
|
|
del operand
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
if len(np.shape(index)) != 0:
|
|
|
|
raise TypeError(
|
|
|
|
f"Branch index must be scalar, "
|
|
|
|
f"got {index} of shape {np.shape(index)}.")
|
|
|
|
|
|
|
|
try:
|
|
|
|
index_dtype = dtypes.result_type(index)
|
|
|
|
except TypeError as err:
|
|
|
|
msg = f"Index type must be an integer, got {index}."
|
|
|
|
raise TypeError(msg) from err
|
|
|
|
|
|
|
|
if index_dtype.kind not in 'iu':
|
|
|
|
raise TypeError(
|
|
|
|
f"Index type must be an integer, got {index} as {index_dtype}")
|
|
|
|
|
|
|
|
branches = tuple(branches)
|
|
|
|
|
|
|
|
if len(branches) == 0:
|
|
|
|
raise ValueError("Empty branch sequence")
|
|
|
|
elif len(branches) == 1:
|
2021-11-05 09:11:01 -07:00
|
|
|
return branches[0](*operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
index = lax.convert_element_type(index, np.int32)
|
|
|
|
lo = np.array(0, np.int32)
|
|
|
|
hi = np.array(len(branches) - 1, np.int32)
|
|
|
|
index = lax.clamp(lo, index, hi)
|
|
|
|
|
2021-03-24 11:20:35 -04:00
|
|
|
if (config.jax_disable_jit and
|
2020-10-17 14:33:26 -04:00
|
|
|
isinstance(core.get_aval(index), ConcreteArray)):
|
2021-11-05 09:11:01 -07:00
|
|
|
return branches[int(index)](*operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-05 09:11:01 -07:00
|
|
|
ops, ops_tree = tree_flatten(operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
ops_avals = tuple(_map(_abstractify, ops))
|
|
|
|
|
|
|
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
2021-05-01 12:28:12 -07:00
|
|
|
branches, ops_tree, ops_avals, primitive_name='switch')
|
2020-10-17 14:33:26 -04:00
|
|
|
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
|
|
|
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
|
|
|
out_trees[0], jaxprs[0].out_avals,
|
|
|
|
out_tree, jaxpr.out_avals)
|
2022-04-12 13:32:43 -07:00
|
|
|
if any(b.effects for b in jaxprs):
|
|
|
|
raise NotImplementedError('Effects not supported in `switch`.')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
linear = (False,) * (len(consts) + len(ops))
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
|
|
|
|
return tree_unflatten(out_trees[0], out)
|
|
|
|
|
|
|
|
|
2021-11-04 15:41:16 -07:00
|
|
|
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
2022-03-24 17:27:06 -07:00
|
|
|
operand=_no_operand_sentinel, linear=None):
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
|
|
|
|
``cond()`` has equivalent semantics to this Python implementation::
|
|
|
|
|
2021-11-04 15:41:16 -07:00
|
|
|
def cond(pred, true_fun, false_fun, *operands):
|
2020-10-17 14:33:26 -04:00
|
|
|
if pred:
|
2021-11-04 15:41:16 -07:00
|
|
|
return true_fun(*operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-11-04 15:41:16 -07:00
|
|
|
return false_fun(*operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
``pred`` must be a scalar type.
|
|
|
|
|
2021-01-15 11:49:19 +11:00
|
|
|
Args:
|
2020-10-17 14:33:26 -04:00
|
|
|
pred: Boolean scalar type, indicating which branch function to apply.
|
|
|
|
true_fun: Function (A -> B), to be applied if ``pred`` is True.
|
|
|
|
false_fun: Function (A -> B), to be applied if ``pred`` is False.
|
2021-11-04 15:41:16 -07:00
|
|
|
operands: Operands (A) input to either branch depending on ``pred``. The
|
|
|
|
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
|
2020-10-17 14:33:26 -04:00
|
|
|
thereof.
|
|
|
|
|
|
|
|
Returns:
|
2021-11-04 15:41:16 -07:00
|
|
|
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
|
2020-10-17 14:33:26 -04:00
|
|
|
depending on the value of ``pred``. The type can be a scalar, array, or any
|
|
|
|
pytree (nested Python tuple/list/dict) thereof.
|
|
|
|
"""
|
2022-04-13 10:39:01 -07:00
|
|
|
if not (callable(true_fun) and callable(false_fun)):
|
|
|
|
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
2021-11-04 15:41:16 -07:00
|
|
|
if operand is not _no_operand_sentinel:
|
|
|
|
if operands:
|
|
|
|
raise TypeError("if 'operand' keyword is passed then no positional "
|
|
|
|
f"operands can be passed, got operand={operand} "
|
|
|
|
f"and positional operands {operands}")
|
|
|
|
operands = (operand,)
|
|
|
|
del operand
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-03 11:06:18 +02:00
|
|
|
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
2020-10-17 14:33:26 -04:00
|
|
|
raise TypeError(
|
2021-02-03 11:06:18 +02:00
|
|
|
f"Pred must be a scalar, got {pred} of " +
|
|
|
|
(f"type {type(pred)}" if isinstance(pred, Sequence)
|
|
|
|
else f"shape {np.shape(pred)}."))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
try:
|
|
|
|
pred_dtype = dtypes.result_type(pred)
|
|
|
|
except TypeError as err:
|
|
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
|
|
raise TypeError(msg.format(pred)) from err
|
|
|
|
|
|
|
|
if pred_dtype.kind != 'b':
|
|
|
|
if pred_dtype.kind in 'iuf':
|
|
|
|
pred = pred != 0
|
|
|
|
else:
|
|
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
|
|
raise TypeError(msg.format(pred_dtype))
|
|
|
|
|
2021-03-24 11:20:35 -04:00
|
|
|
if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray):
|
2020-10-17 14:33:26 -04:00
|
|
|
if pred:
|
2021-11-04 15:41:16 -07:00
|
|
|
return true_fun(*operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-11-04 15:41:16 -07:00
|
|
|
return false_fun(*operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-04 15:41:16 -07:00
|
|
|
ops, ops_tree = tree_flatten(operands)
|
2022-03-24 17:27:06 -07:00
|
|
|
if linear is None:
|
|
|
|
linear_ops = [False] * len(ops)
|
|
|
|
else:
|
|
|
|
linear_ops, ops_tree2 = tree_flatten(linear)
|
|
|
|
if ops_tree != ops_tree2:
|
|
|
|
raise TypeError('linear tree and operand tree mismatch')
|
2020-10-17 14:33:26 -04:00
|
|
|
ops_avals = tuple(_map(_abstractify, ops))
|
|
|
|
|
|
|
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
2021-05-01 12:28:12 -07:00
|
|
|
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
2020-10-17 14:33:26 -04:00
|
|
|
true_jaxpr, false_jaxpr = jaxprs
|
|
|
|
out_tree, false_out_tree = out_trees
|
|
|
|
|
|
|
|
_check_tree_and_avals("true_fun and false_fun output",
|
|
|
|
out_tree, true_jaxpr.out_avals,
|
|
|
|
false_out_tree, false_jaxpr.out_avals)
|
2022-04-12 13:32:43 -07:00
|
|
|
if any(b.effects for b in jaxprs):
|
|
|
|
raise NotImplementedError('Effects not supported in `cond`.')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
index = lax.convert_element_type(pred, np.int32)
|
|
|
|
|
2022-03-24 17:27:06 -07:00
|
|
|
linear = [False] * len(consts) + linear_ops
|
2020-10-17 14:33:26 -04:00
|
|
|
out = cond_p.bind(
|
|
|
|
index, *consts, *ops,
|
2022-03-24 17:27:06 -07:00
|
|
|
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
|
2020-10-17 14:33:26 -04:00
|
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2021-02-05 18:36:01 -08:00
|
|
|
@functools.wraps(_cond)
|
|
|
|
def cond(*args, **kwargs):
|
|
|
|
# detect an attempt to call the former, deprecated cond
|
|
|
|
try:
|
|
|
|
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
else:
|
2021-11-08 14:52:00 -08:00
|
|
|
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
|
|
|
|
_, _, maybe_true_fun, _, maybe_false_fun = ba.args
|
|
|
|
if callable(maybe_true_fun) and callable(maybe_false_fun):
|
|
|
|
return _cond_with_per_branch_args(*ba.args)
|
2021-02-05 18:36:01 -08:00
|
|
|
|
|
|
|
return _cond(*args, **kwargs)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _cond_with_per_branch_args(pred,
|
|
|
|
true_operand, true_fun: Callable,
|
|
|
|
false_operand, false_fun: Callable):
|
|
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
|
|
|
|
Has equivalent semantics to this Python implementation::
|
|
|
|
|
|
|
|
def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
|
|
|
if pred:
|
|
|
|
return true_fun(true_operand)
|
|
|
|
else:
|
|
|
|
return false_fun(false_operand)
|
|
|
|
|
|
|
|
Pred has to be a scalar type, collection types (list, tuple) are not supported
|
|
|
|
"""
|
2022-04-13 10:39:01 -07:00
|
|
|
if not (callable(true_fun) and callable(false_fun)):
|
|
|
|
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
2020-10-17 14:33:26 -04:00
|
|
|
return _cond(pred,
|
|
|
|
lambda op: true_fun(op[0]),
|
|
|
|
lambda op: false_fun(op[1]),
|
|
|
|
(true_operand, false_operand))
|
|
|
|
|
2022-02-28 13:36:39 -08:00
|
|
|
def _cond_abstract_eval(*args, branches, **kwargs):
|
2022-04-12 13:32:43 -07:00
|
|
|
if any(b.effects for b in branches):
|
|
|
|
raise NotImplementedError('Effects not supported in `cond`.')
|
|
|
|
joined_effects = core.join_effects(*(b.effects for b in branches))
|
|
|
|
return _map(raise_to_shaped, branches[0].out_avals), joined_effects
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-05 16:25:23 -08:00
|
|
|
def _bcast_select(pred, on_true, on_false):
|
|
|
|
if np.ndim(pred) != np.ndim(on_true):
|
|
|
|
idx = list(range(np.ndim(pred)))
|
|
|
|
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
|
|
|
|
return lax.select(pred, on_true, on_false)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2022-02-09 11:02:31 -08:00
|
|
|
def _bcast_select_n(pred, *cases):
|
|
|
|
if np.ndim(pred) != np.ndim(cases[0]):
|
|
|
|
idx = list(range(np.ndim(pred)))
|
|
|
|
pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
|
|
|
|
return lax.select_n(pred, *cases)
|
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
|
2020-10-17 14:33:26 -04:00
|
|
|
index, *ops = args
|
2021-02-05 16:25:23 -08:00
|
|
|
index_dim, *op_dims = dims
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-05 16:25:23 -08:00
|
|
|
if index_dim is not batching.not_mapped:
|
|
|
|
# Convert to a lax.select. While we could get away with not broadcasting
|
|
|
|
# some operands yet, because all outputs must be broadcast together anyway
|
|
|
|
# for the select we broadcast the input operands for simplicity and leave
|
|
|
|
# optimizations to XLA.
|
|
|
|
# TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
|
2022-03-22 17:39:31 -07:00
|
|
|
index, *ops = [
|
|
|
|
batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)]
|
|
|
|
|
|
|
|
in_batched = [a is not core.abstract_unit for a in branches[0].in_avals]
|
|
|
|
out_batched = [a is not core.abstract_unit for a in branches[0].out_avals]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-05 16:25:23 -08:00
|
|
|
branches_batched = [
|
2022-03-22 17:39:31 -07:00
|
|
|
batching.batch_jaxpr(
|
|
|
|
jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0]
|
2021-02-05 16:25:23 -08:00
|
|
|
for jaxpr in branches]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
branch_outs = []
|
2021-02-05 16:25:23 -08:00
|
|
|
for i, jaxpr in enumerate(branches_batched):
|
|
|
|
# Perform a select on the inputs for safety of reverse-mode autodiff; see
|
|
|
|
# https://github.com/google/jax/issues/1052
|
2022-02-07 14:12:28 -08:00
|
|
|
predicate = lax.eq(index, lax._const(index, i))
|
|
|
|
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x))
|
2021-02-05 16:25:23 -08:00
|
|
|
if x is not core.unit else x for x in ops]
|
|
|
|
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
|
2022-02-09 11:02:31 -08:00
|
|
|
out = [_bcast_select_n(index, *outs) if outs[0] is not core.unit else outs[0]
|
2021-02-05 16:25:23 -08:00
|
|
|
for outs in zip(*branch_outs)]
|
2022-03-22 17:39:31 -07:00
|
|
|
return out, [0 if b else None for b in out_batched]
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-02-05 16:25:23 -08:00
|
|
|
ops_bat = [d is not batching.not_mapped for d in op_dims]
|
|
|
|
ops = [batching.moveaxis(x, d, 0) if b else x
|
|
|
|
for b, x, d in zip(ops_bat, ops, op_dims)]
|
|
|
|
|
|
|
|
branches_out_bat = [
|
2021-08-26 13:34:01 -07:00
|
|
|
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1]
|
2021-02-05 16:25:23 -08:00
|
|
|
for jaxpr in branches]
|
|
|
|
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
|
|
|
|
branches_batched = tuple(
|
2021-08-26 13:34:01 -07:00
|
|
|
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0]
|
2021-02-05 16:25:23 -08:00
|
|
|
for jaxpr in branches)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *ops, branches=branches_batched, linear=linear)
|
|
|
|
return out, out_dims
|
|
|
|
|
|
|
|
def _cond_jvp(primals, tangents, branches, linear):
|
|
|
|
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
|
|
|
|
index_nz, *ops_nz = nonzeros
|
|
|
|
assert index_nz is False
|
|
|
|
|
|
|
|
branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1]
|
|
|
|
for jaxpr in branches]
|
|
|
|
out_nz = [any(nz) for nz in zip(*branches_out_nz)]
|
|
|
|
|
|
|
|
branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
|
|
|
|
for jaxpr in branches)
|
|
|
|
|
|
|
|
index, *ops = primals
|
|
|
|
_, *ops_dot = tangents
|
|
|
|
ops_dot = _prune_zeros(ops_dot)
|
|
|
|
|
|
|
|
ops_lin = tuple(linear)
|
|
|
|
linear_jvp = ops_lin + (True,) * len(ops_dot)
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp)
|
|
|
|
out_primals, out_tangents = split_list(out, [len(out_nz)])
|
|
|
|
out_tangents_iter = iter(out_tangents)
|
|
|
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
|
|
|
for p, nz in zip(out_primals, out_nz)]
|
|
|
|
return out_primals, out_tangents
|
|
|
|
|
|
|
|
def _cond_partial_eval(trace, *tracers, branches, linear):
|
2022-04-23 15:05:26 -07:00
|
|
|
in_unknowns = [t.pval[0] is not None for t in tracers]
|
|
|
|
index_uk, *ops_uk = in_unknowns
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
if index_uk:
|
|
|
|
# When the branch index is unknown, we stage out the whole cond.
|
2022-04-23 15:05:26 -07:00
|
|
|
# TODO(mattjj): remove this path when old remat is removed
|
2020-10-17 14:33:26 -04:00
|
|
|
params = dict(branches=branches, linear=linear)
|
|
|
|
return trace.default_process_primitive(cond_p, tracers, params)
|
|
|
|
|
|
|
|
branches_out_uks = []
|
|
|
|
for branch_jaxpr in branches:
|
2022-04-23 15:05:26 -07:00
|
|
|
_, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(
|
|
|
|
branch_jaxpr, ops_uk, instantiate=False)
|
2020-10-17 14:33:26 -04:00
|
|
|
branches_out_uks.append(out_uks)
|
|
|
|
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
|
|
|
|
|
2022-04-23 15:05:26 -07:00
|
|
|
branches_known, branches_unknown, branch_res_avals = [], [], []
|
2020-10-17 14:33:26 -04:00
|
|
|
for branch_jaxpr in branches:
|
2022-04-23 15:05:26 -07:00
|
|
|
branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
|
|
|
|
pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
|
|
|
|
branches_known.append(branch_jaxpr_known)
|
|
|
|
branches_unknown.append(branch_jaxpr_unknown)
|
2020-10-17 14:33:26 -04:00
|
|
|
branch_res_avals.append(res_avals)
|
|
|
|
|
2022-04-23 15:05:26 -07:00
|
|
|
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
|
2020-10-17 14:33:26 -04:00
|
|
|
num_res = len(all_res_avals)
|
|
|
|
|
2022-04-23 15:05:26 -07:00
|
|
|
num_known_outs = len(out_uks) - sum(out_uks)
|
|
|
|
branches_known = _join_cond_outputs(
|
|
|
|
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
|
|
|
|
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
|
|
|
|
branches_unknown, all_res_avals, res_avals_per_branch)
|
|
|
|
assert all(all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
|
|
|
for j in branches_known[1:])
|
|
|
|
|
|
|
|
in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
|
|
|
|
linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
|
|
|
|
out_consts_res = cond_p.bind(*in_consts, branches=branches_known,
|
|
|
|
linear=tuple(linear_known))
|
2020-10-17 14:33:26 -04:00
|
|
|
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
|
|
|
|
|
|
|
|
index_tracer = trace.instantiate_const(tracers[0])
|
2022-04-23 15:05:26 -07:00
|
|
|
ops_tracers = [trace.instantiate_const(t)
|
|
|
|
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
|
2020-10-17 14:33:26 -04:00
|
|
|
res_tracers = _map(trace.new_instantiated_const, res)
|
2022-04-23 15:05:26 -07:00
|
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
|
|
|
for aval in branches_unknown[0].out_avals]
|
|
|
|
linear_unknown = ([False] * num_res +
|
|
|
|
[l for l, uk in zip(linear, in_unknowns[1:]) if uk])
|
|
|
|
params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
2020-10-17 14:33:26 -04:00
|
|
|
eqn = pe.new_eqn_recipe(
|
|
|
|
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
|
2022-02-28 13:36:39 -08:00
|
|
|
core.no_effects, source)
|
2020-10-17 14:33:26 -04:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
2022-04-23 15:05:26 -07:00
|
|
|
return util.merge_lists(out_uks, out_consts, out_tracers)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
# When partially evaluating conditionals, each branch produces residuals
|
|
|
|
# depending on the computation carried out by the branch, and a corresponding
|
|
|
|
# staged jaxpr that accepts those residuals as its first few inputs. The
|
|
|
|
# residual-producing branches are staged as jaxprs and bound right away in a
|
|
|
|
# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
|
|
|
|
# conditional. The following helper functions ensure that both collections of
|
|
|
|
# jaxprs (those evaluated and those staged) are valid for joint use under their
|
|
|
|
# respective conditionals.
|
|
|
|
#
|
|
|
|
# In particular, the residuals derived from each original branch may have
|
|
|
|
# distinct types. Because the branches of conditionals must have identical type
|
|
|
|
# signatures, we join residuals together across branches into a common format.
|
|
|
|
|
|
|
|
# In order to set up a type signature that all branches can conform to, it would
|
|
|
|
# suffice to concatenate all branches' residuals. But concatenation can result
|
|
|
|
# in redundant inputs and outputs, and might lead to memory allocation that
|
|
|
|
# scales unnecessarily with the branch count. This function finds common
|
|
|
|
# residual types across branches for reuse, so as to avoid redundant
|
|
|
|
# allocation. It returns a list L of types (avals) representing the collection
|
|
|
|
# of residuals merged according to type, and, for each branch, a lookup table to
|
|
|
|
# match its residuals to their positions/types in L. Example input/output:
|
|
|
|
#
|
|
|
|
# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]]
|
|
|
|
# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]]
|
|
|
|
# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
|
|
|
|
def _merge_branch_residuals(branch_res_avals):
|
|
|
|
def enumerate_equal(xs):
|
|
|
|
counts = {v: itertools.count() for v in set(xs)}
|
|
|
|
return [(x, next(counts[x])) for x in xs]
|
|
|
|
branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
|
|
|
|
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
|
|
|
|
indices = {v: i for i, v in enumerate(all_tagged_avals)}
|
|
|
|
branch_indices = [
|
|
|
|
[indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
|
|
|
|
all_avals = [x for x, _ in all_tagged_avals]
|
|
|
|
return all_avals, branch_indices
|
|
|
|
|
|
|
|
# This function augments branch outputs to agree with the merged residual
|
|
|
|
# format: each branch is made to return zero-filled values in the places of
|
|
|
|
# residual outputs that it does not populate.
|
|
|
|
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
|
|
|
num_non_res_outputs):
|
|
|
|
def augment_jaxpr(jaxpr, res_indices):
|
|
|
|
@lu.wrap_init
|
|
|
|
def f_aug(*args):
|
|
|
|
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
|
|
|
|
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
|
|
|
|
aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
|
|
|
|
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
|
|
|
|
return outs + list(aug_residuals)
|
|
|
|
|
|
|
|
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
|
|
|
|
|
|
|
|
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
|
|
|
|
|
|
|
# This function augments branch inputs to agree with the merged residual format:
|
|
|
|
# each branch is made to accept all residuals, even though it will ignore those
|
|
|
|
# that it does not read.
|
|
|
|
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
|
|
|
res_aval_indices_per_jaxpr):
|
|
|
|
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
|
|
|
|
all_res_vars = _map(newvar, all_res_avals)
|
|
|
|
|
|
|
|
def augment_jaxpr(jaxpr, res_indices):
|
|
|
|
num_res = len(res_indices)
|
|
|
|
res_vars = jaxpr.jaxpr.invars[:num_res]
|
|
|
|
non_res_vars = jaxpr.jaxpr.invars[num_res:]
|
|
|
|
|
|
|
|
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
|
|
|
|
aug_invars = aug_res_vars + non_res_vars
|
|
|
|
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
|
2022-02-28 13:36:39 -08:00
|
|
|
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
|
|
|
|
jaxpr.jaxpr.effects)
|
2020-10-17 14:33:26 -04:00
|
|
|
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
|
|
|
return jaxpr_aug
|
|
|
|
|
|
|
|
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
|
|
|
|
|
|
|
def _ordered_unique(xs):
|
|
|
|
d = collections.OrderedDict((x, None) for x in xs)
|
|
|
|
return list(d.keys())
|
|
|
|
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
|
2020-10-17 14:33:26 -04:00
|
|
|
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
|
|
|
primal_avals = _map(raise_to_shaped, primal_avals)
|
|
|
|
|
|
|
|
@lu.wrap_init
|
|
|
|
def transposed(*args):
|
|
|
|
res, cts_out = split_list(args, [num_res])
|
|
|
|
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
|
|
|
cts_in = ad.backward_pass(
|
2021-10-28 11:06:58 -07:00
|
|
|
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
|
2020-10-17 14:33:26 -04:00
|
|
|
_, cts_in = split_list(cts_in, [num_res])
|
|
|
|
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
|
|
|
|
|
|
|
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
|
|
|
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
|
2020-10-17 14:33:26 -04:00
|
|
|
index, *ops = args
|
|
|
|
in_avals = _map(raise_to_shaped, branches[0].in_avals)
|
|
|
|
num_res = len(ops) - sum(linear)
|
|
|
|
|
|
|
|
branches_trans = tuple(
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
|
2020-11-13 18:08:42 -08:00
|
|
|
lin_in_avals = [raise_to_shaped(a, weak_type=False)
|
|
|
|
for a, l in zip(in_avals, linear) if l]
|
|
|
|
assert all(core.typematch(out_aval, lin_in_aval)
|
|
|
|
for jaxpr in branches_trans
|
|
|
|
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
res = ops[:num_res]
|
|
|
|
cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
|
|
|
|
linear_trans = (False,) * num_res + (True,) * len(cts)
|
|
|
|
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *res, *cts, branches=branches_trans, linear=linear_trans)
|
|
|
|
assert all(_map(core.typecheck, lin_in_avals, out))
|
|
|
|
|
|
|
|
out_iter = iter(out)
|
|
|
|
out = [next(out_iter) if l else None for l in linear]
|
|
|
|
assert next(out_iter, None) is None
|
|
|
|
return [None] + out
|
|
|
|
|
|
|
|
def _avals_short(avals):
|
|
|
|
to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
|
|
|
|
return ' '.join(_map(to_str, avals))
|
|
|
|
|
|
|
|
def _cond_typecheck(*avals, branches, linear):
|
|
|
|
tc = partial(_typecheck_param, 'cond')
|
|
|
|
tc(branches, 'branches', 'tuple of ClosedJaxpr',
|
|
|
|
type(branches) is tuple and
|
|
|
|
all(type(x) is core.ClosedJaxpr for x in branches))
|
|
|
|
tc(linear, 'linear', 'tuple of bool',
|
|
|
|
type(linear) is tuple and all(type(x) is bool for x in linear))
|
|
|
|
|
2022-01-18 10:41:43 -08:00
|
|
|
if len(branches) == 0:
|
|
|
|
raise core.JaxprTypeError('cond requires at least one branch function')
|
|
|
|
if len(linear) + 1 != len(avals):
|
|
|
|
raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for '
|
|
|
|
f'{len(avals) - 1} non-predicate operands')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
jaxpr0 = branches[0]
|
|
|
|
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
|
|
|
|
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
|
2022-04-12 13:32:43 -07:00
|
|
|
if any(b.effects for b in branches):
|
|
|
|
raise NotImplementedError('Effects not supported in `cond`.')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
for i, jaxpr in enumerate(branches[1:]):
|
2022-01-18 10:41:43 -08:00
|
|
|
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
|
|
|
|
f'branch {i+1} takes {len(jaxpr.in_avals)}')
|
2022-01-18 10:41:43 -08:00
|
|
|
if len(jaxpr0.out_avals) != len(jaxpr.out_avals):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
|
|
|
|
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
|
2022-01-18 10:41:43 -08:00
|
|
|
if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'cond branches 0 and {i+1} have mismatching input types: '
|
|
|
|
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
|
2022-01-18 10:41:43 -08:00
|
|
|
if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'cond branches 0 and {i+1} have mismatching output types: '
|
|
|
|
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
|
|
|
|
|
2022-01-18 10:41:43 -08:00
|
|
|
if len(avals) != 1 + len(jaxpr0.in_avals):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'cond called with {len(avals) - 1} non-predicate operands, '
|
|
|
|
f'but branches take {len(jaxpr0.in_avals)} inputs')
|
|
|
|
|
|
|
|
index_aval, *op_avals = avals
|
2022-01-18 10:41:43 -08:00
|
|
|
if index_aval.dtype != np.int32:
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'cond called with index of type {index_aval.dtype} instead of int32')
|
2022-01-18 10:41:43 -08:00
|
|
|
if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'cond branches take input types {jaxpr0_in_avals_str}, '
|
|
|
|
f'called with operands of type {_avals_short(op_avals)}')
|
2022-02-28 13:36:39 -08:00
|
|
|
if any((b.effects != branches[0].effects for b in branches[1:])):
|
|
|
|
raise core.JaxprTypeError(
|
|
|
|
f'cond branches must have matching effect types: '
|
|
|
|
f'{[b.effects for b in branches]}')
|
2022-04-12 13:32:43 -07:00
|
|
|
joined_effects = core.join_effects(*(b.effects for b in branches))
|
|
|
|
return None, joined_effects
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def cond_bind(*args, branches, linear):
|
2021-03-19 13:49:38 -07:00
|
|
|
if config.jax_enable_checks:
|
2020-10-17 14:33:26 -04:00
|
|
|
avals = _map(core.get_aval, args)
|
|
|
|
_cond_typecheck(*avals, branches=branches, linear=linear)
|
|
|
|
for jaxpr in branches:
|
|
|
|
core.check_jaxpr(jaxpr.jaxpr)
|
2021-08-26 13:34:01 -07:00
|
|
|
return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
cond_p = core.AxisPrimitive('cond')
|
2020-10-17 14:33:26 -04:00
|
|
|
cond_p.multiple_results = True
|
|
|
|
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
2022-02-28 13:36:39 -08:00
|
|
|
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
|
2020-10-17 14:33:26 -04:00
|
|
|
cond_p.def_custom_bind(cond_bind)
|
|
|
|
ad.primitive_jvps[cond_p] = _cond_jvp
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
ad.reducing_transposes[cond_p] = _cond_transpose
|
2020-10-17 14:33:26 -04:00
|
|
|
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
2021-08-26 13:34:01 -07:00
|
|
|
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
|
2022-04-13 07:25:51 -07:00
|
|
|
xla.register_initial_style_primitive(cond_p)
|
2020-10-17 14:33:26 -04:00
|
|
|
core.custom_typechecks[cond_p] = _cond_typecheck
|
2022-01-19 20:40:38 -08:00
|
|
|
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
|
|
|
|
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2022-03-04 10:25:22 -05:00
|
|
|
def _cond_lowering(ctx, index, *args, branches, linear):
|
|
|
|
del linear # Unused.
|
|
|
|
output_types = _map(mlir.aval_to_ir_types, ctx.avals_out)
|
|
|
|
flat_output_types = util.flatten(output_types)
|
|
|
|
|
|
|
|
# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
|
|
|
|
# have no arguments; the computation within the block uses implicit
|
|
|
|
# captures.
|
|
|
|
|
|
|
|
# TODO(phawkins): avoid build_generic when CaseOp is fixed.
|
|
|
|
case_op = mhlo.CaseOp.build_generic(
|
|
|
|
flat_output_types, [index], regions=len(branches))
|
2022-04-13 06:28:03 -07:00
|
|
|
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
|
2022-03-04 10:25:22 -05:00
|
|
|
for i, jaxpr in enumerate(branches):
|
|
|
|
branch = case_op.regions[i].blocks.append()
|
|
|
|
with ir.InsertionPoint(branch):
|
2022-04-19 10:45:09 -07:00
|
|
|
if jaxpr.effects:
|
|
|
|
raise NotImplementedError('Cannot lower effectful `cond`.')
|
2022-04-13 06:28:03 -07:00
|
|
|
sub_ctx = ctx.module_context.replace(
|
|
|
|
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
|
2022-04-19 10:45:09 -07:00
|
|
|
out_vals, _ = mlir.jaxpr_subcomp(
|
|
|
|
sub_ctx, jaxpr.jaxpr, mlir.TokenSet(),
|
2022-03-04 10:25:22 -05:00
|
|
|
_map(mlir.ir_constants, jaxpr.consts),
|
|
|
|
*_map(mlir.wrap_singleton_ir_values, args))
|
|
|
|
mhlo.ReturnOp(util.flatten(out_vals))
|
|
|
|
|
|
|
|
return util.unflatten(case_op.results, _map(len, output_types))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(cond_p, _cond_lowering)
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
### scan
|
|
|
|
|
2020-10-26 23:19:33 -04:00
|
|
|
Carry = TypeVar('Carry')
|
|
|
|
X = TypeVar('X')
|
|
|
|
Y = TypeVar('Y')
|
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2020-10-26 23:19:33 -04:00
|
|
|
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
|
|
|
init: Carry,
|
|
|
|
xs: X,
|
|
|
|
length: Optional[int] = None,
|
|
|
|
reverse: bool = False,
|
|
|
|
unroll: int = 1) -> Tuple[Carry, Y]:
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Scan a function over leading array axes while carrying along state.
|
|
|
|
|
|
|
|
The type signature in brief is
|
|
|
|
|
|
|
|
.. code-block:: haskell
|
|
|
|
|
|
|
|
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
|
|
|
|
|
|
|
|
where we use [t] here to denote the type t with an additional leading axis.
|
|
|
|
That is, if t is an array type then [t] represents the type with an additional
|
|
|
|
leading axis, and if t is a pytree (container) type with array leaves then [t]
|
|
|
|
represents the type with the same pytree structure and corresponding leaves
|
|
|
|
each with an additional leading axis.
|
|
|
|
|
|
|
|
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
|
|
|
|
of ``scan`` are given roughly by this Python implementation::
|
|
|
|
|
|
|
|
def scan(f, init, xs, length=None):
|
|
|
|
if xs is None:
|
|
|
|
xs = [None] * length
|
|
|
|
carry = init
|
|
|
|
ys = []
|
|
|
|
for x in xs:
|
|
|
|
carry, y = f(carry, x)
|
|
|
|
ys.append(y)
|
|
|
|
return carry, np.stack(ys)
|
|
|
|
|
|
|
|
Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
|
|
|
|
types, and so multiple arrays can be scanned over at once and produce multiple
|
|
|
|
output arrays. (None is actually an empty pytree.)
|
|
|
|
|
|
|
|
Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
|
|
|
|
a single XLA While HLO. That makes it useful for reducing compilation times
|
|
|
|
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
|
|
|
function are unrolled, leading to large XLA computations.
|
|
|
|
|
|
|
|
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
|
|
|
|
across all iterations (and not just be consistent up to NumPy rank/shape
|
|
|
|
broadcasting and dtype promotion rules, for example). In other words, the type
|
|
|
|
``c`` in the type signature above represents an array with a fixed shape and
|
|
|
|
dtype (or a nested tuple/list/dict container data structure with a fixed
|
|
|
|
structure and arrays with fixed shape and dtype at the leaves).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
|
|
|
that ``f`` accepts two arguments where the first is a value of the loop
|
|
|
|
carry and the second is a slice of ``xs`` along its leading axis, and that
|
|
|
|
``f`` returns a pair where the first element represents a new value for
|
|
|
|
the loop carry and the second represents a slice of the output.
|
|
|
|
init: an initial loop carry value of type ``c``, which can be a scalar,
|
|
|
|
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
|
|
|
the initial loop carry value. This value must have the same structure as
|
|
|
|
the first element of the pair returned by ``f``.
|
|
|
|
xs: the value of type ``[a]`` over which to scan along the leading axis,
|
|
|
|
where ``[a]`` can be an array or any pytree (nested Python
|
|
|
|
tuple/list/dict) thereof with consistent leading axis sizes.
|
|
|
|
length: optional integer specifying the number of loop iterations, which
|
|
|
|
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
|
|
|
be used to perform scans where no input ``xs`` are needed).
|
|
|
|
reverse: optional boolean specifying whether to run the scan iteration
|
|
|
|
forward (the default) or in reverse, equivalent to reversing the leading
|
|
|
|
axes of the arrays in both ``xs`` and in ``ys``.
|
|
|
|
unroll: optional positive int specifying, in the underlying operation of the
|
|
|
|
scan primitive, how many scan iterations to unroll within a single
|
|
|
|
iteration of a loop.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A pair of type ``(c, [b])`` where the first element represents the final
|
|
|
|
loop carry value and the second element represents the stacked outputs of
|
|
|
|
the second output of ``f`` when scanned over the leading axis of the inputs.
|
|
|
|
"""
|
2022-04-13 10:39:01 -07:00
|
|
|
if not callable(f):
|
|
|
|
raise TypeError("lax.scan: f argument should be a callable.")
|
2020-10-17 14:33:26 -04:00
|
|
|
xs_flat, xs_tree = tree_flatten(xs)
|
|
|
|
|
|
|
|
try:
|
|
|
|
lengths = [x.shape[0] for x in xs_flat]
|
|
|
|
except AttributeError as err:
|
|
|
|
msg = "scan got value with no leading axis to scan over: {}."
|
|
|
|
raise ValueError(
|
|
|
|
msg.format(', '.join(str(x) for x in xs_flat
|
|
|
|
if not hasattr(x, 'shape')))) from err
|
|
|
|
|
|
|
|
if length is not None:
|
|
|
|
length = int(length)
|
|
|
|
if not all(length == l for l in lengths):
|
|
|
|
msg = ("scan got `length` argument of {} which disagrees with "
|
|
|
|
"leading axis sizes {}.")
|
|
|
|
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
|
|
|
|
else:
|
|
|
|
unique_lengths = set(lengths)
|
|
|
|
if len(unique_lengths) > 1:
|
|
|
|
msg = "scan got values with different leading axis sizes: {}."
|
|
|
|
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
|
|
|
elif len(unique_lengths) == 0:
|
|
|
|
msg = "scan got no values to scan over and `length` not provided."
|
|
|
|
raise ValueError(msg)
|
|
|
|
else:
|
|
|
|
length, = unique_lengths
|
|
|
|
|
2021-03-24 11:20:35 -04:00
|
|
|
if config.jax_disable_jit:
|
2021-10-11 18:56:15 +02:00
|
|
|
if length == 0:
|
|
|
|
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
|
2020-10-17 14:33:26 -04:00
|
|
|
carry = init
|
|
|
|
ys = []
|
|
|
|
maybe_reversed = reversed if reverse else lambda x: x
|
|
|
|
for i in maybe_reversed(range(length)):
|
|
|
|
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
|
|
|
|
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
|
|
|
|
ys.append(y)
|
|
|
|
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
|
|
|
|
else jax.numpy.stack((y, *ys)))
|
2022-04-01 14:51:54 -07:00
|
|
|
stacked_y = tree_map(stack, *maybe_reversed(ys))
|
2020-10-26 23:19:33 -04:00
|
|
|
return carry, stacked_y
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
|
2021-12-07 06:12:32 -08:00
|
|
|
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
|
2020-10-17 14:33:26 -04:00
|
|
|
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
|
|
|
|
|
|
|
|
def _create_jaxpr(init):
|
|
|
|
init_flat, init_tree = tree_flatten(init)
|
|
|
|
in_flat, in_tree = tree_flatten((init, xs))
|
|
|
|
|
|
|
|
carry_avals = tuple(_map(_abstractify, init_flat))
|
2021-05-01 12:28:12 -07:00
|
|
|
jaxpr, consts, out_tree = _initial_style_jaxpr(
|
|
|
|
f, in_tree, carry_avals + x_avals, "scan")
|
2020-10-17 14:33:26 -04:00
|
|
|
out_tree_children = out_tree.children()
|
|
|
|
if len(out_tree_children) != 2:
|
|
|
|
msg = "scan body output must be a pair, got {}."
|
|
|
|
raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
|
|
|
|
carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves]
|
|
|
|
return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children
|
|
|
|
|
|
|
|
# The carry input and output avals must match exactly. However, we want to account for
|
|
|
|
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
|
|
|
|
# may not match the output despite being compatible by virtue of their weak type.
|
|
|
|
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
|
|
|
|
# necessary, a second time with modified init values.
|
|
|
|
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
|
|
|
|
new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out)
|
|
|
|
if changed:
|
|
|
|
new_init = tree_unflatten(init_tree, new_init_flat)
|
|
|
|
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
|
|
|
|
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
|
|
|
|
|
|
|
|
_check_tree_and_avals("scan carry output and input",
|
|
|
|
# Extract the subtree and avals for the first element of the return tuple
|
|
|
|
out_tree_children[0], carry_avals_out,
|
|
|
|
init_tree, carry_avals)
|
2022-04-12 13:32:43 -07:00
|
|
|
if jaxpr.effects:
|
|
|
|
raise NotImplementedError('Effects not supported in `scan`.')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-07-09 12:38:34 -07:00
|
|
|
out = scan_p.bind(*consts, *in_flat,
|
2020-10-17 14:33:26 -04:00
|
|
|
reverse=reverse, length=length, jaxpr=jaxpr,
|
|
|
|
num_consts=len(consts), num_carry=len(init_flat),
|
|
|
|
linear=(False,) * (len(consts) + len(in_flat)),
|
|
|
|
unroll=unroll)
|
|
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
|
|
|
|
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
|
|
|
|
f_impl, x_avals, y_avals):
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
|
|
|
|
carry = init
|
|
|
|
ys = []
|
|
|
|
|
|
|
|
for i in range(length):
|
|
|
|
i_ = length - i - 1 if reverse else i
|
|
|
|
x = _map(partial(_index_array, i_), x_avals, xs)
|
|
|
|
out = f_impl(*consts, *carry, *x)
|
|
|
|
carry, y = split_list(out, [num_carry])
|
|
|
|
ys.append(y)
|
|
|
|
|
|
|
|
ys = list(reversed(ys)) if reverse else ys
|
|
|
|
ys = list(zip(*ys))
|
|
|
|
ys = _map(_stack, y_avals, ys)
|
|
|
|
return (*carry, *ys)
|
|
|
|
|
|
|
|
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
|
|
|
|
f_impl, x_avals, y_avals):
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
|
|
|
|
def cond_fun(vals):
|
|
|
|
i, *_ = vals
|
|
|
|
return i < length
|
|
|
|
|
|
|
|
def body_fun(vals):
|
|
|
|
[i], carry, ys = split_list(vals, [1, num_carry])
|
|
|
|
i_ = length - i - 1 if reverse else i
|
|
|
|
x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
|
|
|
|
out_flat = f_impl(*consts, *carry, *x)
|
|
|
|
carry_out, y_updates = split_list(out_flat, [num_carry])
|
|
|
|
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
|
|
|
|
return [i + 1] + carry_out + ys_out
|
|
|
|
|
|
|
|
ys_init = _map(partial(_empty_array, length), y_avals)
|
|
|
|
if length == 0:
|
|
|
|
return init + ys_init
|
|
|
|
else:
|
|
|
|
init_val = [lax._const(length, 0)] + init + ys_init
|
|
|
|
_, *outs = while_loop(cond_fun, body_fun, init_val)
|
|
|
|
return outs
|
|
|
|
|
|
|
|
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
|
|
|
|
linear, block_length, f_impl, x_avals, y_avals):
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
|
|
|
|
num_blocks, rem = divmod(length, block_length)
|
|
|
|
assert rem == 0
|
|
|
|
|
|
|
|
partition = partial(_partition_leading, num_blocks, block_length)
|
|
|
|
xs_block = _map(partition, x_avals, xs)
|
|
|
|
|
|
|
|
prepend_aval = partial(_prepend_dim_to_aval, block_length)
|
|
|
|
x_block_avals = _map(prepend_aval, x_avals)
|
|
|
|
y_block_avals = _map(prepend_aval, y_avals)
|
|
|
|
|
|
|
|
f_impl_block = partial(
|
|
|
|
_scan_impl_unrolled, reverse=reverse, length=block_length,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
|
|
|
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
|
|
|
|
|
|
|
outs = _scan_impl_loop(
|
|
|
|
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
|
|
|
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
|
|
|
|
|
|
|
|
carry, ys_blocks = split_list(outs, [num_carry])
|
|
|
|
combine = partial(_combine_leading, num_blocks, block_length)
|
|
|
|
ys = _map(combine, y_avals, ys_blocks)
|
|
|
|
return (*carry, *ys)
|
|
|
|
|
|
|
|
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
|
|
|
unroll):
|
|
|
|
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
|
|
|
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
|
|
f_impl = core.jaxpr_as_fun(jaxpr)
|
|
|
|
|
|
|
|
if unroll == 1:
|
|
|
|
return _scan_impl_loop(
|
|
|
|
*args, reverse=reverse, length=length, num_consts=num_consts,
|
|
|
|
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
|
|
|
|
y_avals=y_avals)
|
|
|
|
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
num_blocks, rem = divmod(length, unroll)
|
|
|
|
length_div = num_blocks * unroll
|
|
|
|
|
|
|
|
if rem > 0:
|
|
|
|
if reverse:
|
|
|
|
split = partial(_split_leading_dim, rem)
|
|
|
|
xs_rem, xs = unzip2(_map(split, x_avals, xs))
|
|
|
|
else:
|
|
|
|
split = partial(_split_leading_dim, length_div)
|
|
|
|
xs, xs_rem = unzip2(_map(split, x_avals, xs))
|
|
|
|
|
|
|
|
outs = _scan_impl_block_unrolled(
|
|
|
|
*consts, *init, *xs, reverse=reverse, length=length_div,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
|
|
|
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
|
|
|
|
|
|
|
carry, ys = split_list(outs, [num_carry])
|
|
|
|
|
|
|
|
if rem > 0:
|
|
|
|
outs = _scan_impl_unrolled(
|
|
|
|
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
|
|
|
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
|
|
|
carry, ys_rem = split_list(outs, [num_carry])
|
|
|
|
if reverse:
|
|
|
|
ys = _map(_concatenate, y_avals, ys_rem, ys)
|
|
|
|
else:
|
|
|
|
ys = _map(_concatenate, y_avals, ys, ys_rem)
|
|
|
|
|
|
|
|
return (*carry, *ys)
|
|
|
|
|
|
|
|
def _stack(aval, vals):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
vals = [lax.expand_dims(x, (0,)) for x in vals]
|
|
|
|
return lax.concatenate(vals, 0)
|
|
|
|
|
|
|
|
def _concatenate(aval, x1, x2):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
return lax.concatenate([x1, x2], 0)
|
|
|
|
|
|
|
|
def _split_leading_dim(i, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return (core.unit, core.unit)
|
|
|
|
else:
|
|
|
|
assert x.ndim >= 1
|
2021-11-23 16:34:33 -08:00
|
|
|
return (slicing.slice_in_dim(x, 0, i),
|
|
|
|
slicing.slice_in_dim(x, i, x.shape[0]))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _dynamic_index_array(i, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
2021-11-23 16:34:33 -08:00
|
|
|
return slicing.dynamic_index_in_dim(x, i, keepdims=False)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _index_array(i, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
2021-11-23 16:34:33 -08:00
|
|
|
return slicing.index_in_dim(x, i, keepdims=False)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _empty_array(sz, aval):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
return lax.full((sz,) + aval.shape, 0, aval.dtype)
|
|
|
|
|
|
|
|
def _update_array(i, aval, xs, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
2021-11-23 16:34:33 -08:00
|
|
|
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _partition_leading(sz0, sz1, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
assert x.ndim >= 1
|
|
|
|
assert x.shape[0] == sz0 * sz1
|
|
|
|
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
|
|
|
|
|
|
|
|
def _combine_leading(sz0, sz1, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
assert x.ndim >= 2
|
|
|
|
assert x.shape[0] == sz0
|
|
|
|
assert x.shape[1] == sz1
|
|
|
|
return lax.collapse(x, 0, 2)
|
|
|
|
|
|
|
|
def _prepend_dim_to_aval(sz, aval):
|
2022-03-30 17:52:55 -07:00
|
|
|
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
|
|
|
|
linear, unroll):
|
2022-04-12 13:32:43 -07:00
|
|
|
if jaxpr.effects:
|
|
|
|
raise NotImplementedError('Effects not supported in `scan`.')
|
2020-10-17 14:33:26 -04:00
|
|
|
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
2021-01-27 15:13:30 -08:00
|
|
|
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
|
2022-02-28 13:36:39 -08:00
|
|
|
return carry_avals + ys_avals, core.no_effects
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
|
|
|
linear, unroll):
|
|
|
|
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
|
|
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
|
|
|
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])
|
|
|
|
|
|
|
|
# Fixpoint computation of which carry are not ad.zero: either
|
|
|
|
# non-zero from init, or the carry out is non-zero. Each iteration promotes
|
|
|
|
# at least one carry to non-zero. We need at most len(carry) iterations,
|
|
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
|
|
# carry_nz.
|
|
|
|
carry_nz = init_nz
|
|
|
|
for _ in range(1 + len(carry_nz)):
|
|
|
|
nonzeros = const_nz + carry_nz + xs_nz
|
|
|
|
jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
|
|
|
|
jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
|
|
|
|
carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
|
|
|
|
if carry_nz_out == carry_nz:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_nz = _map(operator.or_, carry_nz, carry_nz_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
|
|
|
tangents = [ad.instantiate_zeros(t) if nz else t
|
|
|
|
for t, nz in zip(tangents, nonzeros)]
|
|
|
|
|
|
|
|
consts, init, xs = split_list(primals, [num_consts, num_carry])
|
|
|
|
all_tangents = split_list(tangents, [num_consts, num_carry])
|
|
|
|
consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)
|
|
|
|
|
|
|
|
jaxpr_jvp_rearranged = ad.rearrange_binders(
|
|
|
|
jaxpr_jvp,
|
|
|
|
[num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)],
|
|
|
|
[num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)])
|
|
|
|
|
|
|
|
consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
|
|
|
jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot)
|
|
|
|
+ init_linear + [True] * len(init_dot)
|
|
|
|
+ xs_linear + [True] * len(xs_dot))
|
|
|
|
|
|
|
|
out_flat = scan_p.bind(
|
|
|
|
*(consts + consts_dot + init + init_dot + xs + xs_dot),
|
|
|
|
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
|
|
|
|
num_consts=num_consts + len(consts_dot),
|
|
|
|
num_carry=num_carry + len(init_dot),
|
|
|
|
linear=jaxpr_jvp_linear, unroll=unroll)
|
|
|
|
|
|
|
|
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
|
|
|
|
primals_out = carry + ys
|
|
|
|
tangents_out_iter = iter(carry_dot + ys_dot)
|
|
|
|
tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p)
|
|
|
|
for p, nz in zip(primals_out, nonzeros_out)]
|
|
|
|
return primals_out, tangents_out
|
|
|
|
|
|
|
|
def _prune_zeros(ts):
|
|
|
|
return [t for t in ts if type(t) is not ad_util.Zero]
|
|
|
|
|
|
|
|
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
|
|
|
jaxpr, linear, unroll):
|
|
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
2022-04-26 20:34:14 -07:00
|
|
|
unknowns = [not t.pval.is_known() for t in tracers]
|
2020-10-17 14:33:26 -04:00
|
|
|
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
|
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
# Fixpoint computation of which carry elements are unknown. Each iteration
|
|
|
|
# promotes at least one carry to unknown. We need at most len(carry)
|
|
|
|
# iterations, but we need one last iteration to prepare the jaxpr based on the
|
|
|
|
# final carry_uk.
|
2020-10-17 14:33:26 -04:00
|
|
|
carry_uk = init_uk
|
|
|
|
for _ in range(1 + len(carry_uk)):
|
|
|
|
unknowns = const_uk + carry_uk + xs_uk
|
2022-04-26 20:34:14 -07:00
|
|
|
jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits(
|
2020-10-17 14:33:26 -04:00
|
|
|
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
|
2022-04-26 20:34:14 -07:00
|
|
|
carry_uk_out, ys_uk = split_list(out_uk, [num_carry])
|
2020-10-17 14:33:26 -04:00
|
|
|
if carry_uk_out == carry_uk:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
2022-04-26 20:34:14 -07:00
|
|
|
num_res = len(res_avals)
|
|
|
|
del res_avals, carry_uk_out
|
|
|
|
|
|
|
|
# Instantiate those inputs which must be treated as unknown from the fixpoint.
|
|
|
|
tracers = [trace.instantiate_const(t) if uk else t
|
|
|
|
for t, uk in zip(tracers, unknowns)]
|
|
|
|
|
|
|
|
# The residual inputs and outputs of the jaxprs produced haven't yet been
|
|
|
|
# adapted to the scan calling convention; in particular, jaxpr_known has its
|
|
|
|
# residual outputs all at the end, meaning they're extensive outputs (which is
|
|
|
|
# fully general but may be wasteful for residuals which are loop-invariant)
|
|
|
|
# while jaxpr_unknown has its corresponding residual inputs at the front (just
|
|
|
|
# as a convention with partial_eval_jaxpr_nounits), making them constant
|
|
|
|
# inputs. To make them consistent, we move the residual inputs on
|
|
|
|
# jaxpr_unknown to the end, even though we may move some back in the sequel.
|
|
|
|
jaxpr_unknown = pe.move_binders_to_back(
|
|
|
|
jaxpr_unknown, [True] * num_res + [False] * sum(unknowns))
|
|
|
|
|
|
|
|
# At this point, all residuals are treated as extensive outputs of jaxpr_known
|
|
|
|
# (and extensive inputs to jaxpr_unknown). But residuals that are loop-
|
|
|
|
# invariant can be hoisted out of the scan, rather than letting them get
|
|
|
|
# broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
|
|
|
|
# want to broadcast the matrix!). So, outside the loop we perform a partial
|
|
|
|
# evaluation with known 'const' inputs (but all other inputs unknown).
|
|
|
|
const_pvals = [pe.PartialVal.known(t.pval.get_known())
|
|
|
|
for t in tracers[:num_consts] if t.pval.is_known()]
|
|
|
|
other_pvals = [pe.PartialVal.unknown(aval)
|
|
|
|
for aval in jaxpr_known.in_avals[len(const_pvals):]]
|
2021-10-28 11:06:58 -07:00
|
|
|
with source_info_util.reset_name_stack():
|
2022-04-26 20:34:14 -07:00
|
|
|
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
|
|
|
|
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals,
|
|
|
|
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
|
|
|
|
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
|
|
|
|
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
|
|
|
|
# (known values in invar_pvals_out) and also computed loop-invariant values
|
|
|
|
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
|
|
|
|
# previous consts). We need to collect the computed inteisive residuals, and
|
|
|
|
# move corresponding intensive residual binders in jaxpr_unknown to the front.
|
|
|
|
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
|
|
|
|
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
|
|
|
|
jaxpr_unknown = pe.move_binders_to_front(
|
|
|
|
jaxpr_unknown,
|
|
|
|
[False] * sum(unknowns) + [pval.is_known() for pval in res_pvals])
|
|
|
|
del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals
|
|
|
|
# We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
|
|
|
|
# we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
# As another optimization, for any extensive inputs that are just forwarded to
|
2022-04-26 20:34:14 -07:00
|
|
|
# extensive outputs, to avoid a copy (which would be looping over
|
|
|
|
# dynamic-update-slice) we'd rather forward the input tracer/value. That means
|
|
|
|
# pruning some outputs from jaxpr_known here, and updating `out_flat` below.
|
|
|
|
fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
|
|
|
|
# Prune fwds_known to include only extensive input to extensive output.
|
|
|
|
fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and
|
|
|
|
in_idx is not None and
|
|
|
|
in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk)
|
|
|
|
else None for out_idx, in_idx in enumerate(fwds_known)]
|
|
|
|
# Drop any extensive output we can instead get by forwarding an input.
|
|
|
|
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
|
|
|
|
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
|
|
|
|
jaxpr_known_.outvars = [x for x, i in zip(jaxpr_known_.outvars, fwds_known)
|
|
|
|
if i is None]
|
|
|
|
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
|
|
|
|
del jaxpr_known_
|
|
|
|
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
|
|
|
|
|
2022-04-26 21:33:15 -07:00
|
|
|
# Run the known part of the scan (if it has any outputs or effects).
|
2022-04-26 20:34:14 -07:00
|
|
|
known_inputs = (list(jaxpr_known_consts) +
|
|
|
|
[t.pval.get_known() for t in tracers[num_consts:]
|
|
|
|
if t.pval.is_known()])
|
2022-04-26 21:33:15 -07:00
|
|
|
if not jaxpr_known.out_avals and not jaxpr_known.effects:
|
|
|
|
out_known = []
|
|
|
|
else:
|
|
|
|
linear_known = [False] * len(known_inputs) # conservative!
|
|
|
|
out_known = scan_p.bind(
|
|
|
|
*known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
|
|
|
|
num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk),
|
|
|
|
linear=tuple(linear_known), unroll=unroll)
|
|
|
|
del linear_known
|
2022-04-26 20:34:14 -07:00
|
|
|
# Complete the known output by filling in forwarded values using fwds_known.
|
|
|
|
out_known_iter = iter(out_known)
|
2022-04-26 21:33:15 -07:00
|
|
|
out_known = [next(out_known_iter) if f is None
|
|
|
|
else _maybe_put(known_inputs[f]) for f in fwds_known]
|
2022-04-26 20:34:14 -07:00
|
|
|
assert next(out_known_iter, None) is None
|
2022-04-26 21:33:15 -07:00
|
|
|
del known_inputs, out_known_iter
|
2022-04-26 20:34:14 -07:00
|
|
|
|
|
|
|
# Split known outputs from residuals.
|
|
|
|
out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)])
|
|
|
|
assert len(intensive_res) + len(extensive_res) == num_res
|
|
|
|
|
|
|
|
# Create input tracers for jaxpr_unknown bind.
|
|
|
|
unknown_inputs = [t for t in tracers if not t.pval.is_known()]
|
|
|
|
intensive_res = _map(trace.new_instantiated_const, intensive_res)
|
|
|
|
extensive_res = _map(trace.new_instantiated_const, extensive_res)
|
|
|
|
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
|
|
|
|
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
|
|
|
|
ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval)
|
|
|
|
for y_aval in y_avals]
|
|
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
|
|
|
for a in itertools.chain(carry_avals, ys_avals)]
|
|
|
|
del carry_avals, y_avals
|
|
|
|
# Create equation.
|
|
|
|
linear_unknown = tuple([False] * len(intensive_res) +
|
|
|
|
[l for l, uk in zip(linear, unknowns) if uk] +
|
|
|
|
[False] * len(extensive_res))
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
2022-04-26 20:34:14 -07:00
|
|
|
eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res],
|
2020-10-17 14:33:26 -04:00
|
|
|
out_tracers, scan_p,
|
2022-04-26 20:34:14 -07:00
|
|
|
dict(reverse=reverse, length=length, unroll=unroll,
|
|
|
|
jaxpr=jaxpr_unknown, linear=linear_unknown,
|
|
|
|
num_consts=len(intensive_res) + sum(const_uk),
|
|
|
|
num_carry=sum(carry_uk)),
|
|
|
|
jaxpr_unknown.effects, source)
|
2020-10-17 14:33:26 -04:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
|
2022-04-26 20:34:14 -07:00
|
|
|
# Merge known and unknown outputs into final result.
|
|
|
|
return util.merge_lists(out_uk, out_known, out_tracers)
|
|
|
|
|
|
|
|
def _maybe_put(x):
|
2020-10-17 14:33:26 -04:00
|
|
|
if isinstance(x, np.ndarray):
|
2022-04-20 14:47:15 -07:00
|
|
|
return jax.device_put(x, jax.devices('cpu')[0])
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts,
|
|
|
|
num_carry, jaxpr, linear, unroll):
|
2020-10-17 14:33:26 -04:00
|
|
|
# we've only implemented transposing scans with specific lin/nonlin patterns
|
|
|
|
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
|
|
|
|
num_ires = len(consts_lin) - sum(consts_lin)
|
|
|
|
num_eres = len(xs_lin) - sum(xs_lin)
|
|
|
|
if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
|
|
|
|
raise NotImplementedError
|
|
|
|
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
|
|
|
|
raise NotImplementedError
|
|
|
|
if not all(init_lin):
|
|
|
|
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
|
|
|
|
|
|
|
|
consts, _, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
ires, _ = split_list(consts, [num_ires])
|
|
|
|
_, eres = split_list(xs, [sum(xs_lin)])
|
|
|
|
assert not any(ad.is_undefined_primal(r) for r in ires)
|
|
|
|
assert not any(ad.is_undefined_primal(r) for r in eres)
|
|
|
|
|
|
|
|
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
2021-01-27 15:13:30 -08:00
|
|
|
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
|
2020-10-17 14:33:26 -04:00
|
|
|
ct_carry, ct_ys = split_list(cts, [num_carry])
|
|
|
|
ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
|
|
|
|
ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
|
|
|
|
ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
|
|
|
|
|
|
|
|
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
|
|
|
|
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
|
|
|
|
jaxpr_trans = _transpose_scan_jaxpr(
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes)
|
2020-10-17 14:33:26 -04:00
|
|
|
linear_trans = ([False] * num_ires +
|
|
|
|
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
|
|
|
|
[False] * num_eres)
|
|
|
|
|
|
|
|
outs = scan_p.bind(
|
|
|
|
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
|
|
|
|
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
|
|
|
|
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans),
|
|
|
|
unroll=unroll)
|
|
|
|
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
|
|
|
|
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
|
|
|
|
|
|
|
|
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
|
|
|
|
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
|
2020-10-17 14:33:26 -04:00
|
|
|
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
|
|
|
|
# if an axis isn't reduced
|
2020-10-17 14:33:26 -04:00
|
|
|
res1_avals, c_avals, a_avals, res2_avals = split_list(
|
|
|
|
jaxpr.in_avals, [num_res1, num_c, num_a])
|
|
|
|
num_b = len(jaxpr.out_avals)
|
|
|
|
b_avals = list(jaxpr.out_avals)
|
|
|
|
|
|
|
|
@lu.wrap_init
|
|
|
|
def transposed(*res1_cbar_bbar_res2):
|
|
|
|
res1, c_bar, b_bar, res2 = split_list(
|
|
|
|
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
|
|
|
|
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
|
|
|
|
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
|
2021-10-28 11:06:58 -07:00
|
|
|
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts,
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
primals, b_bar)
|
2020-10-17 14:33:26 -04:00
|
|
|
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
|
|
|
|
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
|
|
|
|
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
|
|
|
|
_map(ad.add_tangents, c_bar, new_c_bar))
|
|
|
|
return c_bar + a_bar
|
|
|
|
return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
|
|
|
|
|
|
|
|
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
2021-03-29 13:58:04 -07:00
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
2020-10-17 14:33:26 -04:00
|
|
|
return core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
|
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
def _scan_batching_rule(axis_size, axis_name, main_type, args, dims, reverse, length,
|
|
|
|
jaxpr, num_consts, num_carry, linear, unroll):
|
2020-10-17 14:33:26 -04:00
|
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
|
|
|
orig_batched = [d is not batching.not_mapped for d in dims]
|
|
|
|
const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry])
|
|
|
|
|
|
|
|
# Fixpoint computation of which carry are batched: either
|
|
|
|
# batched from init, or the carry out is batched. Each iteration promotes
|
|
|
|
# at least one carry to batched. We need at most len(carry) iterations,
|
|
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
|
|
# carry_batched.
|
|
|
|
carry_batched = init_batched
|
|
|
|
for _ in range(1 + len(carry_batched)):
|
|
|
|
batched = const_batched + carry_batched + xs_batched
|
|
|
|
jaxpr_batched, batched_out = batching.batch_jaxpr(
|
2021-08-26 13:34:01 -07:00
|
|
|
jaxpr, axis_size, batched,
|
2020-10-26 10:11:13 +00:00
|
|
|
instantiate=carry_batched + [False] * num_ys,
|
2021-04-20 08:32:41 -07:00
|
|
|
axis_name=axis_name,
|
|
|
|
main_type=main_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
|
|
|
|
if carry_batched_out == carry_batched:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_batched = _map(operator.or_, carry_batched, carry_batched_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
|
|
|
|
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
|
|
|
else x for x, d in zip(consts, consts_bdims)]
|
2021-08-26 13:34:01 -07:00
|
|
|
new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched
|
2020-10-17 14:33:26 -04:00
|
|
|
else batching.moveaxis(x, d, 0) if now_batched else x
|
|
|
|
for x, d, was_batched, now_batched in
|
|
|
|
zip(init, init_bdims, init_batched, carry_batched)]
|
|
|
|
new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1
|
|
|
|
else x for x, d in zip(xs, xs_bdims)]
|
|
|
|
new_args = new_consts + new_init + new_xs
|
|
|
|
|
|
|
|
outs = scan_p.bind(
|
|
|
|
*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll)
|
|
|
|
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
|
|
|
|
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
|
|
|
|
return outs, carry_bdims + ys_bdims
|
|
|
|
|
|
|
|
def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
|
|
|
|
jaxpr, num_consts, num_carry, linear, unroll):
|
|
|
|
dynamic_length, = masking.shape_as_value((length,))
|
|
|
|
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
|
|
|
|
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
|
|
|
|
max_length, = {x.shape[0] for x in xs}
|
|
|
|
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
2021-11-24 16:49:23 -08:00
|
|
|
dynamic_length = lax.convert_element_type(dynamic_length, dtypes.int_)
|
|
|
|
out_vals = scan_p.bind(dynamic_length, *consts, dtypes.int_(0), *init, *xs,
|
2020-10-17 14:33:26 -04:00
|
|
|
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
|
|
|
|
num_consts=1 + num_consts, num_carry=1 + num_carry,
|
|
|
|
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear),
|
|
|
|
unroll=unroll)
|
|
|
|
return out_vals[1:]
|
|
|
|
|
|
|
|
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
|
|
|
fun = core.jaxpr_as_fun(jaxpr)
|
|
|
|
|
|
|
|
@lu.wrap_init
|
|
|
|
def masked(*args):
|
|
|
|
[dynamic_length], consts, [i], carry, xs = split_list(
|
|
|
|
args, [1, num_consts, 1, num_carry])
|
|
|
|
out = fun(*(consts + carry + xs))
|
|
|
|
new_carry, ys = split_list(out, [num_carry])
|
|
|
|
new_carry = [lax.select(i < dynamic_length, new_c, c)
|
|
|
|
for new_c, c in zip(new_carry, carry)]
|
|
|
|
return [i + 1] + new_carry + ys
|
|
|
|
|
2021-12-07 06:12:32 -08:00
|
|
|
aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_))
|
2020-10-17 14:33:26 -04:00
|
|
|
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
|
|
|
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
|
|
|
|
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
|
|
|
|
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
|
|
|
|
jaxpr, linear, unroll):
|
|
|
|
tc = partial(_typecheck_param, 'scan')
|
|
|
|
tc(reverse, 'reverse', 'bool', type(reverse) is bool)
|
|
|
|
tc(num_consts, 'num_consts', 'non-negative int',
|
|
|
|
type(num_consts) is int and num_consts >= 0)
|
|
|
|
tc(num_carry, 'num_carry', 'non-negative int',
|
|
|
|
type(num_carry) is int and num_carry >= 0)
|
|
|
|
tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr)
|
|
|
|
tc(linear, 'linear', 'tuple of bool',
|
|
|
|
type(linear) is tuple and all(type(x) is bool for x in linear))
|
|
|
|
tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0)
|
|
|
|
|
|
|
|
length_types = (int, masking.Poly) if bind_time else (int,)
|
|
|
|
tc(length, 'length', 'non-negative int',
|
|
|
|
type(length) in length_types and length >= 0)
|
|
|
|
|
2022-01-18 10:41:43 -08:00
|
|
|
if len(linear) != len(avals):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'scan param linear has length {len(linear)} for {len(avals)} operands')
|
|
|
|
|
|
|
|
const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
|
|
|
|
const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list(
|
|
|
|
jaxpr.in_avals, [num_consts, num_carry])
|
|
|
|
carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry])
|
2020-11-05 11:54:05 +00:00
|
|
|
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2022-01-18 10:41:43 -08:00
|
|
|
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'scan input carry input and output types mismatch: '
|
|
|
|
f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
|
2022-01-18 10:41:43 -08:00
|
|
|
if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
|
|
|
|
f'called with consts of type\n{_avals_short(const_avals)}')
|
2022-01-18 10:41:43 -08:00
|
|
|
if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
|
|
|
|
f'called with initial carry of type\n{_avals_short(init_avals)}')
|
2022-01-18 10:41:43 -08:00
|
|
|
if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)):
|
|
|
|
raise core.JaxprTypeError(
|
2020-10-17 14:33:26 -04:00
|
|
|
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
|
|
|
|
f'called with sequence of type\n{_avals_short(x_avals)}')
|
2022-02-28 13:36:39 -08:00
|
|
|
return None, core.no_effects
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def scan_bind(*args, **params):
|
2021-03-19 13:49:38 -07:00
|
|
|
if config.jax_enable_checks:
|
2020-10-17 14:33:26 -04:00
|
|
|
avals = _map(core.get_aval, args)
|
|
|
|
_scan_typecheck(True, *avals, **params)
|
|
|
|
core.check_jaxpr(params['jaxpr'].jaxpr)
|
2021-08-26 13:34:01 -07:00
|
|
|
return core.AxisPrimitive.bind(scan_p, *args, **params)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
scan_p = core.AxisPrimitive("scan")
|
2020-10-17 14:33:26 -04:00
|
|
|
scan_p.multiple_results = True
|
|
|
|
scan_p.def_custom_bind(scan_bind)
|
2021-04-23 08:31:11 -07:00
|
|
|
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
|
2022-02-28 13:36:39 -08:00
|
|
|
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_jvps[scan_p] = _scan_jvp
|
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
2021-07-08 12:05:56 -07:00
|
|
|
ad.reducing_transposes[scan_p] = _scan_transpose
|
2020-10-17 14:33:26 -04:00
|
|
|
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
2022-04-13 07:25:51 -07:00
|
|
|
xla.register_initial_style_primitive(scan_p)
|
2022-04-06 12:53:19 -07:00
|
|
|
mlir.register_lowering(scan_p,
|
|
|
|
mlir.lower_fun(_scan_impl, multiple_results=True))
|
2021-08-26 13:34:01 -07:00
|
|
|
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
masking.masking_rules[scan_p] = _scan_masking_rule
|
|
|
|
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
2022-01-19 20:40:38 -08:00
|
|
|
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
|
|
|
|
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
|
2022-03-30 17:52:55 -07:00
|
|
|
pe.padding_rules[scan_p] = _scan_padding_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2020-10-17 14:33:26 -04:00
|
|
|
def map(f, xs):
|
|
|
|
"""Map a function over leading array axes.
|
|
|
|
|
|
|
|
Like Python's builtin map, except inputs and outputs are in the form of
|
|
|
|
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
|
|
|
|
need to apply a function element by element for reduced memory usage or
|
|
|
|
heterogeneous computation with other control flow primitives.
|
|
|
|
|
|
|
|
When ``xs`` is an array type, the semantics of ``map`` are given by this
|
|
|
|
Python implementation::
|
|
|
|
|
|
|
|
def map(f, xs):
|
|
|
|
return np.stack([f(x) for x in xs])
|
|
|
|
|
|
|
|
Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
|
|
|
|
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
|
|
|
|
nested pytree type, and the mapped computation is compiled only once.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: a Python function to apply element-wise over the first axis or axes of
|
|
|
|
``xs``.
|
|
|
|
xs: values over which to map along the leading axis.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Mapped values.
|
|
|
|
"""
|
|
|
|
g = lambda _, x: ((), f(x))
|
|
|
|
_, ys = scan(g, (), xs)
|
|
|
|
return ys
|
|
|
|
|
|
|
|
|
|
|
|
def _concat_masking_rule(padded_vals, logical_shapes, dimension):
|
|
|
|
result = lax.concatenate(padded_vals, dimension) # fragmented
|
|
|
|
offset = 0
|
|
|
|
for padded_val, logical_shape in zip(padded_vals, logical_shapes):
|
|
|
|
result = _memcpy(dimension, logical_shape[dimension], padded_val,
|
|
|
|
result, offset)
|
|
|
|
offset = offset + logical_shape[dimension]
|
|
|
|
return result
|
|
|
|
|
|
|
|
def _memcpy(axis, num, src, dst, offset):
|
|
|
|
def body(i, dst):
|
2021-11-23 16:34:33 -08:00
|
|
|
update = slicing.dynamic_index_in_dim(src, i, axis)
|
|
|
|
return slicing.dynamic_update_index_in_dim(dst, update, i + offset, axis)
|
2020-10-17 14:33:26 -04:00
|
|
|
return fori_loop(0, num, body, dst)
|
|
|
|
|
|
|
|
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule # type: ignore
|
|
|
|
|
2021-10-14 17:37:50 +01:00
|
|
|
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
|
|
|
|
"""Calls RBG in a loop and stacks the results."""
|
|
|
|
key, = batched_args
|
|
|
|
bd, = batch_dims
|
|
|
|
if bd is batching.not_mapped:
|
|
|
|
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
|
|
|
|
algorithm=algorithm), (None, None)
|
|
|
|
key = batching.moveaxis(key, bd, 0)
|
|
|
|
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
|
|
|
|
stacked_keys, stacked_bits = map(map_body, key)
|
|
|
|
return (stacked_keys, stacked_bits), (0, 0)
|
|
|
|
|
|
|
|
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-19 13:31:29 -08:00
|
|
|
def _show_diff(array1, array2):
|
|
|
|
if core.typematch(array1, array2):
|
|
|
|
return f"{array1}"
|
|
|
|
return f"DIFFERENT {array1} vs. {array2}"
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
|
|
|
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
|
|
|
|
|
|
|
|
Corresponding `tree` and `avals` must match in the sense that the number of
|
|
|
|
leaves in `tree` must be equal to the length of `avals`. `what` will be
|
|
|
|
prepended to details of the mismatch in TypeError.
|
|
|
|
"""
|
|
|
|
if tree1 != tree2:
|
2020-11-16 19:33:14 -08:00
|
|
|
raise TypeError(
|
|
|
|
f"{what} must have same type structure, got {tree1} and {tree2}.")
|
|
|
|
if not all(_map(core.typematch, avals1, avals2)):
|
2022-04-01 14:51:54 -07:00
|
|
|
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
|
|
|
|
tree_unflatten(tree2, avals2))
|
2021-11-19 13:31:29 -08:00
|
|
|
raise TypeError(f"{what} must have identical types, got\n{diff}.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
2021-05-25 18:00:46 +02:00
|
|
|
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
|
|
|
|
if has_aux:
|
|
|
|
actual_tree_children = actual_tree.children()
|
|
|
|
|
|
|
|
if len(actual_tree_children) == 2:
|
|
|
|
# select first child as result tree
|
|
|
|
actual_tree = tree_structure(actual_tree_children[0])
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"{func_name}() produced a pytree with structure "
|
|
|
|
f"{actual_tree}, but a pytree tuple with auxiliary "
|
|
|
|
f"output was expected because has_aux was set to True.")
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
if actual_tree != expected_tree:
|
|
|
|
raise TypeError(
|
|
|
|
f"{func_name}() output pytree structure must match {expected_name}, "
|
|
|
|
f"got {actual_tree} and {expected_tree}.")
|
|
|
|
|
|
|
|
|
|
|
|
def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
|
|
|
|
"""Promote weakly-typed in_vals to be compatible with out_avals.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
in_vals : flattened list of input values.
|
|
|
|
in_avals : corresponding list of avals.
|
|
|
|
out_avals : list of target output avals.
|
|
|
|
Returns:
|
|
|
|
in_vals_new : flattened list of modified in_vals with no weak types.
|
|
|
|
changed : bool; true if in_vals required modification.
|
|
|
|
"""
|
|
|
|
if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
|
|
|
|
# Calling function is responsible for catching this.
|
|
|
|
return in_vals, False
|
|
|
|
weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
|
|
|
|
if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
|
|
|
|
if not weak_mismatches:
|
|
|
|
return in_vals, False
|
|
|
|
for i in weak_mismatches:
|
|
|
|
new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
|
|
|
|
in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
|
|
|
|
return in_vals, True
|
|
|
|
|
|
|
|
|
|
|
|
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
|
|
|
|
|
|
|
|
|
|
|
|
def _split_root_args(args, const_lengths):
|
|
|
|
params_list = split_list(args, list(const_lengths))
|
|
|
|
return _RootTuple(*params_list[:-1]), params_list[-1]
|
|
|
|
|
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2021-06-17 19:25:46 +02:00
|
|
|
def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Differentiably solve for a roots of a function.
|
|
|
|
|
|
|
|
This is a low-level routine, mostly intended for internal use in JAX.
|
|
|
|
Gradients of custom_root() are defined with respect to closed-over variables
|
|
|
|
from the provided function ``f`` via the implicit function theorem:
|
|
|
|
https://en.wikipedia.org/wiki/Implicit_function_theorem
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: function for which to find a root. Should accept a single argument,
|
|
|
|
return a tree of arrays with the same structure as its input.
|
|
|
|
initial_guess: initial guess for a zero of f.
|
|
|
|
solve: function to solve for the roots of f. Should take two positional
|
|
|
|
arguments, f and initial_guess, and return a solution with the same
|
|
|
|
structure as initial_guess such that func(solution) = 0. In other words,
|
|
|
|
the following is assumed to be true (but not checked)::
|
|
|
|
|
|
|
|
solution = solve(f, initial_guess)
|
|
|
|
error = f(solution)
|
|
|
|
assert all(error == 0)
|
|
|
|
|
|
|
|
tangent_solve: function to solve the tangent system. Should take two
|
|
|
|
positional arguments, a linear function ``g`` (the function ``f``
|
|
|
|
linearized at its root) and a tree of array(s) ``y`` with the same
|
|
|
|
structure as initial_guess, and return a solution ``x`` such that
|
|
|
|
``g(x)=y``:
|
|
|
|
|
|
|
|
- For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
|
|
|
|
- For vector ``y``, you could use a linear solve with the Jacobian, if
|
|
|
|
dimensionality of ``y`` is not too large:
|
|
|
|
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
|
2021-06-17 19:25:46 +02:00
|
|
|
has_aux: bool indicating whether the ``solve`` function returns
|
|
|
|
auxiliary data like solver diagnostics as a second argument.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
The result of calling solve(f, initial_guess) with gradients defined via
|
|
|
|
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
|
|
|
|
"""
|
|
|
|
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
|
|
|
guess_avals = tuple(_map(_abstractify, guess_flat))
|
|
|
|
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
|
|
|
f, in_args_tree, guess_avals)
|
|
|
|
|
|
|
|
in_tree, = treedef_children(in_args_tree)
|
2021-06-17 19:25:46 +02:00
|
|
|
_check_tree("f", "initial_guess", out_tree, in_tree, False)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
|
2021-12-07 01:26:23 +00:00
|
|
|
partial(solve, f), in_args_tree, guess_avals)
|
2021-06-17 19:25:46 +02:00
|
|
|
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def linearize_and_solve(x, b):
|
|
|
|
unchecked_zeros, f_jvp = jax.linearize(f, x)
|
|
|
|
return tangent_solve(f_jvp, b)
|
|
|
|
|
|
|
|
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
|
|
|
|
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
|
2021-06-17 19:25:46 +02:00
|
|
|
_check_tree("tangent_solve", "x", out_tree, in_tree, False)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
all_consts = [f_consts, solve_consts, l_and_s_consts]
|
|
|
|
const_lengths = _RootTuple(*_map(len, all_consts))
|
|
|
|
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
|
|
|
|
|
2021-06-17 19:25:46 +02:00
|
|
|
solution_flat = _custom_root(
|
2020-10-17 14:33:26 -04:00
|
|
|
const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
|
2021-06-17 19:25:46 +02:00
|
|
|
return tree_unflatten(solution_tree, solution_flat)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.custom_jvp, nondiff_argnums=(0, 1))
|
|
|
|
def _custom_root(const_lengths, jaxprs, *args):
|
|
|
|
params, initial_guess = _split_root_args(args, const_lengths)
|
|
|
|
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
|
|
|
|
return solution
|
|
|
|
|
|
|
|
|
|
|
|
@_custom_root.defjvp
|
|
|
|
def _root_jvp(const_lengths, jaxprs, primals, tangents):
|
|
|
|
params, _ = _split_root_args(primals, const_lengths)
|
2021-06-17 19:25:46 +02:00
|
|
|
sol = _custom_root(const_lengths, jaxprs, *primals)
|
|
|
|
|
|
|
|
f_out_vals = len(jaxprs.f.out_avals)
|
|
|
|
solution, aux = split_list(sol, [f_out_vals])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
params_dot, _ = _split_root_args(tangents, const_lengths)
|
|
|
|
|
|
|
|
# F(m, u) = 0 # system of equations in u, parameterized by m
|
|
|
|
# # solution is u*(m) defined in a neighborhood
|
|
|
|
# F(m, u*(m)) = 0 # satisfied in a neighborhood
|
|
|
|
#
|
|
|
|
# ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above
|
|
|
|
# ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange
|
|
|
|
#
|
|
|
|
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
|
|
|
|
|
|
|
|
f = core.jaxpr_as_fun(jaxprs.f)
|
|
|
|
linearize_and_solve = partial(
|
|
|
|
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
|
2021-07-09 12:38:34 -07:00
|
|
|
f_at_solution = lambda *params: f(*params, *solution)
|
2020-10-17 14:33:26 -04:00
|
|
|
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
|
|
|
|
params.f, params_dot.f)
|
|
|
|
solution_dot = _map(
|
2021-07-09 12:38:34 -07:00
|
|
|
operator.neg, linearize_and_solve(*solution, *rhs))
|
2021-06-17 19:25:46 +02:00
|
|
|
# append aux, create symbolic zero tangents for the aux values
|
|
|
|
solution += aux
|
|
|
|
solution_dot += _map(lax.zeros_like_array, aux)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
return solution, solution_dot
|
|
|
|
|
|
|
|
|
|
|
|
class _LinearSolveTuple(collections.namedtuple(
|
|
|
|
'_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):
|
|
|
|
|
|
|
|
def transpose(self):
|
|
|
|
return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)
|
|
|
|
|
|
|
|
|
|
|
|
def _split_linear_solve_args(args, const_lengths):
|
|
|
|
params_list = split_list(args, list(const_lengths))
|
|
|
|
return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
|
|
|
|
|
|
|
|
|
|
|
|
def _transpose_one_output(linear_fun, primals):
|
|
|
|
transpose_fun = jax.linear_transpose(linear_fun, primals)
|
|
|
|
def transposed_fun(x):
|
|
|
|
(y,) = transpose_fun(x)
|
|
|
|
return y
|
|
|
|
return transposed_fun
|
|
|
|
|
|
|
|
|
|
|
|
def _flatten(args):
|
|
|
|
return [x for arg in args for x in arg]
|
|
|
|
|
|
|
|
|
|
|
|
def _check_shapes(func_name, expected_name, actual, expected):
|
|
|
|
actual_shapes = _map(np.shape, tree_leaves(actual))
|
|
|
|
expected_shapes = _map(np.shape, tree_leaves(expected))
|
|
|
|
if actual_shapes != expected_shapes:
|
|
|
|
raise ValueError(
|
|
|
|
f"{func_name}() output shapes must match {expected_name}, "
|
|
|
|
f"got {actual_shapes} and {expected_shapes}")
|
|
|
|
|
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2020-10-17 14:33:26 -04:00
|
|
|
def custom_linear_solve(
|
2021-05-25 18:00:46 +02:00
|
|
|
matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False):
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Perform a matrix-free linear solve with implicitly defined gradients.
|
|
|
|
|
|
|
|
This function allows for overriding or defining gradients for a linear
|
|
|
|
solve directly via implicit differentiation at the solution, rather than by
|
|
|
|
differentiating *through* the solve operation. This can sometimes be much faster
|
|
|
|
or more numerically stable, or differentiating through the solve operation
|
|
|
|
may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``).
|
|
|
|
|
|
|
|
Required invariant::
|
|
|
|
|
|
|
|
x = solve(matvec, b) # solve the linear equation
|
|
|
|
assert matvec(x) == b # not checked
|
|
|
|
|
|
|
|
Args:
|
|
|
|
matvec: linear function to invert. Must be differentiable.
|
|
|
|
b: constant right handle side of the equation. May be any nested structure
|
|
|
|
of arrays.
|
|
|
|
solve: higher level function that solves for solution to the linear
|
2021-05-25 18:00:46 +02:00
|
|
|
equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form
|
2020-10-17 14:33:26 -04:00
|
|
|
as ``b``. This function need not be differentiable.
|
|
|
|
transpose_solve: higher level function for solving the transpose linear
|
|
|
|
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
|
|
|
|
the transpose of the linear map ``matvec`` (computed automatically with
|
|
|
|
autodiff). Required for backwards mode automatic differentiation, unless
|
|
|
|
``symmetric=True``, in which case ``solve`` provides the default value.
|
|
|
|
symmetric: bool indicating if it is safe to assume the linear map
|
|
|
|
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
|
2021-05-25 18:00:46 +02:00
|
|
|
has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions
|
|
|
|
return auxiliary data like solver diagnostics as a second argument.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Result of ``solve(matvec, b)``, with gradients defined assuming that the
|
2021-05-25 18:00:46 +02:00
|
|
|
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
|
|
|
if transpose_solve is None and symmetric:
|
|
|
|
transpose_solve = solve
|
|
|
|
|
|
|
|
b_flat, in_args_tree = tree_flatten((b,))
|
|
|
|
b_avals = tuple(_map(_abstractify, b_flat))
|
|
|
|
|
|
|
|
tree, = treedef_children(in_args_tree)
|
|
|
|
|
2021-05-25 18:00:46 +02:00
|
|
|
def _shape_checked(fun, name, has_aux):
|
2020-10-17 14:33:26 -04:00
|
|
|
def f(x):
|
|
|
|
y = fun(x)
|
|
|
|
_check_shapes(name, "b", y, b_flat)
|
|
|
|
return y
|
|
|
|
|
2021-05-25 18:00:46 +02:00
|
|
|
def f_aux(x):
|
|
|
|
y, aux = fun(x)
|
|
|
|
_check_shapes(name, "b", y, b_flat)
|
|
|
|
return y, aux
|
|
|
|
|
|
|
|
return f_aux if has_aux else f
|
|
|
|
|
|
|
|
# no auxiliary data assumed for matvec
|
2020-10-17 14:33:26 -04:00
|
|
|
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
|
2021-05-25 18:00:46 +02:00
|
|
|
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
|
2021-05-01 12:28:12 -07:00
|
|
|
'custom_linear_solve')
|
2021-05-25 18:00:46 +02:00
|
|
|
_check_tree("matvec", "b", out_tree, tree, False)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
|
2021-05-25 18:00:46 +02:00
|
|
|
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
|
2021-05-01 12:28:12 -07:00
|
|
|
'custom_linear_solve')
|
2021-05-25 18:00:46 +02:00
|
|
|
_check_tree("solve", "b", out_tree, tree, has_aux)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
if transpose_solve is None:
|
|
|
|
vecmat_jaxpr = tr_solve_jaxpr = None
|
|
|
|
vecmat_consts = tr_solve_consts = []
|
|
|
|
else:
|
|
|
|
if symmetric:
|
|
|
|
vecmat = matvec
|
|
|
|
vecmat_jaxpr = matvec_jaxpr
|
|
|
|
vecmat_consts = matvec_consts
|
|
|
|
else:
|
|
|
|
vecmat = _transpose_one_output(matvec, b)
|
|
|
|
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
|
2021-05-01 12:28:12 -07:00
|
|
|
vecmat, in_args_tree, b_avals, 'custom_linear_solve')
|
2020-10-17 14:33:26 -04:00
|
|
|
assert out_tree == tree
|
|
|
|
|
|
|
|
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
|
2021-05-25 18:00:46 +02:00
|
|
|
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux),
|
2021-05-01 12:28:12 -07:00
|
|
|
in_args_tree, b_avals, 'custom_linear_solve')
|
2021-05-25 18:00:46 +02:00
|
|
|
_check_tree("transpose_solve", "b", out_tree, tree, has_aux)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
|
|
|
|
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
|
|
|
|
jaxprs = _LinearSolveTuple(
|
|
|
|
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
|
|
|
|
|
|
|
|
out_flat = linear_solve_p.bind(
|
|
|
|
*(_flatten(all_consts) + b_flat),
|
|
|
|
const_lengths=const_lengths, jaxprs=jaxprs)
|
2021-05-25 18:00:46 +02:00
|
|
|
|
|
|
|
return tree_unflatten(out_tree, out_flat)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
|
2021-05-25 18:00:46 +02:00
|
|
|
args_to_raise = args[sum(const_lengths):]
|
|
|
|
|
|
|
|
# raise aux_args to shaped arrays as well if present
|
|
|
|
# number of aux args is the difference in out_avals
|
|
|
|
# of solve and matvec (since they map to the same vector space)
|
|
|
|
|
|
|
|
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
|
|
|
|
if num_aux > 0:
|
|
|
|
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
|
|
|
|
return _map(raise_to_shaped, args_to_raise)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
|
|
|
params, b = _split_linear_solve_args(args, const_lengths)
|
|
|
|
x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def _tangent_linear_map(func, params, params_dot, *x):
|
|
|
|
"""Compute the tangent of a linear map.
|
|
|
|
|
|
|
|
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
|
|
|
|
this function computes ``∂A @ x``.
|
|
|
|
"""
|
|
|
|
assert any(type(p) is not ad_util.Zero for p in params_dot)
|
|
|
|
zeros = _map(ad_util.Zero.from_value, x)
|
|
|
|
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
|
|
|
|
params + list(x), params_dot + zeros)
|
|
|
|
return out_tangent
|
|
|
|
|
|
|
|
|
|
|
|
def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
|
|
|
|
# A x - b = 0
|
|
|
|
# ∂A x + A ∂x - ∂b = 0
|
|
|
|
# ∂x = A^{-1} (∂b - ∂A x)
|
|
|
|
|
|
|
|
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs)
|
|
|
|
x = linear_solve_p.bind(*primals, **kwargs)
|
|
|
|
|
|
|
|
params, _ = _split_linear_solve_args(primals, const_lengths)
|
|
|
|
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
|
|
|
|
|
2021-05-25 18:00:46 +02:00
|
|
|
num_x_leaves = len(b_dot)
|
|
|
|
# x is a flat tree with possible aux values appended
|
|
|
|
# since x_tree == b_tree == b_dot_tree, we can cut off
|
|
|
|
# aux values with len info provided by b_dot tree here
|
|
|
|
x_leaves, _ = split_list(x, [num_x_leaves])
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
if all(type(p) is ad_util.Zero for p in params_dot.matvec):
|
|
|
|
# no need to evaluate matvec_tangents
|
|
|
|
rhs = b_dot
|
|
|
|
else:
|
|
|
|
matvec_tangents = _tangent_linear_map(
|
2021-05-25 18:00:46 +02:00
|
|
|
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves)
|
2020-10-17 14:33:26 -04:00
|
|
|
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
|
|
|
|
|
|
|
|
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
|
|
|
|
|
2021-05-25 18:00:46 +02:00
|
|
|
# split into x tangents and aux tangents (these become zero)
|
|
|
|
dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves])
|
|
|
|
|
|
|
|
daux_leaves = _map(ad_util.Zero.from_value, daux_leaves)
|
|
|
|
|
|
|
|
x_dot = dx_leaves + daux_leaves
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
return x, x_dot
|
|
|
|
|
|
|
|
|
|
|
|
def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
|
|
|
|
if jaxprs.transpose_solve is None:
|
|
|
|
raise TypeError('transpose_solve required for backwards mode automatic '
|
|
|
|
'differentiation of custom_linear_solve')
|
|
|
|
|
|
|
|
params, b = _split_linear_solve_args(primals, const_lengths)
|
2021-05-25 18:00:46 +02:00
|
|
|
# split off symbolic zeros in the cotangent if present
|
|
|
|
x_cotangent, _ = split_list(cotangent, [len(b)])
|
2020-10-17 14:33:26 -04:00
|
|
|
assert all(ad.is_undefined_primal(x) for x in b)
|
2021-05-25 18:00:46 +02:00
|
|
|
cotangent_b_full = linear_solve_p.bind(
|
|
|
|
*(_flatten(params.transpose()) + x_cotangent),
|
2020-10-17 14:33:26 -04:00
|
|
|
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
|
2021-05-25 18:00:46 +02:00
|
|
|
# drop aux values in cotangent computation
|
|
|
|
cotangent_b, _ = split_list(cotangent_b_full, [len(b)])
|
2020-10-17 14:33:26 -04:00
|
|
|
return [None] * sum(const_lengths) + cotangent_b
|
|
|
|
|
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
|
|
|
|
const_lengths, jaxprs):
|
2020-10-17 14:33:26 -04:00
|
|
|
orig_bat = [d is not batching.not_mapped for d in dims]
|
|
|
|
|
|
|
|
params, b = _split_linear_solve_args(args, const_lengths)
|
|
|
|
params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
|
|
|
|
params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)
|
|
|
|
|
|
|
|
(matvec, vecmat, solve, solve_t) = jaxprs
|
|
|
|
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
|
|
|
|
|
2021-05-25 18:00:46 +02:00
|
|
|
num_aux = len(solve.out_avals) - len(matvec.out_avals)
|
2020-10-17 14:33:26 -04:00
|
|
|
# Fixpoint computation of which parts of x and b are batched; we need to
|
|
|
|
# ensure this is consistent between all four jaxprs
|
|
|
|
b_bat = orig_b_bat
|
|
|
|
x_bat = [False] * len(solve.out_avals)
|
|
|
|
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
|
|
|
|
# Apply vecmat and solve -> new batched parts of x
|
|
|
|
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
|
2021-08-26 13:34:01 -07:00
|
|
|
solve, axis_size, solve_bat + b_bat, instantiate=x_bat,
|
2021-04-20 08:32:41 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
if vecmat is None:
|
|
|
|
vecmat_jaxpr_batched = None
|
|
|
|
x_bat_out = solve_x_bat
|
|
|
|
else:
|
|
|
|
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
|
2021-08-26 13:34:01 -07:00
|
|
|
vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat,
|
2021-04-20 08:32:41 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
2021-05-25 18:00:46 +02:00
|
|
|
# batch all aux data by default
|
|
|
|
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
# Apply matvec and solve_t -> new batched parts of b
|
|
|
|
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
|
2021-08-26 13:34:01 -07:00
|
|
|
matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat,
|
2021-04-20 08:32:41 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
if solve_t is None:
|
|
|
|
solve_t_jaxpr_batched = None
|
|
|
|
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
|
|
|
|
else:
|
2021-10-06 14:18:07 -07:00
|
|
|
solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
|
2021-08-26 13:34:01 -07:00
|
|
|
solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat,
|
2021-04-20 08:32:41 -07:00
|
|
|
axis_name=axis_name, main_type=main_type)
|
2021-10-06 14:18:07 -07:00
|
|
|
assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
|
|
|
|
solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
|
2020-10-17 14:33:26 -04:00
|
|
|
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
|
|
|
|
orig_b_bat)
|
|
|
|
if x_bat_out == x_bat and b_bat_out == b_bat:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
x_bat = x_bat_out
|
|
|
|
b_bat = b_bat_out
|
|
|
|
else:
|
|
|
|
assert False, "Fixedpoint not reached"
|
|
|
|
|
|
|
|
batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched,
|
|
|
|
solve_jaxpr_batched, solve_t_jaxpr_batched)
|
|
|
|
|
|
|
|
# Move batched axes to the front
|
|
|
|
new_params = [
|
|
|
|
batching.moveaxis(x, d, 0)
|
|
|
|
if d is not batching.not_mapped and d != 0 else x
|
|
|
|
for x, d in zip(_flatten(params), _flatten(params_dims))
|
|
|
|
]
|
|
|
|
# Broadcast out b if necessary
|
|
|
|
new_b = [
|
2021-08-26 13:34:01 -07:00
|
|
|
batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
|
|
|
|
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
|
|
|
|
]
|
|
|
|
|
|
|
|
outs = linear_solve_p.bind(
|
|
|
|
*(new_params + new_b),
|
|
|
|
const_lengths=const_lengths,
|
|
|
|
jaxprs=batched_jaxprs)
|
2021-05-25 18:00:46 +02:00
|
|
|
out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat]
|
2020-10-17 14:33:26 -04:00
|
|
|
return outs, out_dims
|
|
|
|
|
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
linear_solve_p = core.AxisPrimitive('custom_linear_solve')
|
2020-10-17 14:33:26 -04:00
|
|
|
linear_solve_p.multiple_results = True
|
|
|
|
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
|
|
|
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
|
|
|
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
2022-04-13 07:25:51 -07:00
|
|
|
xla.register_initial_style_primitive(linear_solve_p)
|
2022-04-06 12:53:19 -07:00
|
|
|
mlir.register_lowering(
|
|
|
|
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
|
|
|
|
multiple_results=True))
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
|
2021-08-26 13:34:01 -07:00
|
|
|
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
|
2022-01-19 20:40:38 -08:00
|
|
|
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \
|
|
|
|
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _interleave(a, b, axis):
|
|
|
|
"""Given two Tensors of static shape, interleave them along the first axis."""
|
|
|
|
assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
|
|
|
|
a_pad = [(0, 0, 0)] * a.ndim
|
|
|
|
b_pad = [(0, 0, 0)] * b.ndim
|
|
|
|
a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
|
|
|
|
b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
|
2021-05-12 10:28:55 -04:00
|
|
|
op = lax.bitwise_or if a.dtype == np.bool_ else lax.add
|
|
|
|
return op(lax.pad(a, lax._const(a, 0), a_pad),
|
|
|
|
lax.pad(b, lax._const(b, 0), b_pad))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-23 20:08:10 -08:00
|
|
|
@api_boundary
|
2020-10-17 14:33:26 -04:00
|
|
|
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
|
|
|
"""Performs a scan with an associative binary operation, in parallel.
|
|
|
|
|
|
|
|
For an introduction to associative scans, see [BLE1990]_.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fn: A Python callable implementing an associative binary operation with
|
|
|
|
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
|
|
|
|
must satisfy the equation
|
|
|
|
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
|
|
|
|
|
|
|
|
The inputs and result are (possibly nested Python tree structures of)
|
|
|
|
array(s) matching ``elems``. Each array has a dimension in place
|
|
|
|
of the ``axis`` dimension. `fn` should be applied elementwise over
|
|
|
|
the ``axis`` dimension (for example, by using :func:`jax.vmap` over the
|
|
|
|
elementwise function.)
|
|
|
|
|
|
|
|
The result ``r`` has the same shape (and structure) as the two inputs
|
|
|
|
``a`` and ``b``.
|
|
|
|
elems: A (possibly nested Python tree structure of) array(s), each with
|
|
|
|
an ``axis`` dimension of size ``num_elems``.
|
|
|
|
reverse: A boolean stating if the scan should be reversed with respect to
|
|
|
|
the ``axis`` dimension.
|
|
|
|
axis: an integer identifying the axis over which the scan should occur.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A (possibly nested Python tree structure of) array(s) of the same shape
|
|
|
|
and structure as ``elems``, in which the ``k``'th element of ``axis`` is the
|
|
|
|
result of recursively applying ``fn`` to combine the first ``k`` elements
|
|
|
|
of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``,
|
|
|
|
the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
|
|
|
|
|
|
|
|
Example 1: partial sums of an array of numbers:
|
|
|
|
|
|
|
|
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
|
2021-04-02 16:43:10 -07:00
|
|
|
DeviceArray([0, 1, 3, 6], dtype=int32)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Example 2: partial products of an array of matrices
|
|
|
|
|
|
|
|
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
|
|
|
|
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
|
|
|
|
>>> partial_prods.shape
|
|
|
|
(4, 2, 2)
|
|
|
|
|
|
|
|
Example 3: reversed partial sums of an array of numbers
|
|
|
|
|
|
|
|
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
|
2021-04-02 16:43:10 -07:00
|
|
|
DeviceArray([6, 6, 5, 3], dtype=int32)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
|
|
|
|
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
|
|
|
|
University.
|
|
|
|
"""
|
2022-04-13 10:39:01 -07:00
|
|
|
if not callable(fn):
|
|
|
|
raise TypeError("lax.associative_scan: fn argument should be callable.")
|
2020-10-17 14:33:26 -04:00
|
|
|
elems_flat, tree = tree_flatten(elems)
|
|
|
|
|
|
|
|
if reverse:
|
|
|
|
elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat]
|
|
|
|
|
|
|
|
def combine(a_flat, b_flat):
|
|
|
|
# Lower `fn` to operate on flattened sequences of elems.
|
|
|
|
a = tree_unflatten(tree, a_flat)
|
|
|
|
b = tree_unflatten(tree, b_flat)
|
|
|
|
c = fn(a, b)
|
|
|
|
c_flat, _ = tree_flatten(c)
|
|
|
|
return c_flat
|
|
|
|
|
|
|
|
# Check that all inputs have a consistent leading dimension `num_elems`.
|
2021-11-23 16:54:02 -08:00
|
|
|
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
|
2020-10-17 14:33:26 -04:00
|
|
|
num_elems = int(elems_flat[0].shape[axis])
|
|
|
|
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
|
|
|
|
raise ValueError('Array inputs to associative_scan must have the same '
|
|
|
|
'first dimension. (saw: {})'
|
2021-06-03 06:36:57 -07:00
|
|
|
.format([elem.shape for elem in elems_flat]))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
# Summary of algorithm:
|
|
|
|
#
|
|
|
|
# Consider elements of `_scan(elems)` at odd indices. That's the same as first
|
|
|
|
# summing successive pairs of elements of `elems` and performing a scan on
|
|
|
|
# that half sized tensor. We perform the latter scan by recursion.
|
|
|
|
#
|
|
|
|
# Now consider the even elements of `_scan(elems)`. These can be computed
|
|
|
|
# from the odd elements of `_scan(elems)` by adding each odd element of
|
|
|
|
# `_scan(elems)` to the matching even element in the original `elems`.
|
|
|
|
#
|
|
|
|
# We return the odd and even elements interleaved.
|
|
|
|
#
|
|
|
|
# For the base case of the recursion we return the first element
|
|
|
|
# of `elems` followed by the sum of the first two elements computed as
|
|
|
|
# a (small two-down-to-one) reduction step.
|
|
|
|
def _scan(elems):
|
|
|
|
"""Perform scan on `elems`."""
|
|
|
|
|
|
|
|
num_elems = elems[0].shape[axis]
|
|
|
|
|
|
|
|
if num_elems < 2:
|
|
|
|
return elems
|
|
|
|
|
|
|
|
# Combine adjacent pairs of elements.
|
|
|
|
reduced_elems = combine(
|
2021-11-23 16:34:33 -08:00
|
|
|
[slicing.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems],
|
|
|
|
[slicing.slice_in_dim(elem, 1, None, stride=2, axis=axis)
|
|
|
|
for elem in elems])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
# Recursively compute scan for partially reduced tensors.
|
|
|
|
odd_elems = _scan(reduced_elems)
|
|
|
|
|
|
|
|
if num_elems % 2 == 0:
|
|
|
|
even_elems = combine(
|
2021-11-23 16:34:33 -08:00
|
|
|
[slicing.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems],
|
|
|
|
[slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
even_elems = combine(
|
|
|
|
odd_elems,
|
2021-11-23 16:34:33 -08:00
|
|
|
[slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
# The first element of a scan is the same as the first element
|
|
|
|
# of the original `elems`.
|
|
|
|
even_elems = [
|
2021-11-23 16:34:33 -08:00
|
|
|
lax.concatenate([slicing.slice_in_dim(elem, 0, 1, axis=axis), result],
|
2020-10-17 14:33:26 -04:00
|
|
|
dimension=axis)
|
|
|
|
for (elem, result) in zip(elems, even_elems)]
|
|
|
|
return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems))
|
|
|
|
|
|
|
|
scans = _scan(elems_flat)
|
|
|
|
|
|
|
|
if reverse:
|
|
|
|
scans = [lax.rev(scanned, [axis]) for scanned in scans]
|
|
|
|
|
|
|
|
return tree_unflatten(tree, scans)
|
|
|
|
|
|
|
|
|
|
|
|
# Cumulative reductions.
|
|
|
|
|
|
|
|
def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
|
|
|
"""Computes a cumulative sum along `axis`."""
|
|
|
|
return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
|
|
|
|
|
|
|
def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
|
|
|
"""Computes a cumulative product along `axis`."""
|
|
|
|
return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
|
|
|
|
|
|
|
def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
|
|
|
"""Computes a cumulative maximum along `axis`."""
|
|
|
|
return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
|
|
|
|
|
|
|
def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
|
|
|
"""Computes a cumulative minimum along `axis`."""
|
|
|
|
return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
|
|
|
|
|
|
|
def _cumred_shape_rule(x, *, axis: int, reverse: bool):
|
|
|
|
if axis < 0 or axis >= x.ndim:
|
|
|
|
raise ValueError(
|
|
|
|
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
|
|
|
|
return x.shape
|
|
|
|
|
2020-12-30 17:42:04 -08:00
|
|
|
def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool):
|
2020-10-17 14:33:26 -04:00
|
|
|
return [cumsum(t, axis=axis, reverse=not reverse)]
|
|
|
|
|
|
|
|
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
def cumred_tpu_impl(window_reduce: Callable, x, *, axis: int, reverse: bool):
|
2020-10-17 14:33:26 -04:00
|
|
|
# On TPU, an implementation using reduce_window is handled specially by the
|
|
|
|
# compiler and is efficient. On other backends, it is O(n^2).
|
|
|
|
n = x.shape[axis]
|
|
|
|
if n == 0:
|
|
|
|
return x
|
|
|
|
padding = [(0, 0)] * x.ndim
|
|
|
|
padding[axis] = (0, n - 1) if reverse else (n - 1, 0)
|
|
|
|
strides = [1] * x.ndim
|
|
|
|
window_dims = [1] * x.ndim
|
|
|
|
window_dims[axis] = n
|
|
|
|
return window_reduce(x, window_dims, strides, padding)
|
|
|
|
|
|
|
|
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int,
|
|
|
|
reverse: bool):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
axis = axis if axis < bdim else axis + 1
|
|
|
|
return prim.bind(operand, axis=axis, reverse=reverse), bdim
|
|
|
|
|
|
|
|
def _cumred_dtype_rule(name, operand, *args, **kw):
|
|
|
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
|
|
|
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
|
|
|
|
"of number.".format(name, np.dtype(operand.dtype).name))
|
|
|
|
return dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
2021-05-21 11:10:41 +03:00
|
|
|
def _cumulative_reduction_primitive(name,
|
|
|
|
reduce_fn,
|
|
|
|
tpu_reduce_window_fn):
|
2020-10-17 14:33:26 -04:00
|
|
|
reducer_p = lax.standard_primitive(
|
|
|
|
_cumred_shape_rule, partial(_cumred_dtype_rule, name),
|
2022-04-18 08:28:08 -07:00
|
|
|
name)
|
2021-10-19 09:47:55 -07:00
|
|
|
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule,
|
|
|
|
reducer_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(
|
2021-12-16 08:34:10 -08:00
|
|
|
reducer_p,
|
|
|
|
mlir.cache_lowering(
|
|
|
|
mlir.lower_fun(partial(associative_scan, reduce_fn),
|
|
|
|
multiple_results=False)))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(
|
|
|
|
reducer_p,
|
2022-04-18 08:28:08 -07:00
|
|
|
mlir.lower_fun(partial(cumred_tpu_impl, tpu_reduce_window_fn),
|
|
|
|
multiple_results=False),
|
2021-11-23 18:57:45 -08:00
|
|
|
platform='tpu')
|
2020-10-17 14:33:26 -04:00
|
|
|
return reducer_p
|
|
|
|
|
2021-11-18 16:31:06 -05:00
|
|
|
cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, windowed_reductions._reduce_window_sum)
|
2021-05-21 11:10:41 +03:00
|
|
|
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
|
2021-11-18 16:31:06 -05:00
|
|
|
cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, windowed_reductions._reduce_window_prod)
|
|
|
|
cummax_p = _cumulative_reduction_primitive("cummax", lax.max, windowed_reductions._reduce_window_max)
|
|
|
|
cummin_p = _cumulative_reduction_primitive("cummin", lax.min, windowed_reductions._reduce_window_min)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
|
|
|
|
combine_fn: Callable):
|
|
|
|
# Irrespective of backend, we always use the parallel prefix scan
|
|
|
|
# implementation when differentiating because reduce_window is not
|
|
|
|
# arbitrarily differentiable.
|
|
|
|
return api.jvp(partial(associative_scan, combine_fn, axis=axis,
|
|
|
|
reverse=reverse),
|
|
|
|
primals, tangents)
|
|
|
|
|
|
|
|
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
|
|
|
|
ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
|
|
|
|
ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)
|
2021-12-02 08:19:16 +02:00
|
|
|
|
|
|
|
|
|
|
|
def _dummy_remat_result(aval: core.AbstractValue):
|
|
|
|
"""A result that will be discarded"""
|
2021-12-16 03:45:14 -08:00
|
|
|
if aval is core.abstract_token:
|
2021-12-02 08:19:16 +02:00
|
|
|
return lax.create_token()
|
|
|
|
elif aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore
|
|
|
|
|
|
|
|
def _remat_translation_using_cond(*args,
|
|
|
|
jaxpr: core.Jaxpr):
|
|
|
|
# Implements:
|
|
|
|
# if(rng(0, 1) < 2)
|
|
|
|
# return eval_jaxpr(*args)
|
|
|
|
# else:
|
|
|
|
# return 0
|
|
|
|
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
|
|
|
|
|
|
|
|
def remat_comp(*args):
|
|
|
|
return tuple(core.eval_jaxpr(jaxpr, (), *args))
|
|
|
|
def dummy_comp(*args):
|
|
|
|
return tuple(_map(_dummy_remat_result, avals_out))
|
|
|
|
|
|
|
|
cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2))
|
|
|
|
return cond(cond_pred, remat_comp, dummy_comp, *args)
|
|
|
|
|
|
|
|
def _remat_translation_using_while(*args,
|
|
|
|
jaxpr: core.Jaxpr):
|
|
|
|
# Implements:
|
|
|
|
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
|
|
|
|
# result = eval_jaxpr(*args)
|
|
|
|
# }
|
|
|
|
# The loop carry is a tuple: (counter, result, args)
|
|
|
|
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
|
|
|
|
dummies_like_result = tuple(_map(_dummy_remat_result, avals_out))
|
|
|
|
carry_init = (np.int32(0), dummies_like_result, args)
|
|
|
|
def cond(carry):
|
|
|
|
counter, _, _ = carry
|
|
|
|
return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())
|
|
|
|
|
|
|
|
def body(carry):
|
|
|
|
counter, _, args = carry
|
|
|
|
results = core.eval_jaxpr(jaxpr, (), *args)
|
|
|
|
return (counter + 1, tuple(results), args)
|
|
|
|
|
|
|
|
carry_res = while_loop(cond, body, carry_init)
|
|
|
|
return carry_res[1]
|
|
|
|
|
2022-02-13 22:40:26 -08:00
|
|
|
|
|
|
|
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
|
|
|
args = _optimization_barrier(args)
|
|
|
|
return core.eval_jaxpr(jaxpr, (), *args)
|
|
|
|
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
def remat_impl(*args,
|
|
|
|
call_jaxpr: Optional[core.Jaxpr] = None,
|
|
|
|
jaxpr: Optional[core.Jaxpr] = None,
|
|
|
|
platform: str,
|
|
|
|
prevent_cse: bool, differentiated: bool,
|
|
|
|
policy,
|
|
|
|
concrete: bool = False,
|
|
|
|
name: str = "checkpoint"):
|
2021-12-02 08:19:16 +02:00
|
|
|
# Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat)
|
|
|
|
# name is not passed for remat2, defaults to "checkpoint"
|
|
|
|
# TODO: remove call_jaxpr once we drop the remat call primitive
|
|
|
|
if jaxpr is None:
|
|
|
|
jaxpr = call_jaxpr
|
|
|
|
assert jaxpr is not None
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
|
|
|
|
del concrete, policy # Unused.
|
|
|
|
if differentiated and prevent_cse:
|
2022-04-13 22:51:36 -07:00
|
|
|
if config.jax_remat_opt_barrier:
|
2022-02-13 22:40:26 -08:00
|
|
|
translation_rule = _remat_translation_using_opt_barrier
|
2022-04-13 22:51:36 -07:00
|
|
|
elif platform == 'gpu':
|
|
|
|
translation_rule = _remat_translation_using_while
|
2021-12-02 08:19:16 +02:00
|
|
|
else:
|
|
|
|
translation_rule = _remat_translation_using_cond
|
|
|
|
else:
|
|
|
|
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
|
|
|
|
|
|
|
|
return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
|
|
|
|
|
|
|
|
for platform in ("cpu", "gpu", "tpu"):
|
|
|
|
for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore
|
|
|
|
mlir.register_lowering(remat_primitive,
|
2022-04-18 08:28:08 -07:00
|
|
|
mlir.lower_fun(partial(remat_impl,
|
2021-12-02 08:19:16 +02:00
|
|
|
platform=platform),
|
2022-02-13 22:40:26 -08:00
|
|
|
multiple_results=True),
|
|
|
|
platform=platform)
|
|
|
|
|
|
|
|
|
|
|
|
def _optimization_barrier_abstract_eval(*args):
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def _optimization_barrier_lowering_rule(ctx, *args):
|
|
|
|
barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
|
|
|
|
flat_barrier_types = util.flatten(barrier_types)
|
|
|
|
|
|
|
|
flat_args = mlir.flatten_lowering_ir_args(args)
|
|
|
|
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
|
|
|
|
return util.unflatten(barrier_op.results, _map(len, barrier_types))
|
|
|
|
|
|
|
|
|
|
|
|
def _optimization_barrier(arg):
|
|
|
|
flat_args, treedef = tree_flatten(arg)
|
|
|
|
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
|
|
|
|
|
|
|
|
|
|
|
|
optimization_barrier_p = core.Primitive('optimization_barrier')
|
|
|
|
optimization_barrier_p.multiple_results = True
|
|
|
|
optimization_barrier_p.def_impl(
|
|
|
|
partial(xla.apply_primitive, optimization_barrier_p))
|
|
|
|
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
|
|
|
mlir.register_lowering(optimization_barrier_p,
|
|
|
|
_optimization_barrier_lowering_rule)
|