2020-10-17 14:33:26 -04:00
|
|
|
# coding=utf-8
|
|
|
|
# Copyright 2019 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""
|
|
|
|
Control flow primitives.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import collections
|
|
|
|
import functools
|
|
|
|
import inspect
|
|
|
|
import itertools
|
|
|
|
import operator
|
|
|
|
import os
|
|
|
|
from typing import Any, Callable, Sequence, TypeVar
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import jax
|
|
|
|
from jax import api
|
|
|
|
from jax import core
|
|
|
|
from jax import dtypes
|
2020-11-04 11:54:01 -08:00
|
|
|
from jax._src import source_info_util
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax import util
|
|
|
|
from jax._src.lax import lax
|
|
|
|
from jax import linear_util as lu
|
2020-11-18 21:17:02 -05:00
|
|
|
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax.api_util import flatten_fun_nokwargs
|
|
|
|
from jax.interpreters import ad
|
|
|
|
from jax.interpreters import partial_eval as pe
|
|
|
|
from jax.interpreters import xla
|
|
|
|
from jax.interpreters import batching
|
|
|
|
from jax.interpreters import masking
|
|
|
|
from jax.lib import xla_bridge as xb
|
|
|
|
from jax.lib import xla_client
|
2020-12-30 10:02:18 -08:00
|
|
|
from jax.util import (partial, unzip2, unzip3, safe_map, safe_zip, split_list,
|
2020-10-17 14:33:26 -04:00
|
|
|
cache, extend_name_stack)
|
|
|
|
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
|
|
|
treedef_children, treedef_tuple, tree_multimap,
|
|
|
|
tree_leaves)
|
|
|
|
from jax import ad_util
|
|
|
|
from jax.config import config
|
|
|
|
|
|
|
|
xops = xla_client.ops
|
|
|
|
|
|
|
|
_map = safe_map
|
|
|
|
zip = safe_zip
|
|
|
|
_reduce = functools.reduce
|
|
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
Array = Any
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
|
|
|
|
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
2020-12-30 10:02:18 -08:00
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
|
|
|
return jaxpr, consts, out_tree()
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
|
2020-12-30 10:02:18 -08:00
|
|
|
jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
|
2020-10-17 14:33:26 -04:00
|
|
|
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
return closed_jaxpr, consts, out_tree
|
|
|
|
|
2020-12-30 10:02:18 -08:00
|
|
|
@cache()
|
2020-10-17 14:33:26 -04:00
|
|
|
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
|
2020-12-30 10:02:18 -08:00
|
|
|
in_tree, in_avals):
|
2020-10-17 14:33:26 -04:00
|
|
|
# When staging the branches of a conditional into jaxprs, constants are
|
|
|
|
# extracted from each branch and converted to jaxpr arguments. To use the
|
|
|
|
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
|
|
|
# their (input) signatures to match. This function "joins" the staged jaxprs:
|
|
|
|
# for each one, it makes another that accepts *all* constants, but only uses
|
|
|
|
# those that it needs (dropping the rest).
|
|
|
|
|
2020-12-30 10:02:18 -08:00
|
|
|
jaxprs, all_consts, all_out_trees = unzip3(
|
2020-10-17 14:33:26 -04:00
|
|
|
_initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs)
|
|
|
|
|
|
|
|
newvar = core.gensym(jaxprs, suffix='_')
|
|
|
|
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
|
2020-12-30 10:02:18 -08:00
|
|
|
for consts in all_consts]
|
2020-10-17 14:33:26 -04:00
|
|
|
unused_const_vars = [[newvar(aval) for aval in const_avals]
|
2020-12-30 10:02:18 -08:00
|
|
|
for const_avals in all_const_avals]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def pad_jaxpr_constvars(i, jaxpr):
|
|
|
|
prefix = util.concatenate(unused_const_vars[:i])
|
2020-12-30 10:02:18 -08:00
|
|
|
suffix = util.concatenate(unused_const_vars[i + 1:])
|
2020-10-17 14:33:26 -04:00
|
|
|
constvars = [*prefix, *jaxpr.constvars, *suffix]
|
|
|
|
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
|
|
|
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
|
|
|
|
|
|
|
consts = util.concatenate(all_consts)
|
|
|
|
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
|
|
|
|
closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
2020-12-30 10:02:18 -08:00
|
|
|
for jaxpr in jaxprs]
|
2020-10-17 14:33:26 -04:00
|
|
|
return closed_jaxprs, consts, all_out_trees
|
|
|
|
|
|
|
|
def _abstractify(x):
|
|
|
|
return raise_to_shaped(core.get_aval(x))
|
|
|
|
|
|
|
|
def _disable_jit_impl(prim, interp, *args, **kwargs):
|
|
|
|
if jax.api._jit_is_disabled():
|
|
|
|
return interp(*args, **kwargs)
|
|
|
|
else:
|
|
|
|
return xla.apply_primitive(prim, *args, **kwargs)
|
|
|
|
|
|
|
|
def _typecheck_param(prim, param, name, msg_required, pred):
|
|
|
|
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
|
|
|
|
f'{msg_required} required:')
|
|
|
|
param_str = str(param)
|
|
|
|
sep = os.linesep if os.linesep in param_str else ' '
|
|
|
|
msg = sep.join([msg, param_str])
|
|
|
|
core.typecheck_assert(pred, msg)
|
|
|
|
|
|
|
|
|
|
|
|
### fori_loop and while_loop
|
|
|
|
|
|
|
|
def _fori_cond_fun(loop_carry):
|
|
|
|
i, upper, _ = loop_carry
|
|
|
|
return lax.lt(i, upper)
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _fori_body_fun(body_fun):
|
|
|
|
def while_body_fun(loop_carry):
|
|
|
|
i, upper, x = loop_carry
|
|
|
|
return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
|
|
|
|
return while_body_fun
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _fori_scan_body_fun(body_fun):
|
|
|
|
def scanned_fun(loop_carry, _):
|
|
|
|
i, upper, x = loop_carry
|
|
|
|
return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None
|
|
|
|
return scanned_fun
|
|
|
|
|
|
|
|
def fori_loop(lower, upper, body_fun, init_val):
|
|
|
|
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
|
|
|
|
|
|
|
|
The type signature in brief is
|
|
|
|
|
|
|
|
.. code-block:: haskell
|
|
|
|
|
|
|
|
fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a
|
|
|
|
|
|
|
|
The semantics of ``fori_loop`` are given by this Python implementation::
|
|
|
|
|
|
|
|
def fori_loop(lower, upper, body_fun, init_val):
|
|
|
|
val = init_val
|
|
|
|
for i in range(lower, upper):
|
|
|
|
val = body_fun(i, val)
|
|
|
|
return val
|
|
|
|
|
|
|
|
Unlike that Python version, ``fori_loop`` is implemented in terms of a call to
|
|
|
|
:func:`jax.lax.while_loop`. See the :func:`jax.lax.while_loop` documentation
|
|
|
|
for more information.
|
|
|
|
|
|
|
|
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
|
|
|
fixed shape and dtype across all iterations (and not just be consistent up to
|
|
|
|
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
|
|
|
other words, the type ``a`` in the type signature above represents an array
|
|
|
|
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
|
|
|
structure with a fixed structure and arrays with fixed shape and dtype at the
|
|
|
|
leaves).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lower: an integer representing the loop index lower bound (inclusive)
|
|
|
|
upper: an integer representing the loop index upper bound (exclusive)
|
|
|
|
body_fun: function of type ``(int, a) -> a``.
|
|
|
|
init_val: initial loop carry value of type ``a``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Loop value from the final iteration, of type ``a``.
|
|
|
|
"""
|
|
|
|
# TODO(phawkins): perhaps do more type checking here, better error messages.
|
|
|
|
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
|
|
|
|
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
|
|
|
|
if lower_dtype != upper_dtype:
|
|
|
|
msg = ("lower and upper arguments to fori_loop must have equal types, "
|
|
|
|
"got {} and {}")
|
|
|
|
raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
|
|
|
|
|
|
|
|
# If we can specialize on the trip count, call scan instead of a while_loop
|
|
|
|
# to enable efficient reverse-mode differentiation.
|
|
|
|
try:
|
|
|
|
lower_ = int(lower)
|
|
|
|
upper_ = int(upper)
|
|
|
|
except TypeError:
|
|
|
|
use_scan = False
|
|
|
|
else:
|
|
|
|
use_scan = False # TODO(mattjj): re-enable this
|
|
|
|
|
|
|
|
if use_scan:
|
|
|
|
(_, _, result), _ = scan(_fori_scan_body_fun(body_fun),
|
|
|
|
(lower, upper, init_val), None,
|
|
|
|
length=upper_ - lower_)
|
|
|
|
else:
|
|
|
|
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
|
|
|
|
(lower, upper, init_val))
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def while_loop(cond_fun: Callable[[T], bool],
|
|
|
|
body_fun: Callable[[T], T],
|
|
|
|
init_val: T) -> T:
|
|
|
|
"""Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
|
|
|
|
|
|
|
|
The type signature in brief is
|
|
|
|
|
|
|
|
.. code-block:: haskell
|
|
|
|
|
|
|
|
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
|
|
|
|
|
|
|
|
The semantics of ``while_loop`` are given by this Python implementation::
|
|
|
|
|
|
|
|
def while_loop(cond_fun, body_fun, init_val):
|
|
|
|
val = init_val
|
|
|
|
while cond_fun(val):
|
|
|
|
val = body_fun(val)
|
|
|
|
return val
|
|
|
|
|
|
|
|
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
|
|
|
to a single XLA While HLO. That makes it useful for reducing compilation times
|
|
|
|
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
|
|
|
function are unrolled, leading to large XLA computations.
|
|
|
|
|
|
|
|
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
|
|
|
fixed shape and dtype across all iterations (and not just be consistent up to
|
|
|
|
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
|
|
|
other words, the type ``a`` in the type signature above represents an array
|
|
|
|
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
|
|
|
structure with a fixed structure and arrays with fixed shape and dtype at the
|
|
|
|
leaves).
|
|
|
|
|
|
|
|
Another difference from using Python-native loop constructs is that
|
|
|
|
``while_loop`` is not reverse-mode differentiable because XLA computations
|
|
|
|
require static bounds on memory requirements.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
cond_fun: function of type ``a -> Bool``.
|
|
|
|
body_fun: function of type ``a -> a``.
|
|
|
|
init_val: value of type ``a``, a type that can be a scalar, array, or any
|
|
|
|
pytree (nested Python tuple/list/dict) thereof, representing the initial
|
|
|
|
loop carry value.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The output from the final iteration of body_fun, of type ``a``.
|
|
|
|
"""
|
|
|
|
if jax.api._jit_is_disabled():
|
|
|
|
try:
|
|
|
|
val = init_val
|
|
|
|
while cond_fun(val):
|
|
|
|
val = body_fun(val)
|
|
|
|
return val
|
|
|
|
except core.ConcretizationTypeError:
|
|
|
|
# Can't run this while_loop in Python (e.g. because there's a vmap
|
|
|
|
# transformation on it), so we fall back to the primitive version.
|
|
|
|
pass
|
|
|
|
|
|
|
|
def _create_jaxpr(init_val):
|
|
|
|
init_vals, in_tree = tree_flatten((init_val,))
|
|
|
|
init_avals = tuple(_map(_abstractify, init_vals))
|
|
|
|
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
|
|
|
|
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
|
|
|
|
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
|
|
|
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
|
|
|
raise TypeError(msg.format(cond_tree))
|
|
|
|
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), np.bool_):
|
|
|
|
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
|
|
|
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
|
|
|
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
|
|
|
|
|
|
|
|
# The body input and output avals must match exactly. However, we want to account for
|
|
|
|
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
|
|
|
|
# may not match the output despite being compatible by virtue of their weak type.
|
|
|
|
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
|
|
|
|
# necessary, a second time with modified init values.
|
|
|
|
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
|
|
|
|
new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
|
|
|
|
if changed:
|
|
|
|
new_init_val, = tree_unflatten(in_tree, new_init_vals)
|
|
|
|
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
|
|
|
|
cond_jaxpr, cond_consts, body_consts, body_tree = rest
|
|
|
|
|
|
|
|
in_tree_children = in_tree.children()
|
|
|
|
assert len(in_tree_children) == 1
|
|
|
|
_check_tree_and_avals("body_fun output and input",
|
|
|
|
body_tree, body_jaxpr.out_avals,
|
|
|
|
in_tree_children[0], init_avals)
|
|
|
|
outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals),
|
|
|
|
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
|
|
|
|
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
|
|
|
|
return tree_unflatten(body_tree, outs)
|
|
|
|
|
|
|
|
def _while_loop_abstract_eval(*args, **kwargs):
|
|
|
|
return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
|
|
|
|
|
|
|
|
def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
|
|
|
cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts):
|
|
|
|
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
|
|
|
|
batched = bool(cond_jaxpr.out_avals[0].shape)
|
|
|
|
|
|
|
|
# Since jaxprs don't have tuples and have multiple return values, but we need
|
|
|
|
# the HLO While loop to take a single tuple input and output a single boolean
|
|
|
|
# (for the cond computation) or a single tuple output (for the body
|
|
|
|
# computation), we build XLA computations that handle the tuple munging before
|
|
|
|
# generating a Call into the computations formed from the jaxprs.
|
|
|
|
|
|
|
|
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
|
|
|
|
|
|
|
|
cond_c = xb.make_computation_builder("cond_computation")
|
|
|
|
cond_carry = xb.parameter(cond_c, 0, c.get_shape(init_carry))
|
|
|
|
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
|
|
|
|
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
|
|
|
|
pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env,
|
|
|
|
_map(partial(xb.constant, cond_c), cond_jaxpr.consts),
|
|
|
|
extend_name_stack(name_stack, 'cond'), *(x + z))
|
|
|
|
if batched:
|
|
|
|
scalar = ShapedArray((), np.bool_)
|
|
|
|
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
|
|
|
|
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, np.array(False))], or_,
|
|
|
|
list(range(cond_jaxpr.out_avals[0].ndim)))
|
|
|
|
|
|
|
|
body_c = xb.make_computation_builder("body_computation")
|
|
|
|
body_carry = xb.parameter(body_c, 0, c.get_shape(init_carry))
|
|
|
|
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
|
|
|
|
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
|
|
|
|
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
|
|
|
|
_map(partial(xb.constant, body_c), body_jaxpr.consts),
|
|
|
|
extend_name_stack(name_stack, 'body'), *(y + z))
|
|
|
|
if batched:
|
|
|
|
body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env,
|
|
|
|
_map(partial(xb.constant, body_c), cond_jaxpr.consts),
|
|
|
|
extend_name_stack(name_stack, 'body_pred'), *(x + z))
|
|
|
|
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z, body_jaxpr.out_avals)
|
|
|
|
assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast
|
|
|
|
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z)))
|
|
|
|
|
|
|
|
ans = xops.While(cond_c.build(pred), body_c.build(new_carry), init_carry)
|
|
|
|
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
|
|
|
|
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
|
|
|
|
return xops.Tuple(c, z)
|
|
|
|
|
|
|
|
def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue):
|
|
|
|
pred_shape = c.get_shape(pred).dimensions()
|
|
|
|
x_shape = c.get_shape(x).dimensions()
|
|
|
|
y_shape = c.get_shape(y).dimensions()
|
|
|
|
assert x_shape == y_shape
|
|
|
|
if x_y_aval is core.abstract_unit:
|
|
|
|
return x
|
|
|
|
elif x_y_aval is core.abstract_token:
|
|
|
|
return xops.AfterAll(c, [x, y])
|
|
|
|
else:
|
|
|
|
assert pred_shape == x_shape[:len(pred_shape)] == y_shape[:len(pred_shape)]
|
|
|
|
bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape))))
|
|
|
|
return xops.Select(bcast_pred, x, y)
|
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
def _while_loop_batching_rule(args, dims, axis_name,
|
|
|
|
cond_nconsts, cond_jaxpr,
|
2020-10-17 14:33:26 -04:00
|
|
|
body_nconsts, body_jaxpr):
|
|
|
|
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
|
|
|
orig_batched = [d is not batching.not_mapped for d in dims]
|
|
|
|
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
|
|
|
|
|
|
|
|
# Fixpoint computation of which carry are batched: either
|
|
|
|
# batched from init, or the carry out is batched. Each iteration promotes
|
|
|
|
# at least one carry to batched. We need at most len(carry) iterations,
|
|
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
|
|
# carry_bat.
|
|
|
|
carry_bat = init_bat
|
|
|
|
for _ in range(1 + len(carry_bat)):
|
|
|
|
batched = bconst_bat + carry_bat
|
|
|
|
body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
|
2020-10-26 10:11:13 +00:00
|
|
|
body_jaxpr, size, batched, instantiate=carry_bat, axis_name=axis_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
|
|
|
|
cond_jaxpr, size, cconst_bat + carry_bat,
|
2020-10-26 10:11:13 +00:00
|
|
|
instantiate=bool(cond_jaxpr.out_avals[0].shape),
|
|
|
|
axis_name=axis_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
|
|
|
|
if carry_bat_out == carry_bat:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_bat = _map(operator.or_, carry_bat, carry_bat_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
|
|
|
consts, init = split_list(args, [cond_nconsts + body_nconsts])
|
|
|
|
const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
|
|
|
|
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
|
|
|
else x for x, d in zip(consts, const_dims)]
|
|
|
|
new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
|
|
|
|
else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
|
|
|
|
for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]
|
|
|
|
|
|
|
|
outs = while_p.bind(*(new_consts + new_init),
|
|
|
|
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
|
|
|
|
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
|
|
|
|
out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
|
|
|
|
return outs, out_bdims
|
|
|
|
|
|
|
|
def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
|
|
|
|
body_jaxpr):
|
|
|
|
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts])
|
|
|
|
|
|
|
|
carry_nz = init_nz
|
|
|
|
for _ in range(1 + len(carry_nz)):
|
|
|
|
body_nonzeros = bconst_nz + carry_nz
|
|
|
|
body_jvp, nonzeros_out = ad.jvp_jaxpr(
|
|
|
|
body_jaxpr, body_nonzeros, instantiate=carry_nz)
|
|
|
|
if nonzeros_out == carry_nz:
|
|
|
|
break
|
|
|
|
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
|
|
|
nonzeros = cconst_nz + body_nonzeros
|
|
|
|
tangents = [ad.instantiate_zeros(t) if nz else t
|
|
|
|
for t, nz in zip(tangents, nonzeros)]
|
|
|
|
|
|
|
|
cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts])
|
|
|
|
_, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts])
|
|
|
|
bconst_dot = _prune_zeros(bconst_dot)
|
|
|
|
init_dot = _prune_zeros(init_dot)
|
|
|
|
|
|
|
|
num_carry = len(primals) - cond_nconsts - body_nconsts
|
|
|
|
|
|
|
|
body_jvp_rearranged = ad.rearrange_binders(
|
|
|
|
body_jvp,
|
|
|
|
[body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
|
|
|
|
[num_carry], [len(init_dot)])
|
|
|
|
|
|
|
|
newvar = core.gensym([cond_jaxpr.jaxpr])
|
|
|
|
invars_aug = (
|
|
|
|
cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot])
|
|
|
|
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
|
|
|
|
invars_aug,
|
|
|
|
cond_jaxpr.jaxpr.outvars,
|
|
|
|
cond_jaxpr.jaxpr.eqns)
|
|
|
|
cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts)
|
|
|
|
|
|
|
|
out = while_p.bind(
|
|
|
|
*(cconst + bconst + bconst_dot + init + init_dot),
|
|
|
|
cond_nconsts=cond_nconsts,
|
|
|
|
cond_jaxpr=cond_jaxpr_augmented,
|
|
|
|
body_nconsts=len(bconst) + len(bconst_dot),
|
|
|
|
body_jaxpr=body_jvp_rearranged)
|
|
|
|
|
|
|
|
out_carry, out_carry_dot = split_list(out, [num_carry])
|
|
|
|
out_tangents_iter = iter(out_carry_dot)
|
|
|
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
|
|
|
for p, nz in zip(out_carry, nonzeros_out)]
|
|
|
|
return out_carry, out_tangents
|
|
|
|
|
|
|
|
def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
|
|
|
|
cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int,
|
|
|
|
body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]:
|
|
|
|
"""An implementation of partial evaluation for while.
|
|
|
|
As long as some carry (and hence output) are known and the output
|
|
|
|
of `cond_jaxpr` is known, we use a portion of the loop body to compute the known
|
|
|
|
outputs of the `while_loop`. For the unknown outputs we generate Jaxpr to run
|
|
|
|
the whole while, including recomputing the known parts.
|
|
|
|
|
|
|
|
This means that we don't actually save any computation by partial
|
|
|
|
evaluation if there are unknown outputs.
|
|
|
|
|
|
|
|
What this achieves is that we can give a proper error for reverse
|
|
|
|
differentiation of `while`, because in that use of partial evaluation the
|
|
|
|
primal inputs are considered "known", and only the tangent computation is
|
|
|
|
unknown (see issue #2129).
|
|
|
|
"""
|
|
|
|
unknowns = [not t.pval.is_known() for t in tracers]
|
|
|
|
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
|
|
|
|
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
|
|
|
|
|
|
|
|
if config.omnistaging_enabled:
|
|
|
|
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
|
|
|
else:
|
|
|
|
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
|
|
|
|
|
|
|
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
|
|
|
|
# Fixpoint computation of unknown carry. Each iteration promotes
|
|
|
|
# at least one carry to unknown. We need one last iteration to prepare the jaxpr.
|
|
|
|
carry_uk = carry_init_uk
|
|
|
|
for _ in range(1 + len(carry_uk)):
|
|
|
|
body_jaxpr_known, _, carry_out_uk = partial_eval_jaxpr( # type: ignore
|
|
|
|
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
|
|
|
|
if carry_out_uk == carry_uk:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
|
|
|
cond_jaxpr_known, _, cond_uk = partial_eval_jaxpr( # type: ignore
|
|
|
|
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
|
|
|
|
|
|
|
|
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
|
|
|
|
# If conditional is unknown, or all inputs are known, or all are unknown,
|
|
|
|
# just do the default processing.
|
|
|
|
return trace.default_process_primitive(while_p, tracers, params)
|
|
|
|
|
|
|
|
# Run the known part of the while. Prepare the inputs, as constants (if known), or
|
|
|
|
# as core.unit.
|
|
|
|
in_consts = [ core.unit if uk else t.pval.get_known()
|
|
|
|
for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk,
|
|
|
|
tracers)]
|
|
|
|
# There should be no residuals for the cond_jaxpr_known
|
|
|
|
assert 1 == len(cond_jaxpr_known.out_avals)
|
|
|
|
# We ignore the residuals from the body_jaxpr_known, so the type of inputs matches
|
|
|
|
# the type of outputs; residuals are at the end
|
|
|
|
if len(body_jaxpr_known.out_avals) > len(body_jaxpr.out_avals):
|
|
|
|
# TODO(necula): this is not quite enough; we should drop the residual computations also
|
|
|
|
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:len(body_jaxpr.out_avals)]
|
|
|
|
out_known = while_p.bind(
|
|
|
|
*in_consts,
|
|
|
|
cond_nconsts=cond_nconsts,
|
|
|
|
cond_jaxpr=cond_jaxpr_known,
|
|
|
|
body_nconsts=body_nconsts,
|
|
|
|
body_jaxpr=body_jaxpr_known)
|
|
|
|
|
|
|
|
# Run the whole while_loop to get all the outputs, then merge with known ones
|
|
|
|
out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params)
|
|
|
|
out_tracers: Sequence[pe.Tracer] = [
|
|
|
|
out_unknown if uk
|
|
|
|
else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe)
|
|
|
|
for uk, out_unknown, known in zip(carry_uk, out_all, out_known)]
|
|
|
|
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def _while_transpose_error(*_, **kwargs):
|
|
|
|
raise ValueError("Reverse-mode differentiation does not work for "
|
|
|
|
"lax.while_loop or lax.fori_loop. "
|
|
|
|
"Try using lax.scan instead.")
|
|
|
|
|
|
|
|
while_p = lax.Primitive('while')
|
|
|
|
while_p.multiple_results = True
|
|
|
|
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
|
|
|
while_p.def_abstract_eval(_while_loop_abstract_eval)
|
|
|
|
ad.primitive_jvps[while_p] = _while_loop_jvp
|
|
|
|
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
|
|
|
xla.initial_style_translations[while_p] = _while_loop_translation_rule
|
|
|
|
ad.primitive_transposes[while_p] = _while_transpose_error
|
2020-10-26 10:11:13 +00:00
|
|
|
batching.initial_style_batchers[while_p] = _while_loop_batching_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
### cond and switch
|
|
|
|
|
|
|
|
def switch(index, branches: Sequence[Callable], operand):
|
|
|
|
"""Apply exactly one of ``branches`` given by ``index``.
|
|
|
|
|
|
|
|
If ``index`` is out of bounds, it is clamped to within bounds.
|
|
|
|
|
|
|
|
Has the semantics of the following Python::
|
|
|
|
|
|
|
|
def switch(index, branches, operand):
|
|
|
|
index = clamp(0, index, len(branches) - 1)
|
|
|
|
return branches[index](operand)
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
index: Integer scalar type, indicating which branch function to apply.
|
|
|
|
branches: Sequence of functions (A -> B) to be applied based on `index`.
|
|
|
|
operand: Operand (A) input to whichever branch is applied.
|
|
|
|
"""
|
|
|
|
if len(np.shape(index)) != 0:
|
|
|
|
raise TypeError(
|
|
|
|
f"Branch index must be scalar, "
|
|
|
|
f"got {index} of shape {np.shape(index)}.")
|
|
|
|
|
|
|
|
try:
|
|
|
|
index_dtype = dtypes.result_type(index)
|
|
|
|
except TypeError as err:
|
|
|
|
msg = f"Index type must be an integer, got {index}."
|
|
|
|
raise TypeError(msg) from err
|
|
|
|
|
|
|
|
if index_dtype.kind not in 'iu':
|
|
|
|
raise TypeError(
|
|
|
|
f"Index type must be an integer, got {index} as {index_dtype}")
|
|
|
|
|
|
|
|
branches = tuple(branches)
|
|
|
|
|
|
|
|
if len(branches) == 0:
|
|
|
|
raise ValueError("Empty branch sequence")
|
|
|
|
elif len(branches) == 1:
|
|
|
|
return branches[0](operand)
|
|
|
|
|
|
|
|
index = lax.convert_element_type(index, np.int32)
|
|
|
|
lo = np.array(0, np.int32)
|
|
|
|
hi = np.array(len(branches) - 1, np.int32)
|
|
|
|
index = lax.clamp(lo, index, hi)
|
|
|
|
|
|
|
|
if (jax.api._jit_is_disabled() and
|
|
|
|
isinstance(core.get_aval(index), ConcreteArray)):
|
|
|
|
return branches[int(index)](operand)
|
|
|
|
|
|
|
|
ops, ops_tree = tree_flatten((operand,))
|
|
|
|
ops_avals = tuple(_map(_abstractify, ops))
|
|
|
|
|
|
|
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
|
|
|
branches, ops_tree, ops_avals)
|
|
|
|
|
|
|
|
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
|
|
|
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
|
|
|
out_trees[0], jaxprs[0].out_avals,
|
|
|
|
out_tree, jaxpr.out_avals)
|
|
|
|
|
|
|
|
linear = (False,) * (len(consts) + len(ops))
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
|
|
|
|
return tree_unflatten(out_trees[0], out)
|
|
|
|
|
|
|
|
|
|
|
|
def cond(*args, **kwargs):
|
|
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
|
|
|
|
``cond()`` has equivalent semantics to this Python implementation::
|
|
|
|
|
|
|
|
def cond(pred, true_fun, false_fun, operand):
|
|
|
|
if pred:
|
|
|
|
return true_fun(operand)
|
|
|
|
else:
|
|
|
|
return false_fun(operand)
|
|
|
|
|
|
|
|
``pred`` must be a scalar type.
|
|
|
|
|
|
|
|
Functions ``true_fun``/``false_fun`` may not need to refer to an ``operand``
|
|
|
|
to compute their result, but one must still be provided to the ``cond`` call
|
|
|
|
and be accepted by both the branch functions, e.g.::
|
|
|
|
|
|
|
|
jax.lax.cond(
|
|
|
|
get_predicate_value(),
|
|
|
|
lambda _: 23,
|
|
|
|
lambda _: 42,
|
|
|
|
operand=None)
|
|
|
|
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
pred: Boolean scalar type, indicating which branch function to apply.
|
|
|
|
true_fun: Function (A -> B), to be applied if ``pred`` is True.
|
|
|
|
false_fun: Function (A -> B), to be applied if ``pred`` is False.
|
|
|
|
operand: Operand (A) input to either branch depending on ``pred``. The type
|
|
|
|
can be a scalar, array, or any pytree (nested Python tuple/list/dict)
|
|
|
|
thereof.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Value (B) of either ``true_fun(operand)`` or ``false_fun(operand)``,
|
|
|
|
depending on the value of ``pred``. The type can be a scalar, array, or any
|
|
|
|
pytree (nested Python tuple/list/dict) thereof.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# detect an attempt to call the former, deprecated cond
|
|
|
|
try:
|
|
|
|
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
return _cond_with_per_branch_args(*ba.args)
|
|
|
|
|
|
|
|
return _cond(*args, **kwargs)
|
|
|
|
|
|
|
|
def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
|
|
|
|
if len(np.shape(pred)) != 0:
|
|
|
|
raise TypeError(
|
|
|
|
f"Pred must be a scalar, got {pred} of shape {np.shape(pred)}.")
|
|
|
|
|
|
|
|
try:
|
|
|
|
pred_dtype = dtypes.result_type(pred)
|
|
|
|
except TypeError as err:
|
|
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
|
|
raise TypeError(msg.format(pred)) from err
|
|
|
|
|
|
|
|
if pred_dtype.kind != 'b':
|
|
|
|
if pred_dtype.kind in 'iuf':
|
|
|
|
pred = pred != 0
|
|
|
|
else:
|
|
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
|
|
raise TypeError(msg.format(pred_dtype))
|
|
|
|
|
|
|
|
if jax.api._jit_is_disabled() and isinstance(core.get_aval(pred), ConcreteArray):
|
|
|
|
if pred:
|
|
|
|
return true_fun(operand)
|
|
|
|
else:
|
|
|
|
return false_fun(operand)
|
|
|
|
|
|
|
|
ops, ops_tree = tree_flatten((operand,))
|
|
|
|
ops_avals = tuple(_map(_abstractify, ops))
|
|
|
|
|
|
|
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
|
|
|
(true_fun, false_fun), ops_tree, ops_avals)
|
|
|
|
true_jaxpr, false_jaxpr = jaxprs
|
|
|
|
out_tree, false_out_tree = out_trees
|
|
|
|
|
|
|
|
_check_tree_and_avals("true_fun and false_fun output",
|
|
|
|
out_tree, true_jaxpr.out_avals,
|
|
|
|
false_out_tree, false_jaxpr.out_avals)
|
|
|
|
|
|
|
|
index = lax.convert_element_type(pred, np.int32)
|
|
|
|
|
|
|
|
linear = (False,) * (len(consts) + len(ops))
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *consts, *ops,
|
|
|
|
branches=(false_jaxpr, true_jaxpr), linear=linear)
|
|
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
|
|
|
|
def _cond_with_per_branch_args(pred,
|
|
|
|
true_operand, true_fun: Callable,
|
|
|
|
false_operand, false_fun: Callable):
|
|
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
|
|
|
|
Has equivalent semantics to this Python implementation::
|
|
|
|
|
|
|
|
def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
|
|
|
if pred:
|
|
|
|
return true_fun(true_operand)
|
|
|
|
else:
|
|
|
|
return false_fun(false_operand)
|
|
|
|
|
|
|
|
Pred has to be a scalar type, collection types (list, tuple) are not supported
|
|
|
|
"""
|
|
|
|
return _cond(pred,
|
|
|
|
lambda op: true_fun(op[0]),
|
|
|
|
lambda op: false_fun(op[1]),
|
|
|
|
(true_operand, false_operand))
|
|
|
|
|
|
|
|
def _cond_abstract_eval(*args, **kwargs):
|
|
|
|
return _map(raise_to_shaped, kwargs["branches"][0].out_avals)
|
|
|
|
|
|
|
|
def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
|
|
|
|
index, *args, branches, linear):
|
|
|
|
del linear # Unused.
|
|
|
|
|
|
|
|
def make_computation(name, jaxpr, op_shape):
|
|
|
|
c = xb.make_computation_builder(name + '_comp')
|
|
|
|
op = xb.parameter(c, 0, op_shape)
|
|
|
|
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
|
|
|
|
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
|
|
|
|
_map(partial(xb.constant, c), jaxpr.consts),
|
|
|
|
extend_name_stack(name_stack, name + '_fun'), *ops)
|
|
|
|
return c.build(xops.Tuple(c, outs))
|
|
|
|
|
|
|
|
op = xops.Tuple(c, args)
|
|
|
|
op_shape = c.get_shape(op)
|
|
|
|
branch_computations = [
|
|
|
|
make_computation(f'branch_{i}', jaxpr, op_shape)
|
|
|
|
for i, jaxpr in enumerate(branches)]
|
|
|
|
return xops.Conditional(index, branch_computations, [op] * len(branches))
|
|
|
|
|
|
|
|
def _select_tree(indices, branch_vals):
|
|
|
|
assert len(branch_vals) > 0
|
|
|
|
if len(branch_vals) == 1:
|
|
|
|
return branch_vals[0]
|
|
|
|
mid = len(branch_vals) // 2
|
|
|
|
mid = np.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
|
|
|
|
return lax.select(lax.lt(indices, mid),
|
|
|
|
_select_tree(indices, branch_vals[:mid]),
|
|
|
|
_select_tree(indices - mid, branch_vals[mid:]))
|
|
|
|
|
|
|
|
def _cond_index_bcast_and_select_tree(indices, branch_vals):
|
|
|
|
if all(core.get_aval(x) is core.abstract_unit for x in branch_vals):
|
|
|
|
return branch_vals[0]
|
|
|
|
else:
|
|
|
|
bcast_indices = lax.broadcast_in_dim(
|
|
|
|
indices, np.shape(branch_vals[0]), list(range(np.ndim(indices))))
|
|
|
|
return _select_tree(bcast_indices, branch_vals)
|
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
def _cond_batching_rule(args, dims, axis_name, branches, linear):
|
2020-10-17 14:33:26 -04:00
|
|
|
# TODO: maybe avoid moving arg axes to front if we're promoting to select?
|
|
|
|
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
|
|
|
args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
|
|
|
else x for x, d in zip(args, dims)]
|
|
|
|
orig_bat = [d is not batching.not_mapped for d in dims]
|
|
|
|
del dims
|
|
|
|
index, *ops = args
|
|
|
|
index_bat, *bat = orig_bat
|
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
branches_out_bat = [batching.batch_jaxpr(jaxpr, size, bat, False, axis_name)[1]
|
2020-10-17 14:33:26 -04:00
|
|
|
for jaxpr in branches]
|
|
|
|
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
|
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
branches_batched = tuple(batching.batch_jaxpr(jaxpr, size, bat, out_bat, axis_name)[0]
|
2020-10-17 14:33:26 -04:00
|
|
|
for jaxpr in branches)
|
|
|
|
|
|
|
|
if index_bat:
|
|
|
|
branch_outs = []
|
|
|
|
for jaxpr in branches_batched:
|
|
|
|
out = core.jaxpr_as_fun(jaxpr)(*ops)
|
|
|
|
out = [batching.broadcast(x, size, 0) if not b else x
|
|
|
|
for x, b in zip(out, out_bat)]
|
|
|
|
branch_outs.append(out)
|
|
|
|
return [_cond_index_bcast_and_select_tree(index, outs)
|
|
|
|
for outs in zip(*branch_outs)], [0] * len(branch_outs[0])
|
|
|
|
else:
|
|
|
|
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *ops, branches=branches_batched, linear=linear)
|
|
|
|
return out, out_dims
|
|
|
|
|
|
|
|
def _cond_jvp(primals, tangents, branches, linear):
|
|
|
|
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
|
|
|
|
index_nz, *ops_nz = nonzeros
|
|
|
|
assert index_nz is False
|
|
|
|
|
|
|
|
branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1]
|
|
|
|
for jaxpr in branches]
|
|
|
|
out_nz = [any(nz) for nz in zip(*branches_out_nz)]
|
|
|
|
|
|
|
|
branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
|
|
|
|
for jaxpr in branches)
|
|
|
|
|
|
|
|
index, *ops = primals
|
|
|
|
_, *ops_dot = tangents
|
|
|
|
ops_dot = _prune_zeros(ops_dot)
|
|
|
|
|
|
|
|
ops_lin = tuple(linear)
|
|
|
|
linear_jvp = ops_lin + (True,) * len(ops_dot)
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp)
|
|
|
|
out_primals, out_tangents = split_list(out, [len(out_nz)])
|
|
|
|
out_tangents_iter = iter(out_tangents)
|
|
|
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
|
|
|
for p, nz in zip(out_primals, out_nz)]
|
|
|
|
return out_primals, out_tangents
|
|
|
|
|
|
|
|
def _cond_partial_eval(trace, *tracers, branches, linear):
|
|
|
|
unknowns = [t.pval[0] is not None for t in tracers]
|
|
|
|
index_uk, *ops_uk = unknowns
|
|
|
|
|
|
|
|
if config.omnistaging_enabled:
|
|
|
|
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
|
|
|
else:
|
|
|
|
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
|
|
|
|
|
|
|
if index_uk:
|
|
|
|
# When the branch index is unknown, we stage out the whole cond.
|
|
|
|
params = dict(branches=branches, linear=linear)
|
|
|
|
return trace.default_process_primitive(cond_p, tracers, params)
|
|
|
|
|
|
|
|
branches_out_uks = []
|
|
|
|
for branch_jaxpr in branches:
|
|
|
|
_, _, out_uks = partial_eval_jaxpr(branch_jaxpr, ops_uk, instantiate=False)
|
|
|
|
branches_out_uks.append(out_uks)
|
|
|
|
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
|
|
|
|
|
|
|
|
branches_1, branches_2, branch_res_avals = [], [], []
|
|
|
|
for branch_jaxpr in branches:
|
|
|
|
branch_jaxpr_1, branch_jaxpr_2, _ = partial_eval_jaxpr(
|
|
|
|
branch_jaxpr, ops_uk, instantiate=out_uks)
|
|
|
|
branch_num_res = len(branch_jaxpr_1.out_avals) - len(out_uks)
|
|
|
|
|
|
|
|
# move residuals to the front
|
|
|
|
move = [False] * len(ops_uk) + [True] * branch_num_res
|
|
|
|
branch_jaxpr_2 = pe.move_binders_to_front(branch_jaxpr_2, move)
|
|
|
|
|
|
|
|
# TODO(frostig,mattjj): pe.partial_eval_jaxpr should raise to shaped avals
|
|
|
|
res_avals = _map(
|
|
|
|
raise_to_shaped, branch_jaxpr_2.in_avals[:branch_num_res])
|
|
|
|
|
|
|
|
branches_1.append(branch_jaxpr_1)
|
|
|
|
branches_2.append(branch_jaxpr_2)
|
|
|
|
branch_res_avals.append(res_avals)
|
|
|
|
|
|
|
|
branches_1 = tuple(branches_1)
|
|
|
|
branches_2 = tuple(branches_2)
|
|
|
|
|
|
|
|
for jaxpr in branches_2[1:]:
|
|
|
|
assert len(jaxpr.out_avals) == len(branches_2[0].out_avals)
|
|
|
|
|
|
|
|
num_outs = len(branches_2[0].out_avals)
|
|
|
|
|
|
|
|
all_res_avals, res_avals_per_branch = _merge_branch_residuals(
|
|
|
|
branch_res_avals)
|
|
|
|
|
|
|
|
branches_1 = _join_cond_outputs(
|
|
|
|
branches_1, all_res_avals, res_avals_per_branch, num_outs)
|
|
|
|
branches_2 = _join_cond_pe_staged_jaxpr_inputs(
|
|
|
|
branches_2, all_res_avals, res_avals_per_branch)
|
|
|
|
|
|
|
|
# TODO(frostig,mattjj): reinstate this assertion once pe.partial_eval_jaxpr
|
|
|
|
# raises to shaped avals
|
|
|
|
# for j in branches_1[1:]:
|
|
|
|
# assert j.out_avals == branches_1[0].out_avals
|
|
|
|
num_res = len(all_res_avals)
|
|
|
|
|
|
|
|
_, in_consts = unzip2([t.pval for t in tracers])
|
|
|
|
out_consts_res = cond_p.bind(*in_consts, branches=branches_1, linear=linear)
|
|
|
|
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
|
|
|
|
|
|
|
|
# TODO(frostig,mattjj): remove raised_to_shaped of avals once
|
|
|
|
# pe.partial_eval_jaxpr handles it
|
|
|
|
out_avals = _map(raise_to_shaped, branches_2[0].out_avals)
|
|
|
|
out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uks)]
|
|
|
|
|
|
|
|
index_tracer = trace.instantiate_const(tracers[0])
|
|
|
|
|
|
|
|
ops_tracers = [trace.instantiate_const(t) if uk
|
|
|
|
else trace.new_instantiated_literal(core.unit)
|
|
|
|
for uk, t in zip(unknowns[1:], tracers[1:])]
|
|
|
|
|
|
|
|
res_tracers = _map(trace.new_instantiated_const, res)
|
|
|
|
|
|
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
|
|
|
|
for pv, const in zip(out_pvs, out_consts)]
|
|
|
|
|
|
|
|
linear_2 = (False,) * num_res + linear
|
|
|
|
params = dict(branches=branches_2, linear=linear_2)
|
|
|
|
eqn = pe.new_eqn_recipe(
|
|
|
|
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
|
|
|
|
source_info_util.current())
|
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
# When partially evaluating conditionals, each branch produces residuals
|
|
|
|
# depending on the computation carried out by the branch, and a corresponding
|
|
|
|
# staged jaxpr that accepts those residuals as its first few inputs. The
|
|
|
|
# residual-producing branches are staged as jaxprs and bound right away in a
|
|
|
|
# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
|
|
|
|
# conditional. The following helper functions ensure that both collections of
|
|
|
|
# jaxprs (those evaluated and those staged) are valid for joint use under their
|
|
|
|
# respective conditionals.
|
|
|
|
#
|
|
|
|
# In particular, the residuals derived from each original branch may have
|
|
|
|
# distinct types. Because the branches of conditionals must have identical type
|
|
|
|
# signatures, we join residuals together across branches into a common format.
|
|
|
|
|
|
|
|
# In order to set up a type signature that all branches can conform to, it would
|
|
|
|
# suffice to concatenate all branches' residuals. But concatenation can result
|
|
|
|
# in redundant inputs and outputs, and might lead to memory allocation that
|
|
|
|
# scales unnecessarily with the branch count. This function finds common
|
|
|
|
# residual types across branches for reuse, so as to avoid redundant
|
|
|
|
# allocation. It returns a list L of types (avals) representing the collection
|
|
|
|
# of residuals merged according to type, and, for each branch, a lookup table to
|
|
|
|
# match its residuals to their positions/types in L. Example input/output:
|
|
|
|
#
|
|
|
|
# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]]
|
|
|
|
# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]]
|
|
|
|
# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
|
|
|
|
def _merge_branch_residuals(branch_res_avals):
|
|
|
|
def enumerate_equal(xs):
|
|
|
|
counts = {v: itertools.count() for v in set(xs)}
|
|
|
|
return [(x, next(counts[x])) for x in xs]
|
|
|
|
branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
|
|
|
|
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
|
|
|
|
indices = {v: i for i, v in enumerate(all_tagged_avals)}
|
|
|
|
branch_indices = [
|
|
|
|
[indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
|
|
|
|
all_avals = [x for x, _ in all_tagged_avals]
|
|
|
|
return all_avals, branch_indices
|
|
|
|
|
|
|
|
# This function augments branch outputs to agree with the merged residual
|
|
|
|
# format: each branch is made to return zero-filled values in the places of
|
|
|
|
# residual outputs that it does not populate.
|
|
|
|
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
|
|
|
num_non_res_outputs):
|
|
|
|
def augment_jaxpr(jaxpr, res_indices):
|
|
|
|
@lu.wrap_init
|
|
|
|
def f_aug(*args):
|
|
|
|
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
|
|
|
|
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
|
|
|
|
aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
|
|
|
|
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
|
|
|
|
return outs + list(aug_residuals)
|
|
|
|
|
|
|
|
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
|
|
|
|
|
|
|
|
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
|
|
|
|
|
|
|
# This function augments branch inputs to agree with the merged residual format:
|
|
|
|
# each branch is made to accept all residuals, even though it will ignore those
|
|
|
|
# that it does not read.
|
|
|
|
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
|
|
|
res_aval_indices_per_jaxpr):
|
|
|
|
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
|
|
|
|
all_res_vars = _map(newvar, all_res_avals)
|
|
|
|
|
|
|
|
def augment_jaxpr(jaxpr, res_indices):
|
|
|
|
num_res = len(res_indices)
|
|
|
|
res_vars = jaxpr.jaxpr.invars[:num_res]
|
|
|
|
non_res_vars = jaxpr.jaxpr.invars[num_res:]
|
|
|
|
|
|
|
|
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
|
|
|
|
aug_invars = aug_res_vars + non_res_vars
|
|
|
|
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
|
|
|
|
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns)
|
|
|
|
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
|
|
|
return jaxpr_aug
|
|
|
|
|
|
|
|
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
|
|
|
|
|
|
|
def _ordered_unique(xs):
|
|
|
|
d = collections.OrderedDict((x, None) for x in xs)
|
|
|
|
return list(d.keys())
|
|
|
|
|
|
|
|
def _transpose_cond_jaxpr(jaxpr, num_res):
|
|
|
|
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
|
|
|
primal_avals = _map(raise_to_shaped, primal_avals)
|
|
|
|
|
|
|
|
@lu.wrap_init
|
|
|
|
def transposed(*args):
|
|
|
|
res, cts_out = split_list(args, [num_res])
|
|
|
|
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
|
|
|
cts_in = ad.backward_pass(
|
|
|
|
jaxpr.jaxpr, jaxpr.consts, primals, cts_out)
|
|
|
|
_, cts_in = split_list(cts_in, [num_res])
|
|
|
|
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
|
|
|
|
|
|
|
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
|
|
|
|
|
|
|
def _cond_transpose(cts, *args, branches, linear):
|
|
|
|
index, *ops = args
|
|
|
|
in_avals = _map(raise_to_shaped, branches[0].in_avals)
|
|
|
|
num_res = len(ops) - sum(linear)
|
|
|
|
|
|
|
|
branches_trans = tuple(
|
|
|
|
_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches)
|
2020-11-13 18:08:42 -08:00
|
|
|
lin_in_avals = [raise_to_shaped(a, weak_type=False)
|
|
|
|
for a, l in zip(in_avals, linear) if l]
|
|
|
|
assert all(core.typematch(out_aval, lin_in_aval)
|
|
|
|
for jaxpr in branches_trans
|
|
|
|
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
res = ops[:num_res]
|
|
|
|
cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
|
|
|
|
linear_trans = (False,) * num_res + (True,) * len(cts)
|
|
|
|
|
|
|
|
out = cond_p.bind(
|
|
|
|
index, *res, *cts, branches=branches_trans, linear=linear_trans)
|
|
|
|
assert all(_map(core.typecheck, lin_in_avals, out))
|
|
|
|
|
|
|
|
out_iter = iter(out)
|
|
|
|
out = [next(out_iter) if l else None for l in linear]
|
|
|
|
assert next(out_iter, None) is None
|
|
|
|
return [None] + out
|
|
|
|
|
|
|
|
def _avals_short(avals):
|
|
|
|
to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
|
|
|
|
return ' '.join(_map(to_str, avals))
|
|
|
|
|
|
|
|
def _cond_typecheck(*avals, branches, linear):
|
|
|
|
tc = partial(_typecheck_param, 'cond')
|
|
|
|
tc(branches, 'branches', 'tuple of ClosedJaxpr',
|
|
|
|
type(branches) is tuple and
|
|
|
|
all(type(x) is core.ClosedJaxpr for x in branches))
|
|
|
|
tc(linear, 'linear', 'tuple of bool',
|
|
|
|
type(linear) is tuple and all(type(x) is bool for x in linear))
|
|
|
|
|
|
|
|
core.typecheck_assert(
|
|
|
|
len(branches) > 0,
|
|
|
|
'cond requires at least one branch function')
|
|
|
|
core.typecheck_assert(
|
|
|
|
len(linear) + 1 == len(avals),
|
|
|
|
f'cond given {len(linear)} linear flags for '
|
|
|
|
f'{len(avals) - 1} non-predicate operands')
|
|
|
|
|
|
|
|
jaxpr0 = branches[0]
|
|
|
|
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
|
|
|
|
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
|
|
|
|
|
|
|
|
for i, jaxpr in enumerate(branches[1:]):
|
|
|
|
core.typecheck_assert(
|
|
|
|
len(jaxpr0.in_avals) == len(jaxpr.in_avals),
|
|
|
|
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
|
|
|
|
f'branch {i+1} takes {len(jaxpr.in_avals)}')
|
|
|
|
core.typecheck_assert(
|
|
|
|
len(jaxpr0.out_avals) == len(jaxpr.out_avals),
|
|
|
|
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
|
|
|
|
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
|
|
|
|
core.typecheck_assert(
|
|
|
|
all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)),
|
|
|
|
f'cond branches 0 and {i+1} have mismatching input types: '
|
|
|
|
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
|
|
|
|
core.typecheck_assert(
|
|
|
|
all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)),
|
|
|
|
f'cond branches 0 and {i+1} have mismatching output types: '
|
|
|
|
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
|
|
|
|
|
|
|
|
core.typecheck_assert(
|
|
|
|
len(avals) == 1 + len(jaxpr0.in_avals),
|
|
|
|
f'cond called with {len(avals) - 1} non-predicate operands, '
|
|
|
|
f'but branches take {len(jaxpr0.in_avals)} inputs')
|
|
|
|
|
|
|
|
index_aval, *op_avals = avals
|
|
|
|
core.typecheck_assert(
|
|
|
|
index_aval.dtype == np.int32,
|
|
|
|
f'cond called with index of type {index_aval.dtype} instead of int32')
|
|
|
|
core.typecheck_assert(
|
|
|
|
all(_map(core.typecompat, jaxpr0.in_avals, op_avals)),
|
|
|
|
f'cond branches take input types {jaxpr0_in_avals_str}, '
|
|
|
|
f'called with operands of type {_avals_short(op_avals)}')
|
|
|
|
|
|
|
|
def cond_bind(*args, branches, linear):
|
|
|
|
if not core.skip_checks:
|
|
|
|
avals = _map(core.get_aval, args)
|
|
|
|
_cond_typecheck(*avals, branches=branches, linear=linear)
|
|
|
|
for jaxpr in branches:
|
|
|
|
core.check_jaxpr(jaxpr.jaxpr)
|
|
|
|
return core.Primitive.bind(cond_p, *args, branches=branches, linear=linear)
|
|
|
|
|
|
|
|
cond_p = lax.Primitive('cond')
|
|
|
|
cond_p.multiple_results = True
|
|
|
|
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
|
|
|
cond_p.def_abstract_eval(_cond_abstract_eval)
|
|
|
|
cond_p.def_custom_bind(cond_bind)
|
|
|
|
ad.primitive_jvps[cond_p] = _cond_jvp
|
|
|
|
ad.primitive_transposes[cond_p] = _cond_transpose
|
|
|
|
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
2020-10-26 10:11:13 +00:00
|
|
|
batching.initial_style_batchers[cond_p] = _cond_batching_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
xla.initial_style_translations[cond_p] = _cond_translation_rule
|
|
|
|
core.custom_typechecks[cond_p] = _cond_typecheck
|
|
|
|
|
|
|
|
|
|
|
|
### scan
|
|
|
|
|
|
|
|
def scan(f, init, xs, length=None, reverse=False, unroll=1):
|
|
|
|
"""Scan a function over leading array axes while carrying along state.
|
|
|
|
|
|
|
|
The type signature in brief is
|
|
|
|
|
|
|
|
.. code-block:: haskell
|
|
|
|
|
|
|
|
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
|
|
|
|
|
|
|
|
where we use [t] here to denote the type t with an additional leading axis.
|
|
|
|
That is, if t is an array type then [t] represents the type with an additional
|
|
|
|
leading axis, and if t is a pytree (container) type with array leaves then [t]
|
|
|
|
represents the type with the same pytree structure and corresponding leaves
|
|
|
|
each with an additional leading axis.
|
|
|
|
|
|
|
|
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
|
|
|
|
of ``scan`` are given roughly by this Python implementation::
|
|
|
|
|
|
|
|
def scan(f, init, xs, length=None):
|
|
|
|
if xs is None:
|
|
|
|
xs = [None] * length
|
|
|
|
carry = init
|
|
|
|
ys = []
|
|
|
|
for x in xs:
|
|
|
|
carry, y = f(carry, x)
|
|
|
|
ys.append(y)
|
|
|
|
return carry, np.stack(ys)
|
|
|
|
|
|
|
|
Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
|
|
|
|
types, and so multiple arrays can be scanned over at once and produce multiple
|
|
|
|
output arrays. (None is actually an empty pytree.)
|
|
|
|
|
|
|
|
Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
|
|
|
|
a single XLA While HLO. That makes it useful for reducing compilation times
|
|
|
|
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
|
|
|
function are unrolled, leading to large XLA computations.
|
|
|
|
|
|
|
|
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
|
|
|
|
across all iterations (and not just be consistent up to NumPy rank/shape
|
|
|
|
broadcasting and dtype promotion rules, for example). In other words, the type
|
|
|
|
``c`` in the type signature above represents an array with a fixed shape and
|
|
|
|
dtype (or a nested tuple/list/dict container data structure with a fixed
|
|
|
|
structure and arrays with fixed shape and dtype at the leaves).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
|
|
|
that ``f`` accepts two arguments where the first is a value of the loop
|
|
|
|
carry and the second is a slice of ``xs`` along its leading axis, and that
|
|
|
|
``f`` returns a pair where the first element represents a new value for
|
|
|
|
the loop carry and the second represents a slice of the output.
|
|
|
|
init: an initial loop carry value of type ``c``, which can be a scalar,
|
|
|
|
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
|
|
|
the initial loop carry value. This value must have the same structure as
|
|
|
|
the first element of the pair returned by ``f``.
|
|
|
|
xs: the value of type ``[a]`` over which to scan along the leading axis,
|
|
|
|
where ``[a]`` can be an array or any pytree (nested Python
|
|
|
|
tuple/list/dict) thereof with consistent leading axis sizes.
|
|
|
|
length: optional integer specifying the number of loop iterations, which
|
|
|
|
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
|
|
|
be used to perform scans where no input ``xs`` are needed).
|
|
|
|
reverse: optional boolean specifying whether to run the scan iteration
|
|
|
|
forward (the default) or in reverse, equivalent to reversing the leading
|
|
|
|
axes of the arrays in both ``xs`` and in ``ys``.
|
|
|
|
unroll: optional positive int specifying, in the underlying operation of the
|
|
|
|
scan primitive, how many scan iterations to unroll within a single
|
|
|
|
iteration of a loop.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A pair of type ``(c, [b])`` where the first element represents the final
|
|
|
|
loop carry value and the second element represents the stacked outputs of
|
|
|
|
the second output of ``f`` when scanned over the leading axis of the inputs.
|
|
|
|
"""
|
|
|
|
xs_flat, xs_tree = tree_flatten(xs)
|
|
|
|
|
|
|
|
try:
|
|
|
|
lengths = [x.shape[0] for x in xs_flat]
|
|
|
|
except AttributeError as err:
|
|
|
|
msg = "scan got value with no leading axis to scan over: {}."
|
|
|
|
raise ValueError(
|
|
|
|
msg.format(', '.join(str(x) for x in xs_flat
|
|
|
|
if not hasattr(x, 'shape')))) from err
|
|
|
|
|
|
|
|
if length is not None:
|
|
|
|
length = int(length)
|
|
|
|
if not all(length == l for l in lengths):
|
|
|
|
msg = ("scan got `length` argument of {} which disagrees with "
|
|
|
|
"leading axis sizes {}.")
|
|
|
|
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
|
|
|
|
else:
|
|
|
|
unique_lengths = set(lengths)
|
|
|
|
if len(unique_lengths) > 1:
|
|
|
|
msg = "scan got values with different leading axis sizes: {}."
|
|
|
|
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
|
|
|
elif len(unique_lengths) == 0:
|
|
|
|
msg = "scan got no values to scan over and `length` not provided."
|
|
|
|
raise ValueError(msg)
|
|
|
|
else:
|
|
|
|
length, = unique_lengths
|
|
|
|
|
|
|
|
if jax.api._jit_is_disabled():
|
|
|
|
carry = init
|
|
|
|
ys = []
|
|
|
|
maybe_reversed = reversed if reverse else lambda x: x
|
|
|
|
for i in maybe_reversed(range(length)):
|
|
|
|
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
|
|
|
|
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
|
|
|
|
ys.append(y)
|
|
|
|
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
|
|
|
|
else jax.numpy.stack((y, *ys)))
|
|
|
|
ys = tree_multimap(stack, *maybe_reversed(ys))
|
|
|
|
return carry, ys
|
|
|
|
|
|
|
|
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
|
|
|
|
x_dtypes = [x.dtype for x in xs_flat]
|
|
|
|
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
|
|
|
|
|
|
|
|
def _create_jaxpr(init):
|
|
|
|
init_flat, init_tree = tree_flatten(init)
|
|
|
|
in_flat, in_tree = tree_flatten((init, xs))
|
|
|
|
|
|
|
|
carry_avals = tuple(_map(_abstractify, init_flat))
|
|
|
|
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
|
|
|
|
out_tree_children = out_tree.children()
|
|
|
|
if len(out_tree_children) != 2:
|
|
|
|
msg = "scan body output must be a pair, got {}."
|
|
|
|
raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
|
|
|
|
carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves]
|
|
|
|
return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children
|
|
|
|
|
|
|
|
# The carry input and output avals must match exactly. However, we want to account for
|
|
|
|
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
|
|
|
|
# may not match the output despite being compatible by virtue of their weak type.
|
|
|
|
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
|
|
|
|
# necessary, a second time with modified init values.
|
|
|
|
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
|
|
|
|
new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out)
|
|
|
|
if changed:
|
|
|
|
new_init = tree_unflatten(init_tree, new_init_flat)
|
|
|
|
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
|
|
|
|
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
|
|
|
|
|
|
|
|
_check_tree_and_avals("scan carry output and input",
|
|
|
|
# Extract the subtree and avals for the first element of the return tuple
|
|
|
|
out_tree_children[0], carry_avals_out,
|
|
|
|
init_tree, carry_avals)
|
|
|
|
|
|
|
|
out = scan_p.bind(*itertools.chain(consts, in_flat),
|
|
|
|
reverse=reverse, length=length, jaxpr=jaxpr,
|
|
|
|
num_consts=len(consts), num_carry=len(init_flat),
|
|
|
|
linear=(False,) * (len(consts) + len(in_flat)),
|
|
|
|
unroll=unroll)
|
|
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
|
|
|
|
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
|
|
|
|
f_impl, x_avals, y_avals):
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
|
|
|
|
carry = init
|
|
|
|
ys = []
|
|
|
|
|
|
|
|
for i in range(length):
|
|
|
|
i_ = length - i - 1 if reverse else i
|
|
|
|
x = _map(partial(_index_array, i_), x_avals, xs)
|
|
|
|
out = f_impl(*consts, *carry, *x)
|
|
|
|
carry, y = split_list(out, [num_carry])
|
|
|
|
ys.append(y)
|
|
|
|
|
|
|
|
ys = list(reversed(ys)) if reverse else ys
|
|
|
|
ys = list(zip(*ys))
|
|
|
|
ys = _map(_stack, y_avals, ys)
|
|
|
|
return (*carry, *ys)
|
|
|
|
|
|
|
|
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
|
|
|
|
f_impl, x_avals, y_avals):
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
|
|
|
|
def cond_fun(vals):
|
|
|
|
i, *_ = vals
|
|
|
|
return i < length
|
|
|
|
|
|
|
|
def body_fun(vals):
|
|
|
|
[i], carry, ys = split_list(vals, [1, num_carry])
|
|
|
|
i_ = length - i - 1 if reverse else i
|
|
|
|
x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
|
|
|
|
out_flat = f_impl(*consts, *carry, *x)
|
|
|
|
carry_out, y_updates = split_list(out_flat, [num_carry])
|
|
|
|
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
|
|
|
|
return [i + 1] + carry_out + ys_out
|
|
|
|
|
|
|
|
ys_init = _map(partial(_empty_array, length), y_avals)
|
|
|
|
if length == 0:
|
|
|
|
return init + ys_init
|
|
|
|
else:
|
|
|
|
init_val = [lax._const(length, 0)] + init + ys_init
|
|
|
|
_, *outs = while_loop(cond_fun, body_fun, init_val)
|
|
|
|
return outs
|
|
|
|
|
|
|
|
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
|
|
|
|
linear, block_length, f_impl, x_avals, y_avals):
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
|
|
|
|
num_blocks, rem = divmod(length, block_length)
|
|
|
|
assert rem == 0
|
|
|
|
|
|
|
|
partition = partial(_partition_leading, num_blocks, block_length)
|
|
|
|
xs_block = _map(partition, x_avals, xs)
|
|
|
|
|
|
|
|
prepend_aval = partial(_prepend_dim_to_aval, block_length)
|
|
|
|
x_block_avals = _map(prepend_aval, x_avals)
|
|
|
|
y_block_avals = _map(prepend_aval, y_avals)
|
|
|
|
|
|
|
|
f_impl_block = partial(
|
|
|
|
_scan_impl_unrolled, reverse=reverse, length=block_length,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
|
|
|
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
|
|
|
|
|
|
|
outs = _scan_impl_loop(
|
|
|
|
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
|
|
|
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
|
|
|
|
|
|
|
|
carry, ys_blocks = split_list(outs, [num_carry])
|
|
|
|
combine = partial(_combine_leading, num_blocks, block_length)
|
|
|
|
ys = _map(combine, y_avals, ys_blocks)
|
|
|
|
return (*carry, *ys)
|
|
|
|
|
|
|
|
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
|
|
|
unroll):
|
|
|
|
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
|
|
|
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
|
|
f_impl = core.jaxpr_as_fun(jaxpr)
|
|
|
|
|
|
|
|
if unroll == 1:
|
|
|
|
return _scan_impl_loop(
|
|
|
|
*args, reverse=reverse, length=length, num_consts=num_consts,
|
|
|
|
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
|
|
|
|
y_avals=y_avals)
|
|
|
|
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
num_blocks, rem = divmod(length, unroll)
|
|
|
|
length_div = num_blocks * unroll
|
|
|
|
|
|
|
|
if rem > 0:
|
|
|
|
if reverse:
|
|
|
|
split = partial(_split_leading_dim, rem)
|
|
|
|
xs_rem, xs = unzip2(_map(split, x_avals, xs))
|
|
|
|
else:
|
|
|
|
split = partial(_split_leading_dim, length_div)
|
|
|
|
xs, xs_rem = unzip2(_map(split, x_avals, xs))
|
|
|
|
|
|
|
|
outs = _scan_impl_block_unrolled(
|
|
|
|
*consts, *init, *xs, reverse=reverse, length=length_div,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
|
|
|
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
|
|
|
|
|
|
|
carry, ys = split_list(outs, [num_carry])
|
|
|
|
|
|
|
|
if rem > 0:
|
|
|
|
outs = _scan_impl_unrolled(
|
|
|
|
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
|
|
|
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
|
|
|
carry, ys_rem = split_list(outs, [num_carry])
|
|
|
|
if reverse:
|
|
|
|
ys = _map(_concatenate, y_avals, ys_rem, ys)
|
|
|
|
else:
|
|
|
|
ys = _map(_concatenate, y_avals, ys, ys_rem)
|
|
|
|
|
|
|
|
return (*carry, *ys)
|
|
|
|
|
|
|
|
def _stack(aval, vals):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
vals = [lax.expand_dims(x, (0,)) for x in vals]
|
|
|
|
return lax.concatenate(vals, 0)
|
|
|
|
|
|
|
|
def _concatenate(aval, x1, x2):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
return lax.concatenate([x1, x2], 0)
|
|
|
|
|
|
|
|
def _split_leading_dim(i, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return (core.unit, core.unit)
|
|
|
|
else:
|
|
|
|
assert x.ndim >= 1
|
|
|
|
return (lax.slice_in_dim(x, 0, i),
|
|
|
|
lax.slice_in_dim(x, i, x.shape[0]))
|
|
|
|
|
|
|
|
def _dynamic_index_array(i, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
return lax.dynamic_index_in_dim(x, i, keepdims=False)
|
|
|
|
|
|
|
|
def _index_array(i, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
return lax.index_in_dim(x, i, keepdims=False)
|
|
|
|
|
|
|
|
def _empty_array(sz, aval):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
return lax.full((sz,) + aval.shape, 0, aval.dtype)
|
|
|
|
|
|
|
|
def _update_array(i, aval, xs, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
return lax.dynamic_update_index_in_dim(xs, x, i, 0)
|
|
|
|
|
|
|
|
def _partition_leading(sz0, sz1, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
assert x.ndim >= 1
|
|
|
|
assert x.shape[0] == sz0 * sz1
|
|
|
|
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
|
|
|
|
|
|
|
|
def _combine_leading(sz0, sz1, aval, x):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
else:
|
|
|
|
assert x.ndim >= 2
|
|
|
|
assert x.shape[0] == sz0
|
|
|
|
assert x.shape[1] == sz1
|
|
|
|
return lax.collapse(x, 0, 2)
|
|
|
|
|
|
|
|
def _prepend_dim_to_aval(sz, aval):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return aval
|
|
|
|
elif isinstance(aval, ShapedArray):
|
|
|
|
return ShapedArray((sz, *aval.shape), aval.dtype)
|
|
|
|
else:
|
|
|
|
raise TypeError(f'Prepending dim {sz} to aval {aval}')
|
|
|
|
|
|
|
|
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
|
|
|
|
linear, unroll):
|
|
|
|
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
|
|
ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
|
|
|
|
if aval is not core.abstract_unit else aval for aval in y_avals]
|
|
|
|
return carry_avals + ys_avals
|
|
|
|
|
|
|
|
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
|
|
|
linear, unroll):
|
|
|
|
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
|
|
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
|
|
|
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])
|
|
|
|
|
|
|
|
# Fixpoint computation of which carry are not ad.zero: either
|
|
|
|
# non-zero from init, or the carry out is non-zero. Each iteration promotes
|
|
|
|
# at least one carry to non-zero. We need at most len(carry) iterations,
|
|
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
|
|
# carry_nz.
|
|
|
|
carry_nz = init_nz
|
|
|
|
for _ in range(1 + len(carry_nz)):
|
|
|
|
nonzeros = const_nz + carry_nz + xs_nz
|
|
|
|
jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
|
|
|
|
jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
|
|
|
|
carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
|
|
|
|
if carry_nz_out == carry_nz:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_nz = _map(operator.or_, carry_nz, carry_nz_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
|
|
|
tangents = [ad.instantiate_zeros(t) if nz else t
|
|
|
|
for t, nz in zip(tangents, nonzeros)]
|
|
|
|
|
|
|
|
consts, init, xs = split_list(primals, [num_consts, num_carry])
|
|
|
|
all_tangents = split_list(tangents, [num_consts, num_carry])
|
|
|
|
consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)
|
|
|
|
|
|
|
|
jaxpr_jvp_rearranged = ad.rearrange_binders(
|
|
|
|
jaxpr_jvp,
|
|
|
|
[num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)],
|
|
|
|
[num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)])
|
|
|
|
|
|
|
|
consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
|
|
|
jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot)
|
|
|
|
+ init_linear + [True] * len(init_dot)
|
|
|
|
+ xs_linear + [True] * len(xs_dot))
|
|
|
|
|
|
|
|
out_flat = scan_p.bind(
|
|
|
|
*(consts + consts_dot + init + init_dot + xs + xs_dot),
|
|
|
|
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
|
|
|
|
num_consts=num_consts + len(consts_dot),
|
|
|
|
num_carry=num_carry + len(init_dot),
|
|
|
|
linear=jaxpr_jvp_linear, unroll=unroll)
|
|
|
|
|
|
|
|
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
|
|
|
|
primals_out = carry + ys
|
|
|
|
tangents_out_iter = iter(carry_dot + ys_dot)
|
|
|
|
tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p)
|
|
|
|
for p, nz in zip(primals_out, nonzeros_out)]
|
|
|
|
return primals_out, tangents_out
|
|
|
|
|
|
|
|
def _prune_zeros(ts):
|
|
|
|
return [t for t in ts if type(t) is not ad_util.Zero]
|
|
|
|
|
|
|
|
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
|
|
|
jaxpr, linear, unroll):
|
|
|
|
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: # type: ignore
|
|
|
|
params = dict(reverse=reverse, length=length, num_consts=num_consts,
|
|
|
|
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
|
|
|
|
unroll=unroll)
|
|
|
|
return trace.default_process_primitive(scan_p, tracers, params)
|
|
|
|
|
|
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
|
|
|
|
|
|
|
unknowns = [t.pval[0] is not None for t in tracers]
|
|
|
|
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
|
|
|
|
|
|
|
|
if config.omnistaging_enabled:
|
|
|
|
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
|
|
|
else:
|
|
|
|
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
|
|
|
|
|
|
|
# Fixpoint computation of which carry are unknown (not a constant): either
|
|
|
|
# unknown from init, or the carry out is unknown. Each iteration promotes
|
|
|
|
# at least one carry to unknown. We need at most len(carry) iterations,
|
|
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
|
|
# carry_uk.
|
|
|
|
carry_uk = init_uk
|
|
|
|
for _ in range(1 + len(carry_uk)):
|
|
|
|
unknowns = const_uk + carry_uk + xs_uk
|
|
|
|
jaxpr_1, jaxpr_2, out_uk = partial_eval_jaxpr(
|
|
|
|
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
|
|
|
|
carry_uk_out = out_uk[:num_carry]
|
|
|
|
if carry_uk_out == carry_uk:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
|
|
|
|
|
|
|
|
# The residuals are treated as extensive outputs of jaxpr_1 (and extensive
|
|
|
|
# inputs to jaxpr_2), but residuals that are loop-invariant can be hoisted.
|
|
|
|
# TODO(mattjj): hoist other loop-invariant values here too (instantiate=False)
|
|
|
|
invariant_pvals = [pe.PartialVal.known(core.unit if uk else t.pval[1])
|
|
|
|
for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])]
|
|
|
|
other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]]
|
|
|
|
in_pvals_1 = invariant_pvals + other_pvals
|
|
|
|
jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
|
|
|
|
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
|
|
|
|
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
|
|
|
|
jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ())
|
|
|
|
num_consts_1 = num_consts + len(consts_1)
|
|
|
|
# any now-known residuals are intensive, so we want to revise jaxpr_2 to take
|
|
|
|
# those inputs as constants rather than as extensive inputs
|
|
|
|
_, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys])
|
|
|
|
intensive_residuals = [const for pv, const in res_pvals if pv is None]
|
|
|
|
move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals]
|
|
|
|
jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
|
|
|
|
num_consts_2 = num_consts + len(intensive_residuals)
|
|
|
|
|
|
|
|
# As another optimization, for any extensive inputs that are just forwarded to
|
|
|
|
# extensive outputs, to avoid a copy (looping over dynamic-update-slice) we'd
|
|
|
|
# rather just forward the input tracer. That means pruning some extensive
|
|
|
|
# outputs from the jaxpr here, and updating out_flat below.
|
|
|
|
extensive_invars = jaxpr_1_opt.jaxpr.invars[num_consts_1 + num_carry:]
|
|
|
|
extensive_outvars = jaxpr_1_opt.jaxpr.outvars[num_carry:]
|
2020-11-05 11:54:05 +00:00
|
|
|
extensive_avals = [core.unmapped_aval(length, 0, core.raise_to_shaped(v.aval))
|
2020-10-17 14:33:26 -04:00
|
|
|
for v in extensive_outvars]
|
|
|
|
fwd_extensive = [num_consts + num_carry + extensive_invars.index(v)
|
|
|
|
if v in extensive_invars else None for v in extensive_outvars]
|
|
|
|
jaxpr_1_opt.jaxpr.outvars = (
|
|
|
|
jaxpr_1_opt.jaxpr.outvars[:num_carry] +
|
|
|
|
[v for i, v in zip(fwd_extensive, extensive_outvars) if i is None])
|
|
|
|
|
|
|
|
in_consts = (list(consts_1) + [core.unit] * num_consts +
|
|
|
|
[core.unit if uk else t.pval[1]
|
|
|
|
for uk, t in zip(unknowns[num_consts:], tracers[num_consts:])])
|
|
|
|
linear_1 = ([False] * len(consts_1) + [True] * num_consts +
|
|
|
|
[lin or uk for uk, lin
|
|
|
|
in zip(unknowns[num_consts:], linear[num_consts:])])
|
|
|
|
out_flat = scan_p.bind(
|
|
|
|
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
|
|
|
|
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1),
|
|
|
|
unroll=unroll)
|
|
|
|
|
|
|
|
# Propagate the forwarded extensive outputs using fwd_extensive. Any
|
|
|
|
# numpy.ndarray inputs should be converted to JAX DeviceArrays.
|
|
|
|
out_carry, out_extensive = split_list(out_flat, [num_carry])
|
|
|
|
out_extensive_iter = iter(out_extensive)
|
|
|
|
out_extensive = [next(out_extensive_iter) if i is None
|
|
|
|
else _maybe_device_put(tracers[i].pval[1]) if tracers[i].is_known()
|
|
|
|
else tracers[i] for i in fwd_extensive]
|
|
|
|
assert all(a == core.raise_to_shaped(core.get_aval(out))
|
|
|
|
for a, out in zip(extensive_avals, out_extensive))
|
|
|
|
out_flat = out_carry + out_extensive
|
|
|
|
|
|
|
|
out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
|
|
|
|
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
|
|
|
|
|
|
|
|
new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
|
|
|
|
for uk, t in zip(unknowns, tracers)]
|
|
|
|
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
|
|
ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
|
|
|
|
out_avals = carry_avals + ys_avals
|
|
|
|
out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]
|
|
|
|
|
|
|
|
out_consts = out_carry + ys
|
|
|
|
int_res_tracers = _map(trace.new_instantiated_const, intensive_residuals)
|
|
|
|
ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals)
|
|
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
|
|
|
|
for pv, const in zip(out_pvs, out_consts)]
|
|
|
|
linear_2 = ([False] * len(int_res_tracers) +
|
|
|
|
[lin or not uk for uk, lin in zip(unknowns, linear)] +
|
|
|
|
[False] * len(ext_res_tracers))
|
|
|
|
eqn = pe.new_eqn_recipe(int_res_tracers + new_tracers + ext_res_tracers,
|
|
|
|
out_tracers, scan_p,
|
|
|
|
dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
|
|
|
|
num_consts=num_consts_2,
|
|
|
|
num_carry=num_carry, linear=tuple(linear_2),
|
|
|
|
unroll=unroll),
|
|
|
|
source_info_util.current())
|
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return out_tracers
|
|
|
|
|
|
|
|
def _maybe_device_put(x):
|
|
|
|
if isinstance(x, np.ndarray):
|
|
|
|
return lax._device_put_raw(x)
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
def _promote_aval_rank(sz, aval):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.abstract_unit
|
|
|
|
else:
|
|
|
|
return ShapedArray((sz,) + aval.shape, aval.dtype)
|
|
|
|
|
|
|
|
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
|
|
|
|
linear, unroll):
|
|
|
|
# we've only implemented transposing scans with specific lin/nonlin patterns
|
|
|
|
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
|
|
|
|
num_ires = len(consts_lin) - sum(consts_lin)
|
|
|
|
num_eres = len(xs_lin) - sum(xs_lin)
|
|
|
|
if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
|
|
|
|
raise NotImplementedError
|
|
|
|
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
|
|
|
|
raise NotImplementedError
|
|
|
|
if not all(init_lin):
|
|
|
|
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
|
|
|
|
|
|
|
|
consts, _, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
ires, _ = split_list(consts, [num_ires])
|
|
|
|
_, eres = split_list(xs, [sum(xs_lin)])
|
|
|
|
assert not any(ad.is_undefined_primal(r) for r in ires)
|
|
|
|
assert not any(ad.is_undefined_primal(r) for r in eres)
|
|
|
|
|
|
|
|
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
|
|
ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
|
|
|
|
ct_carry, ct_ys = split_list(cts, [num_carry])
|
|
|
|
ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
|
|
|
|
ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
|
|
|
|
ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
|
|
|
|
|
|
|
|
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
|
|
|
|
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
|
|
|
|
jaxpr_trans = _transpose_scan_jaxpr(
|
|
|
|
num_ires, num_consts - num_ires, num_eres, jaxpr)
|
|
|
|
linear_trans = ([False] * num_ires +
|
|
|
|
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
|
|
|
|
[False] * num_eres)
|
|
|
|
|
|
|
|
outs = scan_p.bind(
|
|
|
|
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
|
|
|
|
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
|
|
|
|
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans),
|
|
|
|
unroll=unroll)
|
|
|
|
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
|
|
|
|
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
|
|
|
|
|
|
|
|
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
|
|
|
|
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
|
|
|
|
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
|
|
|
|
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
|
|
|
|
res1_avals, c_avals, a_avals, res2_avals = split_list(
|
|
|
|
jaxpr.in_avals, [num_res1, num_c, num_a])
|
|
|
|
num_b = len(jaxpr.out_avals)
|
|
|
|
b_avals = list(jaxpr.out_avals)
|
|
|
|
|
|
|
|
@lu.wrap_init
|
|
|
|
def transposed(*res1_cbar_bbar_res2):
|
|
|
|
res1, c_bar, b_bar, res2 = split_list(
|
|
|
|
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
|
|
|
|
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
|
|
|
|
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
|
|
|
|
cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
|
|
|
|
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
|
|
|
|
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
|
|
|
|
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
|
|
|
|
_map(ad.add_tangents, c_bar, new_c_bar))
|
|
|
|
return c_bar + a_bar
|
|
|
|
return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
|
|
|
|
|
|
|
|
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
|
|
|
if config.omnistaging_enabled:
|
|
|
|
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
|
|
|
else:
|
|
|
|
pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
|
|
|
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
|
|
|
|
out_avals, _ = unzip2(pvals_out)
|
|
|
|
return core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
|
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_consts,
|
2020-10-17 14:33:26 -04:00
|
|
|
num_carry, linear, unroll):
|
|
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
|
|
|
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
|
|
|
orig_batched = [d is not batching.not_mapped for d in dims]
|
|
|
|
const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry])
|
|
|
|
|
|
|
|
# Fixpoint computation of which carry are batched: either
|
|
|
|
# batched from init, or the carry out is batched. Each iteration promotes
|
|
|
|
# at least one carry to batched. We need at most len(carry) iterations,
|
|
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
|
|
# carry_batched.
|
|
|
|
carry_batched = init_batched
|
|
|
|
for _ in range(1 + len(carry_batched)):
|
|
|
|
batched = const_batched + carry_batched + xs_batched
|
|
|
|
jaxpr_batched, batched_out = batching.batch_jaxpr(
|
2020-10-26 10:11:13 +00:00
|
|
|
jaxpr, size, batched,
|
|
|
|
instantiate=carry_batched + [False] * num_ys,
|
|
|
|
axis_name=axis_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
|
|
|
|
if carry_batched_out == carry_batched:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
carry_batched = _map(operator.or_, carry_batched, carry_batched_out)
|
|
|
|
else:
|
|
|
|
assert False, "Fixpoint not reached"
|
|
|
|
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
|
|
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
|
|
|
|
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
|
|
|
else x for x, d in zip(consts, consts_bdims)]
|
|
|
|
new_init = [batching.broadcast(x, size, 0) if now_batched and not was_batched
|
|
|
|
else batching.moveaxis(x, d, 0) if now_batched else x
|
|
|
|
for x, d, was_batched, now_batched in
|
|
|
|
zip(init, init_bdims, init_batched, carry_batched)]
|
|
|
|
new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1
|
|
|
|
else x for x, d in zip(xs, xs_bdims)]
|
|
|
|
new_args = new_consts + new_init + new_xs
|
|
|
|
|
|
|
|
outs = scan_p.bind(
|
|
|
|
*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
|
|
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll)
|
|
|
|
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
|
|
|
|
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
|
|
|
|
return outs, carry_bdims + ys_bdims
|
|
|
|
|
|
|
|
def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
|
|
|
|
jaxpr, num_consts, num_carry, linear, unroll):
|
|
|
|
dynamic_length, = masking.shape_as_value((length,))
|
|
|
|
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
|
|
|
|
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
|
|
|
|
max_length, = {x.shape[0] for x in xs}
|
|
|
|
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
|
|
|
out_vals = scan_p.bind(
|
|
|
|
*itertools.chain([dynamic_length] + consts, [0], init, xs),
|
|
|
|
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
|
|
|
|
num_consts=1 + num_consts, num_carry=1 + num_carry,
|
|
|
|
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear),
|
|
|
|
unroll=unroll)
|
|
|
|
return out_vals[1:]
|
|
|
|
|
|
|
|
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
|
|
|
fun = core.jaxpr_as_fun(jaxpr)
|
|
|
|
|
|
|
|
@lu.wrap_init
|
|
|
|
def masked(*args):
|
|
|
|
[dynamic_length], consts, [i], carry, xs = split_list(
|
|
|
|
args, [1, num_consts, 1, num_carry])
|
|
|
|
out = fun(*(consts + carry + xs))
|
|
|
|
new_carry, ys = split_list(out, [num_carry])
|
|
|
|
new_carry = [lax.select(i < dynamic_length, new_c, c)
|
|
|
|
for new_c, c in zip(new_carry, carry)]
|
|
|
|
return [i + 1] + new_carry + ys
|
|
|
|
|
|
|
|
aval = ShapedArray((), dtypes.int_)
|
|
|
|
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
|
|
|
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
|
|
|
|
|
|
|
def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
|
|
|
|
jaxpr, linear, unroll):
|
|
|
|
tc = partial(_typecheck_param, 'scan')
|
|
|
|
tc(reverse, 'reverse', 'bool', type(reverse) is bool)
|
|
|
|
tc(num_consts, 'num_consts', 'non-negative int',
|
|
|
|
type(num_consts) is int and num_consts >= 0)
|
|
|
|
tc(num_carry, 'num_carry', 'non-negative int',
|
|
|
|
type(num_carry) is int and num_carry >= 0)
|
|
|
|
tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr)
|
|
|
|
tc(linear, 'linear', 'tuple of bool',
|
|
|
|
type(linear) is tuple and all(type(x) is bool for x in linear))
|
|
|
|
tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0)
|
|
|
|
|
|
|
|
length_types = (int, masking.Poly) if bind_time else (int,)
|
|
|
|
tc(length, 'length', 'non-negative int',
|
|
|
|
type(length) in length_types and length >= 0)
|
|
|
|
|
|
|
|
core.typecheck_assert(
|
|
|
|
len(linear) == len(avals),
|
|
|
|
f'scan param linear has length {len(linear)} for {len(avals)} operands')
|
|
|
|
|
|
|
|
const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
|
|
|
|
const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list(
|
|
|
|
jaxpr.in_avals, [num_consts, num_carry])
|
|
|
|
carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry])
|
2020-11-05 11:54:05 +00:00
|
|
|
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
core.typecheck_assert(
|
|
|
|
all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)),
|
|
|
|
f'scan input carry input and output types mismatch: '
|
|
|
|
f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
|
|
|
|
core.typecheck_assert(
|
|
|
|
all(_map(core.typecompat, const_avals_jaxpr, const_avals)),
|
|
|
|
f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
|
|
|
|
f'called with consts of type\n{_avals_short(const_avals)}')
|
|
|
|
core.typecheck_assert(
|
|
|
|
all(_map(core.typecompat, init_avals_jaxpr, init_avals)),
|
|
|
|
f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
|
|
|
|
f'called with initial carry of type\n{_avals_short(init_avals)}')
|
|
|
|
core.typecheck_assert(
|
|
|
|
all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)),
|
|
|
|
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
|
|
|
|
f'called with sequence of type\n{_avals_short(x_avals)}')
|
|
|
|
|
|
|
|
def scan_bind(*args, **params):
|
|
|
|
if not core.skip_checks:
|
|
|
|
avals = _map(core.get_aval, args)
|
|
|
|
_scan_typecheck(True, *avals, **params)
|
|
|
|
core.check_jaxpr(params['jaxpr'].jaxpr)
|
|
|
|
return core.Primitive.bind(scan_p, *args, **params)
|
|
|
|
|
|
|
|
scan_p = core.Primitive("scan")
|
|
|
|
scan_p.multiple_results = True
|
|
|
|
scan_p.def_custom_bind(scan_bind)
|
|
|
|
scan_p.def_impl(_scan_impl)
|
|
|
|
# scan_p.def_impl(partial(xla.apply_primitive, scan_p)) # TODO(mattjj): re-enable
|
|
|
|
scan_p.def_abstract_eval(_scan_abstract_eval)
|
|
|
|
ad.primitive_jvps[scan_p] = _scan_jvp
|
|
|
|
ad.primitive_transposes[scan_p] = _scan_transpose
|
|
|
|
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
|
|
|
xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl)
|
2020-10-26 10:11:13 +00:00
|
|
|
batching.initial_style_batchers[scan_p] = _scan_batching_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
masking.masking_rules[scan_p] = _scan_masking_rule
|
|
|
|
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
|
|
|
|
|
|
|
|
|
|
|
def map(f, xs):
|
|
|
|
"""Map a function over leading array axes.
|
|
|
|
|
|
|
|
Like Python's builtin map, except inputs and outputs are in the form of
|
|
|
|
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
|
|
|
|
need to apply a function element by element for reduced memory usage or
|
|
|
|
heterogeneous computation with other control flow primitives.
|
|
|
|
|
|
|
|
When ``xs`` is an array type, the semantics of ``map`` are given by this
|
|
|
|
Python implementation::
|
|
|
|
|
|
|
|
def map(f, xs):
|
|
|
|
return np.stack([f(x) for x in xs])
|
|
|
|
|
|
|
|
Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
|
|
|
|
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
|
|
|
|
nested pytree type, and the mapped computation is compiled only once.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: a Python function to apply element-wise over the first axis or axes of
|
|
|
|
``xs``.
|
|
|
|
xs: values over which to map along the leading axis.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Mapped values.
|
|
|
|
"""
|
|
|
|
g = lambda _, x: ((), f(x))
|
|
|
|
_, ys = scan(g, (), xs)
|
|
|
|
return ys
|
|
|
|
|
|
|
|
|
|
|
|
def _concat_masking_rule(padded_vals, logical_shapes, dimension):
|
|
|
|
result = lax.concatenate(padded_vals, dimension) # fragmented
|
|
|
|
offset = 0
|
|
|
|
for padded_val, logical_shape in zip(padded_vals, logical_shapes):
|
|
|
|
result = _memcpy(dimension, logical_shape[dimension], padded_val,
|
|
|
|
result, offset)
|
|
|
|
offset = offset + logical_shape[dimension]
|
|
|
|
return result
|
|
|
|
|
|
|
|
def _memcpy(axis, num, src, dst, offset):
|
|
|
|
def body(i, dst):
|
|
|
|
update = lax.dynamic_index_in_dim(src, i, axis)
|
|
|
|
return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
|
|
|
|
return fori_loop(0, num, body, dst)
|
|
|
|
|
|
|
|
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
|
|
|
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
|
|
|
|
|
|
|
|
Corresponding `tree` and `avals` must match in the sense that the number of
|
|
|
|
leaves in `tree` must be equal to the length of `avals`. `what` will be
|
|
|
|
prepended to details of the mismatch in TypeError.
|
|
|
|
"""
|
|
|
|
if tree1 != tree2:
|
2020-11-16 19:33:14 -08:00
|
|
|
raise TypeError(
|
|
|
|
f"{what} must have same type structure, got {tree1} and {tree2}.")
|
|
|
|
if not all(_map(core.typematch, avals1, avals2)):
|
|
|
|
raise TypeError(
|
|
|
|
f"{what} must have identical types, got\n"
|
|
|
|
f"{tree_unflatten(tree1, avals1)}\nand\n"
|
|
|
|
f"{tree_unflatten(tree2, avals2)}.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
|
|
|
|
if actual_tree != expected_tree:
|
|
|
|
raise TypeError(
|
|
|
|
f"{func_name}() output pytree structure must match {expected_name}, "
|
|
|
|
f"got {actual_tree} and {expected_tree}.")
|
|
|
|
|
|
|
|
|
|
|
|
def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
|
|
|
|
"""Promote weakly-typed in_vals to be compatible with out_avals.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
in_vals : flattened list of input values.
|
|
|
|
in_avals : corresponding list of avals.
|
|
|
|
out_avals : list of target output avals.
|
|
|
|
Returns:
|
|
|
|
in_vals_new : flattened list of modified in_vals with no weak types.
|
|
|
|
changed : bool; true if in_vals required modification.
|
|
|
|
"""
|
|
|
|
if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
|
|
|
|
# Calling function is responsible for catching this.
|
|
|
|
return in_vals, False
|
|
|
|
weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
|
|
|
|
if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
|
|
|
|
if not weak_mismatches:
|
|
|
|
return in_vals, False
|
|
|
|
for i in weak_mismatches:
|
|
|
|
new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
|
|
|
|
in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
|
|
|
|
return in_vals, True
|
|
|
|
|
|
|
|
def _stop_gradient_fun(f):
|
|
|
|
"""Create a version of f() that stops all gradients."""
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
args_flat, in_args_tree = tree_flatten((args, kwargs))
|
|
|
|
args_avals = tuple(_map(_abstractify, args_flat))
|
|
|
|
g = lambda a, b: f(*a, **b)
|
|
|
|
jaxpr, consts, out_tree = _initial_style_jaxpr(g, in_args_tree, args_avals)
|
|
|
|
all_args = _map(lax.stop_gradient, (*consts, *args_flat))
|
|
|
|
out = core.jaxpr_as_fun(jaxpr)(*all_args)
|
|
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
|
|
|
|
|
|
|
|
|
|
|
|
def _split_root_args(args, const_lengths):
|
|
|
|
params_list = split_list(args, list(const_lengths))
|
|
|
|
return _RootTuple(*params_list[:-1]), params_list[-1]
|
|
|
|
|
|
|
|
|
|
|
|
def custom_root(f, initial_guess, solve, tangent_solve):
|
|
|
|
"""Differentiably solve for a roots of a function.
|
|
|
|
|
|
|
|
This is a low-level routine, mostly intended for internal use in JAX.
|
|
|
|
Gradients of custom_root() are defined with respect to closed-over variables
|
|
|
|
from the provided function ``f`` via the implicit function theorem:
|
|
|
|
https://en.wikipedia.org/wiki/Implicit_function_theorem
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: function for which to find a root. Should accept a single argument,
|
|
|
|
return a tree of arrays with the same structure as its input.
|
|
|
|
initial_guess: initial guess for a zero of f.
|
|
|
|
solve: function to solve for the roots of f. Should take two positional
|
|
|
|
arguments, f and initial_guess, and return a solution with the same
|
|
|
|
structure as initial_guess such that func(solution) = 0. In other words,
|
|
|
|
the following is assumed to be true (but not checked)::
|
|
|
|
|
|
|
|
solution = solve(f, initial_guess)
|
|
|
|
error = f(solution)
|
|
|
|
assert all(error == 0)
|
|
|
|
|
|
|
|
tangent_solve: function to solve the tangent system. Should take two
|
|
|
|
positional arguments, a linear function ``g`` (the function ``f``
|
|
|
|
linearized at its root) and a tree of array(s) ``y`` with the same
|
|
|
|
structure as initial_guess, and return a solution ``x`` such that
|
|
|
|
``g(x)=y``:
|
|
|
|
|
|
|
|
- For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
|
|
|
|
- For vector ``y``, you could use a linear solve with the Jacobian, if
|
|
|
|
dimensionality of ``y`` is not too large:
|
|
|
|
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The result of calling solve(f, initial_guess) with gradients defined via
|
|
|
|
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
|
|
|
|
"""
|
|
|
|
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
|
|
|
guess_avals = tuple(_map(_abstractify, guess_flat))
|
|
|
|
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
|
|
|
f, in_args_tree, guess_avals)
|
|
|
|
|
|
|
|
in_tree, = treedef_children(in_args_tree)
|
|
|
|
_check_tree("f", "initial_guess", out_tree, in_tree)
|
|
|
|
|
|
|
|
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
|
|
|
|
partial(solve, _stop_gradient_fun(f)), in_args_tree, guess_avals)
|
|
|
|
_check_tree("solve", "initial_guess", solution_tree, in_tree)
|
|
|
|
|
|
|
|
def linearize_and_solve(x, b):
|
|
|
|
unchecked_zeros, f_jvp = jax.linearize(f, x)
|
|
|
|
return tangent_solve(f_jvp, b)
|
|
|
|
|
|
|
|
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
|
|
|
|
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
|
|
|
|
_check_tree("tangent_solve", "x", out_tree, in_tree)
|
|
|
|
|
|
|
|
all_consts = [f_consts, solve_consts, l_and_s_consts]
|
|
|
|
const_lengths = _RootTuple(*_map(len, all_consts))
|
|
|
|
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
|
|
|
|
|
|
|
|
out_flat = _custom_root(
|
|
|
|
const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
|
|
|
|
return tree_unflatten(out_tree, out_flat)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.custom_jvp, nondiff_argnums=(0, 1))
|
|
|
|
def _custom_root(const_lengths, jaxprs, *args):
|
|
|
|
params, initial_guess = _split_root_args(args, const_lengths)
|
|
|
|
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
|
|
|
|
return solution
|
|
|
|
|
|
|
|
|
|
|
|
@_custom_root.defjvp
|
|
|
|
def _root_jvp(const_lengths, jaxprs, primals, tangents):
|
|
|
|
params, _ = _split_root_args(primals, const_lengths)
|
|
|
|
solution = _custom_root(const_lengths, jaxprs, *primals)
|
|
|
|
|
|
|
|
params_dot, _ = _split_root_args(tangents, const_lengths)
|
|
|
|
|
|
|
|
# F(m, u) = 0 # system of equations in u, parameterized by m
|
|
|
|
# # solution is u*(m) defined in a neighborhood
|
|
|
|
# F(m, u*(m)) = 0 # satisfied in a neighborhood
|
|
|
|
#
|
|
|
|
# ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above
|
|
|
|
# ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange
|
|
|
|
#
|
|
|
|
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
|
|
|
|
|
|
|
|
f = core.jaxpr_as_fun(jaxprs.f)
|
|
|
|
linearize_and_solve = partial(
|
|
|
|
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
|
|
|
|
f_at_solution = lambda *params: f(*itertools.chain(params, solution))
|
|
|
|
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
|
|
|
|
params.f, params_dot.f)
|
|
|
|
solution_dot = _map(
|
|
|
|
operator.neg, linearize_and_solve(*itertools.chain(solution, rhs)))
|
|
|
|
|
|
|
|
return solution, solution_dot
|
|
|
|
|
|
|
|
|
|
|
|
class _LinearSolveTuple(collections.namedtuple(
|
|
|
|
'_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):
|
|
|
|
|
|
|
|
def transpose(self):
|
|
|
|
return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)
|
|
|
|
|
|
|
|
|
|
|
|
def _split_linear_solve_args(args, const_lengths):
|
|
|
|
params_list = split_list(args, list(const_lengths))
|
|
|
|
return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
|
|
|
|
|
|
|
|
|
|
|
|
def _transpose_one_output(linear_fun, primals):
|
|
|
|
transpose_fun = jax.linear_transpose(linear_fun, primals)
|
|
|
|
def transposed_fun(x):
|
|
|
|
(y,) = transpose_fun(x)
|
|
|
|
return y
|
|
|
|
return transposed_fun
|
|
|
|
|
|
|
|
|
|
|
|
def _flatten(args):
|
|
|
|
return [x for arg in args for x in arg]
|
|
|
|
|
|
|
|
|
|
|
|
def _check_shapes(func_name, expected_name, actual, expected):
|
|
|
|
actual_shapes = _map(np.shape, tree_leaves(actual))
|
|
|
|
expected_shapes = _map(np.shape, tree_leaves(expected))
|
|
|
|
if actual_shapes != expected_shapes:
|
|
|
|
raise ValueError(
|
|
|
|
f"{func_name}() output shapes must match {expected_name}, "
|
|
|
|
f"got {actual_shapes} and {expected_shapes}")
|
|
|
|
|
|
|
|
|
|
|
|
def custom_linear_solve(
|
|
|
|
matvec, b, solve, transpose_solve=None, symmetric=False):
|
|
|
|
"""Perform a matrix-free linear solve with implicitly defined gradients.
|
|
|
|
|
|
|
|
This function allows for overriding or defining gradients for a linear
|
|
|
|
solve directly via implicit differentiation at the solution, rather than by
|
|
|
|
differentiating *through* the solve operation. This can sometimes be much faster
|
|
|
|
or more numerically stable, or differentiating through the solve operation
|
|
|
|
may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``).
|
|
|
|
|
|
|
|
Required invariant::
|
|
|
|
|
|
|
|
x = solve(matvec, b) # solve the linear equation
|
|
|
|
assert matvec(x) == b # not checked
|
|
|
|
|
|
|
|
Args:
|
|
|
|
matvec: linear function to invert. Must be differentiable.
|
|
|
|
b: constant right handle side of the equation. May be any nested structure
|
|
|
|
of arrays.
|
|
|
|
solve: higher level function that solves for solution to the linear
|
|
|
|
equation, i.e., ``solve(matvec, x)) == x`` for all ``x`` of the same form
|
|
|
|
as ``b``. This function need not be differentiable.
|
|
|
|
transpose_solve: higher level function for solving the transpose linear
|
|
|
|
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
|
|
|
|
the transpose of the linear map ``matvec`` (computed automatically with
|
|
|
|
autodiff). Required for backwards mode automatic differentiation, unless
|
|
|
|
``symmetric=True``, in which case ``solve`` provides the default value.
|
|
|
|
symmetric: bool indicating if it is safe to assume the linear map
|
|
|
|
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Result of ``solve(matvec, b)``, with gradients defined assuming that the
|
|
|
|
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
|
|
|
|
"""
|
|
|
|
if transpose_solve is None and symmetric:
|
|
|
|
transpose_solve = solve
|
|
|
|
|
|
|
|
b_flat, in_args_tree = tree_flatten((b,))
|
|
|
|
b_avals = tuple(_map(_abstractify, b_flat))
|
|
|
|
|
|
|
|
tree, = treedef_children(in_args_tree)
|
|
|
|
|
|
|
|
def _shape_checked(fun, name):
|
|
|
|
def f(x):
|
|
|
|
y = fun(x)
|
|
|
|
_check_shapes(name, "b", y, b_flat)
|
|
|
|
return y
|
|
|
|
return f
|
|
|
|
|
|
|
|
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
|
|
|
|
_shape_checked(matvec, "matvec"), in_args_tree, b_avals)
|
|
|
|
_check_tree("matvec", "b", out_tree, tree)
|
|
|
|
|
|
|
|
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
|
|
|
|
_shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals)
|
|
|
|
_check_tree("solve", "b", out_tree, tree)
|
|
|
|
|
|
|
|
if transpose_solve is None:
|
|
|
|
vecmat_jaxpr = tr_solve_jaxpr = None
|
|
|
|
vecmat_consts = tr_solve_consts = []
|
|
|
|
else:
|
|
|
|
if symmetric:
|
|
|
|
vecmat = matvec
|
|
|
|
vecmat_jaxpr = matvec_jaxpr
|
|
|
|
vecmat_consts = matvec_consts
|
|
|
|
else:
|
|
|
|
vecmat = _transpose_one_output(matvec, b)
|
|
|
|
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
|
|
|
|
vecmat, in_args_tree, b_avals)
|
|
|
|
assert out_tree == tree
|
|
|
|
|
|
|
|
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
|
|
|
|
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve"),
|
|
|
|
in_args_tree, b_avals)
|
|
|
|
_check_tree("transpose_solve", "b", out_tree, tree)
|
|
|
|
|
|
|
|
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
|
|
|
|
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
|
|
|
|
jaxprs = _LinearSolveTuple(
|
|
|
|
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
|
|
|
|
|
|
|
|
out_flat = linear_solve_p.bind(
|
|
|
|
*(_flatten(all_consts) + b_flat),
|
|
|
|
const_lengths=const_lengths, jaxprs=jaxprs)
|
|
|
|
return tree_unflatten(tree, out_flat)
|
|
|
|
|
|
|
|
|
|
|
|
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
|
|
|
|
return _map(raise_to_shaped, args[sum(const_lengths):])
|
|
|
|
|
|
|
|
|
|
|
|
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
|
|
|
params, b = _split_linear_solve_args(args, const_lengths)
|
|
|
|
x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def _tangent_linear_map(func, params, params_dot, *x):
|
|
|
|
"""Compute the tangent of a linear map.
|
|
|
|
|
|
|
|
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
|
|
|
|
this function computes ``∂A @ x``.
|
|
|
|
"""
|
|
|
|
assert any(type(p) is not ad_util.Zero for p in params_dot)
|
|
|
|
zeros = _map(ad_util.Zero.from_value, x)
|
|
|
|
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
|
|
|
|
params + list(x), params_dot + zeros)
|
|
|
|
return out_tangent
|
|
|
|
|
|
|
|
|
|
|
|
def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
|
|
|
|
# A x - b = 0
|
|
|
|
# ∂A x + A ∂x - ∂b = 0
|
|
|
|
# ∂x = A^{-1} (∂b - ∂A x)
|
|
|
|
|
|
|
|
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs)
|
|
|
|
x = linear_solve_p.bind(*primals, **kwargs)
|
|
|
|
|
|
|
|
params, _ = _split_linear_solve_args(primals, const_lengths)
|
|
|
|
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
|
|
|
|
|
|
|
|
if all(type(p) is ad_util.Zero for p in params_dot.matvec):
|
|
|
|
# no need to evaluate matvec_tangents
|
|
|
|
rhs = b_dot
|
|
|
|
else:
|
|
|
|
matvec_tangents = _tangent_linear_map(
|
|
|
|
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x)
|
|
|
|
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
|
|
|
|
|
|
|
|
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
|
|
|
|
|
|
|
|
return x, x_dot
|
|
|
|
|
|
|
|
|
|
|
|
def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
|
|
|
|
if jaxprs.transpose_solve is None:
|
|
|
|
raise TypeError('transpose_solve required for backwards mode automatic '
|
|
|
|
'differentiation of custom_linear_solve')
|
|
|
|
|
|
|
|
params, b = _split_linear_solve_args(primals, const_lengths)
|
|
|
|
assert all(ad.is_undefined_primal(x) for x in b)
|
|
|
|
cotangent_b = linear_solve_p.bind(
|
|
|
|
*(_flatten(params.transpose()) + cotangent),
|
|
|
|
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
|
|
|
|
return [None] * sum(const_lengths) + cotangent_b
|
|
|
|
|
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs):
|
2020-10-17 14:33:26 -04:00
|
|
|
orig_bat = [d is not batching.not_mapped for d in dims]
|
|
|
|
size, = {
|
|
|
|
a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped
|
|
|
|
}
|
|
|
|
|
|
|
|
params, b = _split_linear_solve_args(args, const_lengths)
|
|
|
|
params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
|
|
|
|
params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)
|
|
|
|
|
|
|
|
(matvec, vecmat, solve, solve_t) = jaxprs
|
|
|
|
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
|
|
|
|
|
|
|
|
# Fixpoint computation of which parts of x and b are batched; we need to
|
|
|
|
# ensure this is consistent between all four jaxprs
|
|
|
|
b_bat = orig_b_bat
|
|
|
|
x_bat = [False] * len(solve.out_avals)
|
|
|
|
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
|
|
|
|
# Apply vecmat and solve -> new batched parts of x
|
|
|
|
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
|
2020-10-26 10:11:13 +00:00
|
|
|
solve, size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
if vecmat is None:
|
|
|
|
vecmat_jaxpr_batched = None
|
|
|
|
x_bat_out = solve_x_bat
|
|
|
|
else:
|
|
|
|
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
|
2020-10-26 10:11:13 +00:00
|
|
|
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
|
|
|
|
# Apply matvec and solve_t -> new batched parts of b
|
|
|
|
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
|
2020-10-26 10:11:13 +00:00
|
|
|
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
if solve_t is None:
|
|
|
|
solve_t_jaxpr_batched = None
|
|
|
|
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
|
|
|
|
else:
|
|
|
|
solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr(
|
2020-10-26 10:11:13 +00:00
|
|
|
solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name)
|
2020-10-17 14:33:26 -04:00
|
|
|
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
|
|
|
|
orig_b_bat)
|
|
|
|
if x_bat_out == x_bat and b_bat_out == b_bat:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
x_bat = x_bat_out
|
|
|
|
b_bat = b_bat_out
|
|
|
|
else:
|
|
|
|
assert False, "Fixedpoint not reached"
|
|
|
|
|
|
|
|
batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched,
|
|
|
|
solve_jaxpr_batched, solve_t_jaxpr_batched)
|
|
|
|
|
|
|
|
# Move batched axes to the front
|
|
|
|
new_params = [
|
|
|
|
batching.moveaxis(x, d, 0)
|
|
|
|
if d is not batching.not_mapped and d != 0 else x
|
|
|
|
for x, d in zip(_flatten(params), _flatten(params_dims))
|
|
|
|
]
|
|
|
|
# Broadcast out b if necessary
|
|
|
|
new_b = [
|
|
|
|
batching.broadcast(x, size, 0) if now_bat and not was_bat else
|
|
|
|
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
|
|
|
|
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
|
|
|
|
]
|
|
|
|
|
|
|
|
outs = linear_solve_p.bind(
|
|
|
|
*(new_params + new_b),
|
|
|
|
const_lengths=const_lengths,
|
|
|
|
jaxprs=batched_jaxprs)
|
|
|
|
out_dims = [0 if batched else batching.not_mapped for batched in b_bat]
|
|
|
|
return outs, out_dims
|
|
|
|
|
|
|
|
|
|
|
|
linear_solve_p = core.Primitive('custom_linear_solve')
|
|
|
|
linear_solve_p.multiple_results = True
|
|
|
|
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
|
|
|
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
|
|
|
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
|
|
|
xla.initial_style_translations[linear_solve_p] = \
|
|
|
|
xla.lower_fun_initial_style(_custom_linear_solve_impl)
|
|
|
|
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
|
2020-10-26 10:11:13 +00:00
|
|
|
batching.initial_style_batchers[linear_solve_p] = _linear_solve_batching_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _interleave(a, b, axis):
|
|
|
|
"""Given two Tensors of static shape, interleave them along the first axis."""
|
|
|
|
assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
|
|
|
|
a_pad = [(0, 0, 0)] * a.ndim
|
|
|
|
b_pad = [(0, 0, 0)] * b.ndim
|
|
|
|
a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
|
|
|
|
b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
|
|
|
|
return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
|
|
|
|
lax.pad(b, lax._const(b, 0), b_pad))
|
|
|
|
|
|
|
|
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
|
|
|
"""Performs a scan with an associative binary operation, in parallel.
|
|
|
|
|
|
|
|
For an introduction to associative scans, see [BLE1990]_.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fn: A Python callable implementing an associative binary operation with
|
|
|
|
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
|
|
|
|
must satisfy the equation
|
|
|
|
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
|
|
|
|
|
|
|
|
The inputs and result are (possibly nested Python tree structures of)
|
|
|
|
array(s) matching ``elems``. Each array has a dimension in place
|
|
|
|
of the ``axis`` dimension. `fn` should be applied elementwise over
|
|
|
|
the ``axis`` dimension (for example, by using :func:`jax.vmap` over the
|
|
|
|
elementwise function.)
|
|
|
|
|
|
|
|
The result ``r`` has the same shape (and structure) as the two inputs
|
|
|
|
``a`` and ``b``.
|
|
|
|
elems: A (possibly nested Python tree structure of) array(s), each with
|
|
|
|
an ``axis`` dimension of size ``num_elems``.
|
|
|
|
reverse: A boolean stating if the scan should be reversed with respect to
|
|
|
|
the ``axis`` dimension.
|
|
|
|
axis: an integer identifying the axis over which the scan should occur.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A (possibly nested Python tree structure of) array(s) of the same shape
|
|
|
|
and structure as ``elems``, in which the ``k``'th element of ``axis`` is the
|
|
|
|
result of recursively applying ``fn`` to combine the first ``k`` elements
|
|
|
|
of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``,
|
|
|
|
the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
|
|
|
|
|
|
|
|
Example 1: partial sums of an array of numbers:
|
|
|
|
|
|
|
|
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
|
|
|
|
[ 0, 1, 3, 6]
|
|
|
|
|
|
|
|
Example 2: partial products of an array of matrices
|
|
|
|
|
|
|
|
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
|
|
|
|
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
|
|
|
|
>>> partial_prods.shape
|
|
|
|
(4, 2, 2)
|
|
|
|
|
|
|
|
Example 3: reversed partial sums of an array of numbers
|
|
|
|
|
|
|
|
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
|
|
|
|
[ 6, 6, 5, 3]
|
|
|
|
|
|
|
|
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
|
|
|
|
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
|
|
|
|
University.
|
|
|
|
"""
|
|
|
|
elems_flat, tree = tree_flatten(elems)
|
|
|
|
|
|
|
|
if reverse:
|
|
|
|
elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat]
|
|
|
|
|
|
|
|
def combine(a_flat, b_flat):
|
|
|
|
# Lower `fn` to operate on flattened sequences of elems.
|
|
|
|
a = tree_unflatten(tree, a_flat)
|
|
|
|
b = tree_unflatten(tree, b_flat)
|
|
|
|
c = fn(a, b)
|
|
|
|
c_flat, _ = tree_flatten(c)
|
|
|
|
return c_flat
|
|
|
|
|
|
|
|
# Check that all inputs have a consistent leading dimension `num_elems`.
|
|
|
|
axis = lax._canonicalize_axis(axis, elems_flat[0].ndim)
|
|
|
|
num_elems = int(elems_flat[0].shape[axis])
|
|
|
|
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
|
|
|
|
raise ValueError('Array inputs to associative_scan must have the same '
|
|
|
|
'first dimension. (saw: {})'
|
|
|
|
.format([elems.shape for elem in elems_flat]))
|
|
|
|
|
|
|
|
|
|
|
|
# Summary of algorithm:
|
|
|
|
#
|
|
|
|
# Consider elements of `_scan(elems)` at odd indices. That's the same as first
|
|
|
|
# summing successive pairs of elements of `elems` and performing a scan on
|
|
|
|
# that half sized tensor. We perform the latter scan by recursion.
|
|
|
|
#
|
|
|
|
# Now consider the even elements of `_scan(elems)`. These can be computed
|
|
|
|
# from the odd elements of `_scan(elems)` by adding each odd element of
|
|
|
|
# `_scan(elems)` to the matching even element in the original `elems`.
|
|
|
|
#
|
|
|
|
# We return the odd and even elements interleaved.
|
|
|
|
#
|
|
|
|
# For the base case of the recursion we return the first element
|
|
|
|
# of `elems` followed by the sum of the first two elements computed as
|
|
|
|
# a (small two-down-to-one) reduction step.
|
|
|
|
def _scan(elems):
|
|
|
|
"""Perform scan on `elems`."""
|
|
|
|
|
|
|
|
num_elems = elems[0].shape[axis]
|
|
|
|
|
|
|
|
if num_elems < 2:
|
|
|
|
return elems
|
|
|
|
|
|
|
|
# Combine adjacent pairs of elements.
|
|
|
|
reduced_elems = combine(
|
|
|
|
[lax.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems],
|
|
|
|
[lax.slice_in_dim(elem, 1, None, stride=2, axis=axis) for elem in elems])
|
|
|
|
|
|
|
|
# Recursively compute scan for partially reduced tensors.
|
|
|
|
odd_elems = _scan(reduced_elems)
|
|
|
|
|
|
|
|
if num_elems % 2 == 0:
|
|
|
|
even_elems = combine(
|
|
|
|
[lax.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems],
|
|
|
|
[lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
|
|
|
|
else:
|
|
|
|
even_elems = combine(
|
|
|
|
odd_elems,
|
|
|
|
[lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
|
|
|
|
|
|
|
|
# The first element of a scan is the same as the first element
|
|
|
|
# of the original `elems`.
|
|
|
|
even_elems = [
|
|
|
|
lax.concatenate([lax.slice_in_dim(elem, 0, 1, axis=axis), result],
|
|
|
|
dimension=axis)
|
|
|
|
for (elem, result) in zip(elems, even_elems)]
|
|
|
|
return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems))
|
|
|
|
|
|
|
|
scans = _scan(elems_flat)
|
|
|
|
|
|
|
|
if reverse:
|
|
|
|
scans = [lax.rev(scanned, [axis]) for scanned in scans]
|
|
|
|
|
|
|
|
return tree_unflatten(tree, scans)
|
|
|
|
|
|
|
|
|
|
|
|
# Cumulative reductions.
|
|
|
|
|
|
|
|
def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
|
|
|
"""Computes a cumulative sum along `axis`."""
|
|
|
|
return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
|
|
|
|
|
|
|
def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
|
|
|
"""Computes a cumulative product along `axis`."""
|
|
|
|
return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
|
|
|
|
|
|
|
def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
|
|
|
"""Computes a cumulative maximum along `axis`."""
|
|
|
|
return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
|
|
|
|
|
|
|
def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
|
|
|
"""Computes a cumulative minimum along `axis`."""
|
|
|
|
return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
|
|
|
|
|
|
|
def _cumred_shape_rule(x, *, axis: int, reverse: bool):
|
|
|
|
if axis < 0 or axis >= x.ndim:
|
|
|
|
raise ValueError(
|
|
|
|
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
|
|
|
|
return x.shape
|
|
|
|
|
2020-12-30 17:42:04 -08:00
|
|
|
def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool):
|
2020-10-17 14:33:26 -04:00
|
|
|
return [cumsum(t, axis=axis, reverse=not reverse)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
|
|
|
|
axis: int, reverse: bool):
|
|
|
|
# On TPU, an implementation using reduce_window is handled specially by the
|
|
|
|
# compiler and is efficient. On other backends, it is O(n^2).
|
|
|
|
n = x.shape[axis]
|
|
|
|
if n == 0:
|
|
|
|
return x
|
|
|
|
padding = [(0, 0)] * x.ndim
|
|
|
|
padding[axis] = (0, n - 1) if reverse else (n - 1, 0)
|
|
|
|
strides = [1] * x.ndim
|
|
|
|
window_dims = [1] * x.ndim
|
|
|
|
window_dims[axis] = n
|
|
|
|
return window_reduce(x, window_dims, strides, padding)
|
|
|
|
|
|
|
|
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int,
|
|
|
|
reverse: bool):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
axis = axis if axis < bdim else axis + 1
|
|
|
|
return prim.bind(operand, axis=axis, reverse=reverse), bdim
|
|
|
|
|
|
|
|
def _cumred_dtype_rule(name, operand, *args, **kw):
|
|
|
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
|
|
|
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
|
|
|
|
"of number.".format(name, np.dtype(operand.dtype).name))
|
|
|
|
return dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
|
|
|
|
cumsum_p = lax.standard_primitive(
|
|
|
|
_cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"),
|
|
|
|
'cumsum')
|
2020-12-30 17:42:04 -08:00
|
|
|
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
|
|
|
partial(_cumred_tpu_translation_rule, lax._reduce_window_sum),
|
|
|
|
multiple_results=False)
|
|
|
|
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
|
|
|
|
|
|
|
|
|
|
|
def _cumulative_reduction_primitive(name, reduce_window_fn):
|
|
|
|
reducer_p = lax.standard_primitive(
|
|
|
|
_cumred_shape_rule, partial(_cumred_dtype_rule, name),
|
|
|
|
name)
|
|
|
|
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
|
|
|
|
partial(_cumred_tpu_translation_rule, reduce_window_fn),
|
|
|
|
multiple_results=False)
|
|
|
|
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
|
|
|
|
return reducer_p
|
|
|
|
|
|
|
|
|
|
|
|
cumprod_p = _cumulative_reduction_primitive("cumprod", lax._reduce_window_prod)
|
|
|
|
cummax_p = _cumulative_reduction_primitive("cummax", lax._reduce_window_max)
|
|
|
|
cummin_p = _cumulative_reduction_primitive("cummin", lax._reduce_window_min)
|
|
|
|
|
|
|
|
xla.translations[cumsum_p] = xla.lower_fun(
|
|
|
|
partial(associative_scan, lax.add), multiple_results=False)
|
|
|
|
xla.translations[cumprod_p] = xla.lower_fun(
|
|
|
|
partial(associative_scan, lax.mul), multiple_results=False)
|
|
|
|
xla.translations[cummin_p] = xla.lower_fun(
|
|
|
|
partial(associative_scan, lax.min), multiple_results=False)
|
|
|
|
xla.translations[cummax_p] = xla.lower_fun(
|
|
|
|
partial(associative_scan, lax.max), multiple_results=False)
|
|
|
|
|
|
|
|
def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
|
|
|
|
combine_fn: Callable):
|
|
|
|
# Irrespective of backend, we always use the parallel prefix scan
|
|
|
|
# implementation when differentiating because reduce_window is not
|
|
|
|
# arbitrarily differentiable.
|
|
|
|
return api.jvp(partial(associative_scan, combine_fn, axis=axis,
|
|
|
|
reverse=reverse),
|
|
|
|
primals, tangents)
|
|
|
|
|
|
|
|
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
|
|
|
|
ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
|
|
|
|
ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)
|
|
|
|
|
|
|
|
|
|
|
|
@config.register_omnistaging_disabler
|
|
|
|
def omnistaging_disabler() -> None:
|
|
|
|
global _initial_style_open_jaxpr, _initial_style_jaxpr, \
|
|
|
|
_initial_style_jaxprs_with_common_consts
|
|
|
|
|
2020-12-30 10:02:18 -08:00
|
|
|
from jax.util import unzip4
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
@cache()
|
|
|
|
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
|
|
|
|
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
|
|
|
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
|
|
|
with core.initial_style_staging(): # type: ignore
|
|
|
|
jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
|
|
|
|
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
|
|
|
|
return jaxpr, out_pvals, consts, out_tree
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
|
|
|
|
jaxpr, out_pvals, consts, out_tree = _initial_style_open_jaxpr(
|
|
|
|
fun, in_tree, in_avals)
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
return closed_jaxpr, consts, out_tree()
|
|
|
|
|
2021-01-06 10:45:19 -08:00
|
|
|
@cache()
|
2020-10-17 14:33:26 -04:00
|
|
|
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
|
|
|
|
in_tree, in_avals):
|
|
|
|
# When staging the branches of a conditional into jaxprs, constants are
|
|
|
|
# extracted from each branch and converted to jaxpr arguments. To use the
|
|
|
|
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
|
|
|
# their (input) signatures to match. This function "joins" the staged jaxprs:
|
|
|
|
# for each one, it makes another that accepts *all* constants, but only uses
|
|
|
|
# those that it needs (dropping the rest).
|
|
|
|
|
|
|
|
jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([
|
|
|
|
_initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs])
|
|
|
|
|
|
|
|
newvar = core.gensym(jaxprs, suffix='_')
|
|
|
|
all_const_avals = tuple(
|
|
|
|
tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
|
|
|
|
for consts in all_consts)
|
|
|
|
unused_const_vars = tuple(
|
|
|
|
tuple(newvar(aval) for aval in const_avals)
|
|
|
|
for const_avals in all_const_avals)
|
|
|
|
|
|
|
|
def pad_jaxpr_constvars(i, jaxpr):
|
|
|
|
prefix = util.concatenate(unused_const_vars[:i])
|
|
|
|
suffix = util.concatenate(unused_const_vars[i+1:])
|
|
|
|
constvars = prefix + jaxpr.constvars + suffix
|
|
|
|
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
|
|
|
|
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
|
|
|
|
|
|
|
def type_and_const_convert_jaxpr(jaxpr, out_pvals):
|
|
|
|
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
|
|
|
|
|
|
|
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
|
|
|
|
closed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals)
|
|
|
|
|
|
|
|
return (tuple(closed_jaxprs),
|
|
|
|
tuple(util.concatenate(all_consts)),
|
|
|
|
tuple(out_tree() for out_tree in all_out_trees))
|