rocm_jax/jax/lax/lax_control_flow.py

1818 lines
78 KiB
Python

# coding=utf-8
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Control flow primitives.
"""
import collections
import functools
import itertools
import operator
import threading
from typing import Callable, Sequence
import numpy as onp
import jax
from jax import core
from jax import dtypes
from jax.lax import lax
from jax import linear_util as lu
from jax.abstract_arrays import ShapedArray, raise_to_shaped
from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs
from jax.core import get_aval
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import masking
from jax.lib import xla_bridge as xb
from jax.lib import xla_client
from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
split_dict, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_leaves,
tree_map, tree_multimap)
from jax import ad_util
xops = xla_client.ops
_map = safe_map
zip = safe_zip
_reduce = functools.reduce
@cache()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
with core.initial_style_staging():
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
(), const_avals + in_avals, out_avals)
return typed_jaxpr, consts, out_tree()
def _abstractify(x):
return raise_to_shaped(core.get_aval(x))
def typecheck(aval, x):
aval = raise_to_shaped(aval).strip_weak_type()
try:
return aval == core.lattice_join(aval, core.get_aval(x)).strip_weak_type()
except TypeError:
return False
def typematch(aval1, aval2):
return (raise_to_shaped(aval1).strip_weak_type() ==
raise_to_shaped(aval2).strip_weak_type())
def _disable_jit_impl(prim, interp, *args, **kwargs):
if jax.api._jit_is_disabled():
return interp(*args, **kwargs)
else:
return xla.apply_primitive(prim, *args, **kwargs)
### fori_loop and while_loop
def _fori_cond_fun(loop_carry):
i, upper, _ = loop_carry
return lax.lt(i, upper)
@cache()
def _fori_body_fun(body_fun):
def while_body_fun(loop_carry):
i, upper, x = loop_carry
return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
return while_body_fun
@cache()
def _fori_scan_body_fun(body_fun):
def scanned_fun(loop_carry, _):
i, upper, x = loop_carry
return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None
return scanned_fun
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to ``while_loop``.
The type signature in brief is
.. code-block:: haskell
fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a
The semantics of ``fori_loop`` are given by this Python implementation::
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
Unlike that Python version, ``fori_loop`` is implemented in terms of a call to
``while_loop``. See the docstring for ``while_loop`` for more information.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
other words, the type ``a`` in the type signature above represents an array
with a fixed shape and dtype (or a nested tuple/list/dict container data
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: function of type ``(int, a) -> a``.
init_val: initial loop carry value of type ``a``.
Returns:
Loop value from the final iteration, of type ``a``.
"""
# TODO(phawkins): perhaps do more type checking here, better error messages.
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
if lower_dtype != upper_dtype:
msg = ("lower and upper arguments to fori_loop must have equal types, "
"got {} and {}")
raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
# If we can specialize on the trip count, call scan instead of a while_loop
# to enable efficient reverse-mode differentiation.
try:
lower_ = int(lower)
upper_ = int(upper)
except TypeError:
use_scan = False
else:
use_scan = False # TODO(mattjj): re-enable this
if use_scan:
(_, _, result), _ = scan(_fori_scan_body_fun(body_fun),
(lower, upper, init_val), None,
length=upper_ - lower_)
else:
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
(lower, upper, init_val))
return result
def while_loop(cond_fun, body_fun, init_val):
"""Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
The type signature in brief is
.. code-block:: haskell
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
The semantics of ``while_loop`` are given by this Python implementation::
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
to a single XLA While HLO. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
other words, the type ``a`` in the type signature above represents an array
with a fixed shape and dtype (or a nested tuple/list/dict container data
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
Another difference from using Python-native loop constructs is that
``while_loop`` is not reverse-mode differentiable because XLA computations
require static bounds on memory requirements.
Args:
cond_fun: function of type ``a -> Bool``.
body_fun: function of type ``a -> a``.
init_val: value of type ``a``, a type that can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof, representing the initial
loop carry value.
Returns:
The output from the final iteration of body_fun, of type ``a``.
"""
if jax.api._jit_is_disabled():
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), onp.bool_):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
in_tree_children = in_tree.children()
assert len(in_tree_children) == 1
_check_tree_and_avals("body_fun output and input",
# Extract the subtree and avals for the first element of the return tuple
body_tree, body_jaxpr.out_avals,
in_tree_children[0], init_avals)
outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals),
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
return tree_unflatten(body_tree, outs)
def _while_loop_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts):
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
batched = bool(cond_jaxpr.out_avals[0].shape)
# Since jaxprs don't have tuples and have multiple return values, but we need
# the HLO While loop to take a single tuple input and output a single boolean
# (for the cond computation) or a single tuple output (for the body
# computation), we build XLA computations that handle the tuple munging before
# generating a Call into the computations formed from the jaxprs.
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
cond_c = xb.make_computation_builder("cond_computation")
cond_carry = xb.parameter(cond_c, 0, c.GetShape(init_carry))
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, cond_c),
cond_jaxpr.literals),
extend_name_stack(name_stack, 'cond'), *(x + z))
if batched:
scalar = ShapedArray((), onp.bool_)
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, onp.array(False))], or_,
list(range(cond_jaxpr.out_avals[0].ndim)))
body_c = xb.make_computation_builder("body_computation")
body_carry = xb.parameter(body_c, 0, c.GetShape(init_carry))
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, body_c), body_jaxpr.literals),
extend_name_stack(name_stack, 'body'), *(y + z))
if batched:
body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, body_c), cond_jaxpr.literals),
extend_name_stack(name_stack, 'body_pred'), *(x + z))
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z)))
ans = xops.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
return xops.Tuple(c, z)
def _pred_bcast_select(c, pred, x, y):
pred_shape = c.GetShape(pred).dimensions()
x_shape = c.GetShape(x).dimensions()
y_shape = c.GetShape(y).dimensions()
assert x_shape == y_shape
assert pred_shape == x_shape[:len(pred_shape)] == y_shape[:len(pred_shape)]
bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape))))
return xops.Select(bcast_pred, x, y)
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr,
body_nconsts, body_jaxpr):
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
orig_batched = [d is not batching.not_mapped for d in dims]
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
# Fixpoint computation of which carry are batched: either
# batched from init, or the carry out is batched. Each iteration promotes
# at least one carry to batched. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_bat.
carry_bat = init_bat
for _ in range(1 + len(carry_bat)):
batched = bconst_bat + carry_bat
body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
body_jaxpr, size, batched, instantiate=carry_bat)
cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False)
carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
if carry_bat_out == carry_bat:
break
else:
carry_bat = _map(operator.or_, carry_bat, carry_bat_out)
else:
assert False, "Fixpoint not reached"
consts, init = split_list(args, [cond_nconsts + body_nconsts])
const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(consts, const_dims)]
new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
else batching.moveaxis(x, d, 0) if now_bat else x
for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]
outs = while_p.bind(*(new_consts + new_init),
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
return outs, out_bdims
def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
body_jaxpr):
nonzeros = [t is not ad_util.zero for t in tangents]
cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts])
carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
body_nonzeros = bconst_nz + carry_nz
body_jvp, nonzeros_out = ad.jvp_jaxpr(
body_jaxpr, body_nonzeros, instantiate=carry_nz)
if nonzeros_out == carry_nz:
break
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
else:
assert False, "Fixpoint not reached"
nonzeros = cconst_nz + body_nonzeros
tangents = [ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t
for x, t, nz in zip(primals, tangents, nonzeros)]
cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts])
_, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts])
bconst_dot = _prune_zeros(bconst_dot)
init_dot = _prune_zeros(init_dot)
num_carry = len(primals) - cond_nconsts - body_nconsts
body_jvp_rearranged = ad.rearrange_binders(
body_jvp,
[body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
[num_carry], [len(init_dot)])
newvar = core.gensym('')
invars_aug = (
cond_jaxpr.jaxpr.invars + [newvar(get_aval(x)) for x in init_dot])
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
invars_aug,
cond_jaxpr.jaxpr.outvars,
cond_jaxpr.jaxpr.eqns)
in_avals_aug = (cond_jaxpr.in_avals[:cond_nconsts] +
body_jvp_rearranged.in_avals[body_nconsts + len(bconst_dot):])
cond_jaxpr_augmented = core.TypedJaxpr(cond_jaxpr_augmented,
cond_jaxpr.literals,
in_avals_aug,
cond_jaxpr.out_avals)
out = while_p.bind(
*(cconst + bconst + bconst_dot + init + init_dot),
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr_augmented,
body_nconsts=len(bconst) + len(bconst_dot),
body_jaxpr=body_jvp_rearranged)
out_carry, out_carry_dot = split_list(out, [num_carry])
out_tangents_iter = iter(out_carry_dot)
out_tangents = [next(out_tangents_iter) if nz else ad_util.zero
for nz in nonzeros_out]
return out_carry, out_tangents
def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
cond_jaxpr: pe.TypedJaxpr, body_nconsts: int,
body_jaxpr: pe.TypedJaxpr) -> Sequence[pe.Tracer]:
"""An implementation of partial evaluation for while.
As long as some carry (and hence output) are known and the output
of `cond_jaxpr` is known, we use a portion of the loop body to compute the known
outputs of the `while_loop`. For the unknown outputs we generate Jaxpr to run
the whole while, including recomputing the known parts.
This means that we don't actually save any computation by partial
evaluation if there are unknown outputs.
What this achieves is that we can give a proper error for reverse
differentiation of `while`, because in that use of partial evaluation the
primal inputs are considered "known", and only the tangent computation is
unknown (see issue #2129).
"""
unknowns = [not t.pval.is_known() for t in tracers]
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
# Fixpoint computation of unknown carry. Each iteration promotes
# at least one carry to unknown. We need one last iteration to prepare the jaxpr.
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_jaxpr_known, _, carry_out_uk = pe.partial_eval_jaxpr(
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk,
trace_type=trace.master.trace_type)
if carry_out_uk == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
else:
assert False, "Fixpoint not reached"
cond_jaxpr_known, _, cond_uk = pe.partial_eval_jaxpr(
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False,
trace_type=trace.master.trace_type)
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
# If conditional is unknown, or all inputs are known, or all are unknown,
# just do the default processing.
return trace.default_process_primitive(while_p, tracers, params)
# Run the known part of the while. Prepare the inputs, as constants (if known), or
# as core.unit.
in_consts = [ core.unit if uk else t.pval.get_known()
for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk,
tracers)]
# There should be no residuals for the cond_jaxpr_known
assert 1 == len(cond_jaxpr_known.out_avals)
# We ignore the residuals from the body_jaxpr_known, so the type of inputs matches
# the type of outputs; residuals are at the end
if len(body_jaxpr_known.out_avals) > len(body_jaxpr.out_avals):
# TODO(necula): this is not quite enough; we should drop the residual computations also
body_jaxpr_known.out_avals = body_jaxpr_known.out_avals[:len(body_jaxpr.out_avals)]
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:len(body_jaxpr.out_avals)]
out_known = while_p.bind(
*in_consts,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr_known,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr_known)
# Run the whole while_loop to get all the outputs, then merge with known ones
out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params)
out_tracers: Sequence[pe.Tracer] = [
out_unknown if uk
else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe)
for uk, out_unknown, known in zip(carry_uk, out_all, out_known)]
return out_tracers
def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for "
"lax.while_loop or lax.fori_loop. "
"Try using lax.scan instead.")
while_p = lax.Primitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.initial_style_translations[while_p] = _while_loop_translation_rule
ad.primitive_transposes[while_p] = _while_transpose_error
batching.primitive_batchers[while_p] = _while_loop_batching_rule
### cond
def cond(pred, true_operand, true_fun, false_operand, false_fun):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Has equivalent semantics to this Python implementation::
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
Pred has to be a scalar type, collection types (list, tuple) are not supported
"""
if len(onp.shape(pred)) != 0:
raise TypeError("Pred must be a scalar, got {} of shape {}.".format(pred, onp.shape(pred)))
try:
pred_dtype = dtypes.result_type(pred)
except TypeError as err:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred)) from err
if pred_dtype.kind != 'b':
if pred_dtype.kind in 'iuf':
pred = pred != 0
else:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred_dtype))
if jax.api._jit_is_disabled():
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
true_ops, true_tree = tree_flatten((true_operand,))
true_avals = tuple(_map(_abstractify, true_ops))
true_jaxpr, true_consts, true_out_tree = _initial_style_jaxpr(true_fun, true_tree, true_avals)
false_ops, false_tree = tree_flatten((false_operand,))
false_avals = tuple(_map(_abstractify, false_ops))
false_jaxpr, false_consts, false_out_tree = _initial_style_jaxpr(false_fun, false_tree, false_avals)
_check_tree_and_avals("true_fun and false_fun output",
true_out_tree, true_jaxpr.out_avals,
false_out_tree, false_jaxpr.out_avals)
linear = (False,) * (len(true_consts) + len(true_ops) + len(false_consts) +
len(false_ops))
out = cond_p.bind(
*itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr, linear=linear)
return tree_unflatten(true_out_tree, out)
def _cond_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, kwargs["true_jaxpr"].out_avals)
def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
pred, *args, true_jaxpr, false_jaxpr, linear):
del linear # Unused.
true_ops, false_ops = split_list(args, [len(true_jaxpr.in_avals)])
def make_computation(name, jaxpr, op_shape):
c = xb.make_computation_builder(name + '_comp')
op = xb.parameter(c, 0, op_shape)
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
_map(partial(xb.constant, c), jaxpr.literals),
extend_name_stack(name_stack, name + '_fun'), *ops)
return c.Build(xops.Tuple(c, outs))
true_op = xops.Tuple(c, true_ops)
true_c = make_computation('true', true_jaxpr, c.GetShape(true_op))
false_op = xops.Tuple(c, false_ops)
false_c = make_computation('false', false_jaxpr, c.GetShape(false_op))
return xops.Conditional(pred, true_op, true_c, false_op, false_c)
def _cond_pred_bcast_select(pred, x, y):
if core.get_aval(x) is core.get_aval(y) is core.abstract_unit:
return x
else:
bcast_pred = lax.broadcast_in_dim(pred, onp.shape(x), list(range(onp.ndim(pred))))
return lax.select(bcast_pred, x, y)
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, linear):
# TODO: maybe avoid moving arg axes to front if we're promoting to select?
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(args, dims)]
orig_bat = [d is not batching.not_mapped for d in dims]
del dims
(pred,), true_ops, false_ops = split_list(args, [1, len(true_jaxpr.in_avals)])
(pred_bat,), t_bat, f_bat = split_list(orig_bat, [1, len(true_jaxpr.in_avals)])
_, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, t_bat, False)
_, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, f_bat, False)
out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]
true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, t_bat, out_bat)
false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, f_bat, out_bat)
if pred_bat:
true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*true_ops)
false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*false_ops)
true_out = [batching.broadcast(x, size, 0) if not b else x
for x, b in zip(true_out, out_bat)]
false_out = [batching.broadcast(x, size, 0) if not b else x
for x, b in zip(false_out, out_bat)]
return [_cond_pred_bcast_select(pred, t, f)
for t, f in zip(true_out, false_out)], [0] * len(true_out)
else:
out_dims = [0 if b else batching.not_mapped for b in out_bat]
out = cond_p.bind(
*itertools.chain([pred], true_ops, false_ops),
true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched, linear=linear)
return out, out_dims
def _cond_jvp(primals, tangents, true_jaxpr, false_jaxpr, linear):
nonzeros = [t is not ad_util.zero for t in tangents]
(pred_nz,), t_nz, f_nz = split_list(nonzeros, [1, len(true_jaxpr.in_avals)])
assert pred_nz is False
_, true_out_nz = ad.jvp_jaxpr(true_jaxpr, t_nz, instantiate=False)
_, false_out_nz = ad.jvp_jaxpr(false_jaxpr, f_nz, instantiate=False)
out_nz = [a or b for a, b in zip(true_out_nz, false_out_nz)]
true_jvp, _ = ad.jvp_jaxpr(true_jaxpr, t_nz, instantiate=out_nz)
false_jvp, _ = ad.jvp_jaxpr(false_jaxpr, f_nz, instantiate=out_nz)
(pred,), tops, fops = split_list(primals, [1, len(true_jaxpr.in_avals)])
_, tops_dot, fops_dot = split_list(tangents, [1, len(true_jaxpr.in_avals)])
tops_dot = _prune_zeros(tops_dot)
fops_dot = _prune_zeros(fops_dot)
tops_lin, fops_lin = _map(tuple, split_list(linear, [len(tops)]))
linear_jvp = (tops_lin + (True,) * len(tops_dot) +
fops_lin + (True,) * len(fops_dot))
out = cond_p.bind(
*itertools.chain([pred], tops, tops_dot, fops, fops_dot),
true_jaxpr=true_jvp, false_jaxpr=false_jvp, linear=linear_jvp)
out_primals, out_tangents = split_list(out, [len(out_nz)])
out_tangents_iter = iter(out_tangents)
out_tangents = [
next(out_tangents_iter) if nz else ad_util.zero for nz in out_nz]
return out_primals, out_tangents
def _cond_partial_eval(trace, *tracers, true_jaxpr, false_jaxpr, linear):
unknowns = [t.pval[0] is not None for t in tracers]
(pred_uk,), t_uk, f_uk = split_list(unknowns, [1, len(true_jaxpr.in_avals)])
if pred_uk:
# When the predicate is unknown, we stage out the whole cond.
params = dict(true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr, linear=linear)
return trace.default_process_primitive(cond_p, tracers, params)
_, _, t_out_uks = pe.partial_eval_jaxpr(true_jaxpr, t_uk, instantiate=False,
trace_type=trace.master.trace_type)
_, _, f_out_uks = pe.partial_eval_jaxpr(false_jaxpr, f_uk, instantiate=False,
trace_type=trace.master.trace_type)
out_uks = [a or b for a, b in zip(t_out_uks, f_out_uks)]
true_jaxpr_1, true_jaxpr_2, _ = pe.partial_eval_jaxpr(true_jaxpr, t_uk,
instantiate=out_uks,
trace_type=trace.master.trace_type)
false_jaxpr_1, false_jaxpr_2, _ = pe.partial_eval_jaxpr(false_jaxpr, f_uk,
instantiate=out_uks,
trace_type=trace.master.trace_type)
num_t_res = len(true_jaxpr_1.out_avals) - len(out_uks)
num_f_res = len(false_jaxpr_1.out_avals) - len(out_uks)
# Move the residuals to front
move = [False] * len(true_jaxpr.in_avals) + [True] * num_t_res
true_jaxpr_2 = pe.move_binders_to_front(true_jaxpr_2, move)
move = [False] * len(false_jaxpr.in_avals) + [True] * num_f_res
false_jaxpr_2 = pe.move_binders_to_front(false_jaxpr_2, move)
# TODO(frostig,mattjj): pe.partial_eval_jaxpr should raise to shaped avals
t_res_avals = _map(raise_to_shaped, true_jaxpr_2.in_avals[:num_t_res])
f_res_avals = _map(raise_to_shaped, false_jaxpr_2.in_avals[:num_f_res])
assert len(true_jaxpr_2.out_avals) == len(false_jaxpr_2.out_avals)
num_outs = len(true_jaxpr_2.out_avals)
true_jaxpr_1 = _join_cond_outputs(
true_jaxpr_1, num_outs, f_res_avals, zeros_on_left=False)
false_jaxpr_1 = _join_cond_outputs(
false_jaxpr_1, num_outs, t_res_avals, zeros_on_left=True)
# TODO(frostig,mattjj): reinstate this assertion once pe.partial_eval_jaxpr
# raises to shaped avals
# assert true_jaxpr_1.out_avals == false_jaxpr_1.out_avals
num_res = num_t_res + num_f_res
_, in_consts = unzip2([t.pval for t in tracers])
out_consts_res = cond_p.bind(
*in_consts, true_jaxpr=true_jaxpr_1, false_jaxpr=false_jaxpr_1,
linear=linear)
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
# TODO(frostig,mattjj): remove raised_to_shaped of avals once
# pe.partial_eval_jaxpr handles it
out_avals = _map(raise_to_shaped, true_jaxpr_2.out_avals)
out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uks)]
pred_tracer = trace.instantiate_const(tracers[0])
ops_tracers = [trace.instantiate_const(t) if uk
else trace.new_instantiated_literal(core.unit)
for uk, t in zip(unknowns[1:], tracers[1:])]
true_ops_tracers, false_ops_tracers = split_list(
ops_tracers, [len(true_jaxpr.in_avals)])
res_tracers = _map(trace.new_instantiated_const, res)
true_res_tracers, false_res_tracers = split_list(res_tracers, [num_t_res])
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
for pv, const in zip(out_pvs, out_consts)]
tops_lin, fops_lin = _map(tuple, split_list(linear, [len(true_jaxpr.in_avals)]))
linear_2 = ((False,) * num_t_res + tops_lin + (False,) * num_f_res + fops_lin)
params = dict(true_jaxpr=true_jaxpr_2, false_jaxpr=false_jaxpr_2,
linear=linear_2)
eqn = pe.new_eqn_recipe([pred_tracer] +
true_res_tracers + true_ops_tracers +
false_res_tracers + false_ops_tracers,
out_tracers,
cond_p, params)
for t in out_tracers: t.recipe = eqn
return out_tracers
def _join_cond_outputs(jaxpr, num_prefix, zeros_avals, zeros_on_left):
@lu.wrap_init
def f_aug(*args):
prefix_and_rest = core.jaxpr_as_fun(jaxpr)(*args)
prefix, rest = split_list(prefix_and_rest, [num_prefix])
zeros = [ad_util.zeros_like_aval(a) for a in zeros_avals]
if zeros_on_left:
return prefix + zeros + rest
else:
return prefix + rest + zeros
return _make_typed_jaxpr(f_aug, jaxpr.in_avals)
def _transpose_cond_jaxpr(jaxpr, num_res):
num_non_res = len(jaxpr.in_avals) - num_res
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = _map(raise_to_shaped, primal_avals)
@lu.wrap_init
def transposed(*args):
res, cts_out = split_list(args, [num_res])
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
cts_in = ad.backward_pass(
jaxpr.jaxpr, jaxpr.literals, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
return _make_typed_jaxpr(transposed, res_avals + jaxpr.out_avals)
def _cond_transpose(cts, *args, true_jaxpr, false_jaxpr, linear):
(pred,), tops, fops = split_list(args, [1, len(true_jaxpr.in_avals)])
tops_lin, fops_lin = split_list(linear, [len(true_jaxpr.in_avals)])
in_avals = _map(raise_to_shaped, true_jaxpr.in_avals + false_jaxpr.in_avals)
num_t_res = len(tops) - sum(tops_lin)
num_f_res = len(fops) - sum(fops_lin)
t_jaxpr_trans = _transpose_cond_jaxpr(true_jaxpr, num_t_res)
f_jaxpr_trans = _transpose_cond_jaxpr(false_jaxpr, num_f_res)
lin_in_avals = _map(raise_to_shaped, [a for a, l in zip(in_avals, linear) if l])
assert t_jaxpr_trans.out_avals + f_jaxpr_trans.out_avals == lin_in_avals
t_jaxpr_trans_ = _join_cond_outputs(
t_jaxpr_trans, 0, f_jaxpr_trans.out_avals, zeros_on_left=False)
f_jaxpr_trans_ = _join_cond_outputs(
f_jaxpr_trans, 0, t_jaxpr_trans.out_avals, zeros_on_left=True)
assert t_jaxpr_trans_.out_avals == f_jaxpr_trans_.out_avals == lin_in_avals
t_res, _ = split_list(tops, [num_t_res])
f_res, _ = split_list(fops, [num_f_res])
linear_trans = ((False,) * num_t_res + (True,) * len(cts) +
(False,) * num_f_res + (True,) * len(cts))
cts = _map(ad.instantiate_zeros_aval, true_jaxpr.out_avals, cts)
out = cond_p.bind(
pred, *itertools.chain(t_res, cts, f_res, cts),
true_jaxpr=t_jaxpr_trans_, false_jaxpr=f_jaxpr_trans_,
linear=linear_trans)
assert all(_map(typecheck, lin_in_avals, out))
out_iter = iter(out)
out = [next(out_iter) if l else None for l in linear]
assert next(out_iter, None) is None
return [None] + out
def cond_bind(*args, true_jaxpr, false_jaxpr, linear):
if not core.skip_checks:
assert len(linear) + 1 == len(args)
assert len(args) == 1 + len(true_jaxpr.in_avals) + len(false_jaxpr.in_avals)
(pred,), tops, fops = split_list(args, [1, len(true_jaxpr.in_avals)])
assert all(_map(typecheck, true_jaxpr.in_avals, tops))
assert all(_map(typecheck, false_jaxpr.in_avals, fops))
core.check_jaxpr(true_jaxpr.jaxpr)
core.check_jaxpr(false_jaxpr.jaxpr)
return core.Primitive.bind(cond_p, *args, true_jaxpr=true_jaxpr,
false_jaxpr=false_jaxpr, linear=linear)
cond_p = lax.Primitive('cond')
cond_p.multiple_results = True
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.primitive_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.primitive_batchers[cond_p] = _cond_batching_rule
xla.initial_style_translations[cond_p] = _cond_translation_rule
### scan
def scan(f, init, xs, length=None, reverse=False):
"""Scan a function over leading array axes while carrying along state.
The type signature in brief is
.. code-block:: haskell
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
where we use [t] here to denote the type t with an additional leading axis.
That is, if t is an array type then [t] represents the type with an additional
leading axis, and if t is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves
each with an additional leading axis.
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
of ``scan`` are given roughly by this Python implementation::
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
types, and so multiple arrays can be scanned over at once and produce multiple
output arrays. (None is actually an empty pytree.)
Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
a single XLA While HLO. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
across all iterations (and not just be consistent up to NumPy rank/shape
broadcasting and dtype promotion rules, for example). In other words, the type
``c`` in the type signature above represents an array with a fixed shape and
dtype (or a nested tuple/list/dict container data structure with a fixed
structure and arrays with fixed shape and dtype at the leaves).
Args:
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
that ``f`` accepts two arguments where the first is a value of the loop
carry and the second is a slice of ``xs`` along its leading axis, and that
``f`` returns a pair where the first element represents a new value for
the loop carry and the second represents a slice of the output.
init: an initial loop carry value of type ``c``, which can be a scalar,
array, or any pytree (nested Python tuple/list/dict) thereof, representing
the initial loop carry value. This value must have the same structure as
the first element of the pair returned by ``f``.
xs: the value of type ``[a]`` over which to scan along the leading axis,
where ``[a]`` can be an array or any pytree (nested Python
tuple/list/dict) thereof with consistent leading axis sizes.
length: optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
be used to perform scans where no input ``xs`` are needed).
reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
Returns:
A pair of type ``(c, [b])`` where the first element represents the final
loop carry value and the second element represents the stacked outputs of
the second output of ``f`` when scanned over the leading axis of the inputs.
"""
init_flat, init_tree = tree_flatten(init)
xs_flat, xs_tree = tree_flatten(xs)
in_flat, in_tree = tree_flatten((init, xs))
try:
lengths = [x.shape[0] for x in xs_flat]
except AttributeError as err:
msg = "scan got value with no leading axis to scan over: {}."
raise ValueError(
msg.format(', '.join(str(x) for x in xs_flat
if not hasattr(x, 'shape')))) from err
if length is not None:
length = int(length)
if not all(length == l for l in lengths):
msg = ("scan got `length` argument of {} which disagrees with "
"leading axis sizes {}.")
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
else:
unique_lengths = set(lengths)
if len(unique_lengths) > 1:
msg = "scan got values with different leading axis sizes: {}."
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
elif len(unique_lengths) == 0:
msg = "scan got no values to scan over and `length` not provided."
raise ValueError(msg)
else:
length, = unique_lengths
if jax.api._jit_is_disabled():
carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
ys.append(y)
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
else jax.numpy.stack((y, *ys)))
ys = tree_multimap(stack, *maybe_reversed(ys))
return carry, ys
carry_avals = tuple(_map(_abstractify, init_flat))
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
x_dtypes = [x.dtype for x in xs_flat]
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
_check_tree_and_avals("scan carry output and input",
# Extract the subtree and avals for the first element of the return tuple
out_tree_children[0], jaxpr.out_avals[:out_tree_children[0].num_leaves],
init_tree, carry_avals)
out = scan_p.bind(*itertools.chain(consts, in_flat),
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=len(init_flat),
linear=(False,) * (len(consts) + len(in_flat)))
return tree_unflatten(out_tree, out)
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
consts, init, xs = split_list(args, [num_consts, num_carry])
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
def cond_fun(vals):
i, *_ = vals
return i < length
def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out
ys_init = _map(partial(_empty_array, length), y_avals)
if length == 0:
return init + ys_init
else:
init_val = [lax._const(length, 0)] + init + ys_init
_, *outs = while_loop(cond_fun, body_fun, init_val)
return outs
def _index_array(i, aval, x):
if aval is core.abstract_unit:
return core.unit
else:
return lax.dynamic_index_in_dim(x, i, keepdims=False)
def _empty_array(sz, aval):
if aval is core.abstract_unit:
return core.unit
else:
return lax.full((sz,) + aval.shape, 0, aval.dtype)
def _update_array(i, aval, xs, x):
if aval is core.abstract_unit:
return core.unit
else:
return lax.dynamic_update_index_in_dim(xs, x, i, 0)
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
if aval is not core.abstract_unit else aval for aval in y_avals]
return carry_avals + ys_avals
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear):
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry
nonzeros = [t is not ad_util.zero for t in tangents]
const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])
# Fixpoint computation of which carry are not ad.zero: either
# non-zero from init, or the carry out is non-zero. Each iteration promotes
# at least one carry to non-zero. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_nz.
carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
nonzeros = const_nz + carry_nz + xs_nz
jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
if carry_nz_out == carry_nz:
break
else:
carry_nz = _map(operator.or_, carry_nz, carry_nz_out)
else:
assert False, "Fixpoint not reached"
tangents = [ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t
for x, t, nz in zip(primals, tangents, nonzeros)]
consts, init, xs = split_list(primals, [num_consts, num_carry])
all_tangents = split_list(tangents, [num_consts, num_carry])
consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)
jaxpr_jvp_rearranged = ad.rearrange_binders(
jaxpr_jvp,
[num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)],
[num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)])
consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot)
+ init_linear + [True] * len(init_dot)
+ xs_linear + [True] * len(xs_dot))
out_flat = scan_p.bind(
*(consts + consts_dot + init + init_dot + xs + xs_dot),
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
num_consts=num_consts+len(consts_dot), num_carry=num_carry+len(init_dot),
linear=jaxpr_jvp_linear)
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
primals_out = carry + ys
tangents_out_iter = iter(carry_dot + ys_dot)
tangents_out = [next(tangents_out_iter) if nz else ad_util.zero
for nz in nonzeros_out]
return primals_out, tangents_out
def _prune_zeros(ts):
return [t for t in ts if t is not ad_util.zero]
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear):
if trace.master.trace_type is pe.StagingJaxprTrace:
params = {"reverse": reverse, "length": length, "num_consts": num_consts,
"num_carry": num_carry, "jaxpr": jaxpr, "linear": linear}
return trace.default_process_primitive(scan_p, tracers, params)
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry
unknowns = [t.pval[0] is not None for t in tracers]
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
# Fixpoint computation of which carry are unknown (not a constant): either
# unknown from init, or the carry out is unknown. Each iteration promotes
# at least one carry to unknown. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_uk.
carry_uk = init_uk
for _ in range(1 + len(carry_uk)):
unknowns = const_uk + carry_uk + xs_uk
jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys,
trace_type=trace.master.trace_type)
carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
# The residuals are treated as extensive outputs of jaxpr_1 (and extensive
# inputs to jaxpr_2), but residuals that are loop-invariant can be hoisted.
# TODO(mattjj): hoist other loop-invariant values here too (instantiate=False)
invariant_pvals = [pe.PartialVal.known(core.unit if uk else t.pval[1])
for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])]
other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]]
in_pvals_1 = invariant_pvals + other_pvals
untyped_jaxpr_1, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
const_avals_1 = [raise_to_shaped(core.get_aval(c)) for c in consts_1]
in_avals_1 = [core.abstract_unit] * num_consts + jaxpr_1.in_avals[num_consts:]
out_avals_1 = [core.abstract_unit if pv is None else pv for pv, c in out_pvals_1]
# TODO(cjfj): Explain the need for the code below.
for var in untyped_jaxpr_1.invars[:num_consts]:
var.aval = core.abstract_unit
jaxpr_1_opt = pe.TypedJaxpr(pe.convert_constvars_jaxpr(untyped_jaxpr_1),
(), const_avals_1 + in_avals_1, out_avals_1)
num_consts_1 = num_consts + len(consts_1)
# any now-known residuals are intensive, so we want to revise jaxpr_2 to take
# those inputs as constants rather than as extensive inputs
_, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys])
intensive_residuals = [const for pv, const in res_pvals if pv is None]
move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals]
jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
num_consts_2 = num_consts + len(intensive_residuals)
in_consts = (list(consts_1) + [core.unit] * num_consts +
[core.unit if uk else t.pval[1]
for uk, t in zip(unknowns[num_consts:], tracers[num_consts:])])
linear_1 = ([False] * len(consts_1) + [True] * num_consts +
[lin or uk for uk, lin
in zip(unknowns[num_consts:], linear[num_consts:])])
out_flat = scan_p.bind(
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1))
out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
for uk, t in zip(unknowns, tracers)]
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
out_avals = carry_avals + ys_avals
out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]
out_consts = out_carry + ys
int_res_tracers = _map(trace.new_instantiated_const, intensive_residuals)
ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
for pv, const in zip(out_pvs, out_consts)]
linear_2 = ([False] * len(int_res_tracers) +
[lin or not uk for uk, lin in zip(unknowns, linear)] +
[False] * len(ext_res_tracers))
eqn = pe.new_eqn_recipe(int_res_tracers + new_tracers + ext_res_tracers,
out_tracers, scan_p,
dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
num_consts=num_consts_2,
num_carry=num_carry, linear=tuple(linear_2)))
for t in out_tracers: t.recipe = eqn
return out_tracers
def _promote_aval_rank(sz, aval):
if aval is core.abstract_unit:
return core.abstract_unit
else:
return ShapedArray((sz,) + aval.shape, aval.dtype)
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
num_eres = len(xs_lin) - sum(xs_lin)
if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
raise NotImplementedError
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
raise NotImplementedError
if not all(init_lin):
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
consts, _, xs = split_list(args, [num_consts, num_carry])
ires, _ = split_list(consts, [num_ires])
_, eres = split_list(xs, [sum(xs_lin)])
assert not any(ad.is_undefined_primal(r) for r in ires)
assert not any(ad.is_undefined_primal(r) for r in eres)
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
ct_carry, ct_ys = split_list(cts, [num_carry])
ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
jaxpr_trans = _transpose_scan_jaxpr(
num_ires, num_consts - num_ires, num_eres, jaxpr)
linear_trans = ([False] * num_ires +
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
[False] * num_eres)
outs = scan_p.bind(
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans))
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
res1_avals, c_avals, a_avals, res2_avals = split_list(
jaxpr.in_avals, [num_res1, num_c, num_a])
num_b = len(jaxpr.out_avals)
b_avals = list(jaxpr.out_avals)
@lu.wrap_init
def transposed(*res1_cbar_bbar_res2):
res1, c_bar, b_bar, res2 = split_list(
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, primals,
b_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
_map(ad.add_tangents, c_bar, new_c_bar))
return c_bar + a_bar
return _make_typed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
def _make_typed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
out_avals, _ = unzip2(pvals_out)
return core.TypedJaxpr(jaxpr, consts, in_avals, _map(raise_to_shaped, out_avals))
def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
num_carry, linear):
num_ys = len(jaxpr.out_avals) - num_carry
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
orig_batched = [d is not batching.not_mapped for d in dims]
const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry])
# Fixpoint computation of which carry are batched: either
# batched from init, or the carry out is batched. Each iteration promotes
# at least one carry to batched. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_batched.
carry_batched = init_batched
for _ in range(1 + len(carry_batched)):
batched = const_batched + carry_batched + xs_batched
jaxpr_batched, batched_out = batching.batch_jaxpr(
jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys)
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
if carry_batched_out == carry_batched:
break
else:
carry_batched = _map(operator.or_, carry_batched, carry_batched_out)
else:
assert False, "Fixpoint not reached"
consts, init, xs = split_list(args, [num_consts, num_carry])
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(consts, consts_bdims)]
new_init = [batching.broadcast(x, size, 0) if now_batched and not was_batched
else batching.moveaxis(x, d, 0) if now_batched else x
for x, d, was_batched, now_batched in
zip(init, init_bdims, init_batched, carry_batched)]
new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1
else x for x, d in zip(xs, xs_bdims)]
new_args = new_consts + new_init + new_xs
outs = scan_p.bind(*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
num_consts=num_consts, num_carry=num_carry, linear=linear)
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
return outs, carry_bdims + ys_bdims
def _scan_shape_rule(shapes, reverse, length, jaxpr,
num_consts, num_carry, linear):
const_shexprs, init_shexprs, xs_shexprs = split_list(shapes, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_shapes = [(length,) + tuple(y_aval.shape) for y_aval in y_avals]
return init_shexprs + ys_shapes
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, reverse, length,
jaxpr, num_consts, num_carry, linear):
out_shape = _scan_shape_rule(shape_exprs, reverse, length, jaxpr,
num_consts, num_carry, linear)
dynamic_length = length.evaluate(shape_envs.logical)
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
max_length, = {x.shape[0] for x in xs}
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
out_vals = scan_p.bind(
*itertools.chain([dynamic_length] + consts, [0], init, xs),
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
num_consts=1 + num_consts, num_carry=1 + num_carry,
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear))
return out_vals[1:], out_shape
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
fun = core.jaxpr_as_fun(jaxpr)
@lu.wrap_init
def masked(*args):
[dynamic_length], consts, [i], carry, xs = split_list(
args, [1, num_consts, 1, num_carry])
out = fun(*(consts + carry + xs))
new_carry, ys = split_list(out, [num_carry])
new_carry = [lax.select(i < dynamic_length, new_c, c)
for new_c, c in zip(new_carry, carry)]
return [i + 1] + new_carry + ys
aval = ShapedArray((), dtypes.int_)
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def scan_bind(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
if not core.skip_checks:
assert len(linear) == len(args)
consts, init, xs = split_list(args, [num_consts, num_carry])
consts_avals, init_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
xs_avals = _map(partial(_promote_aval_rank, length), x_avals)
assert all(_map(typecheck, consts_avals, consts)), (consts, consts_avals)
assert all(_map(typecheck, init_avals, init))
# assert all(_map(typecheck, xs_avals, xs))
carry_avals, _ = split_list(jaxpr.out_avals, [num_carry])
assert all(_map(typematch, init_avals, carry_avals))
core.check_jaxpr(jaxpr.jaxpr)
return core.Primitive.bind(scan_p, *args, reverse=reverse, length=length,
jaxpr=jaxpr, num_consts=num_consts,
num_carry=num_carry, linear=linear)
scan_p = core.Primitive("scan")
scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(_scan_impl)
# scan_p.def_impl(partial(xla.apply_primitive, scan_p)) # TODO(mattjj): re-enable
scan_p.def_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.primitive_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.initial_style_translations[scan_p] = \
xla.lower_fun_initial_style(_scan_impl)
batching.primitive_batchers[scan_p] = _scan_batching_rule
masking.shape_parameterized_primitive_rules[scan_p] = _scan_masking_rule
def map(f, xs):
"""Map a function over leading array axes.
Like Python's builtin map, except inputs and outputs are in the form of
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
need to apply a function element by element for reduced memory usage or
heterogeneous computation with other control flow primitives.
When ``xs`` is an array type, the semantics of ``map`` are given by this
Python implementation::
def map(f, xs):
return np.stack([f(x) for x in xs])
Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
nested pytree type, and the mapped computation is compiled only once.
Args:
f: a Python function to apply element-wise over the first axis or axes of
``xs``.
xs: values over which to map along the leading axis.
Returns:
Mapped values.
"""
g = lambda _, x: ((), f(x))
_, ys = scan(g, (), xs)
return ys
def _concat_masking_rule(padded_vals, logical_shapes, dimension):
result = lax.concatenate(padded_vals, dimension) # fragmented
offset = 0
for padded_val, logical_shape in zip(padded_vals, logical_shapes):
result = _memcpy(dimension, logical_shape[dimension], padded_val,
result, offset)
offset = offset + logical_shape[dimension]
return result
def _memcpy(axis, num, src, dst, offset):
def body(i, dst):
update = lax.dynamic_index_in_dim(src, i, axis)
return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
return fori_loop(0, num, body, dst)
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
if actual_tree != expected_tree:
raise TypeError(
"{}() output pytree structure must match {}, got {} and {}."
.format(func_name, expected_name, actual_tree, expected_tree))
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
Corresponding `tree` and `avals` must match in the sense that the number of
leaves in `tree` must be equal to the length of `avals`. `what` will be
prepended to details of the mismatch in TypeError.
"""
if tree1 != tree2:
msg = ("{} must have same type structure, got {} and {}.")
raise TypeError(msg.format(what, tree1, tree2))
if not all(safe_map(typematch, avals1, avals2)):
msg = ("{} must have identical types, "
"got\n{}\nand\n{}.")
raise TypeError(msg.format(what, tree_unflatten(tree1, avals1),
tree_unflatten(tree2, avals2)))
def _stop_gradient_fun(f):
"""Create a version of f() that stops all gradients."""
def wrapper(*args, **kwargs):
args_flat, in_args_tree = tree_flatten((args, kwargs))
args_avals = tuple(_map(_abstractify, args_flat))
g = lambda a, b: f(*a, **b)
jaxpr, consts, out_tree = _initial_style_jaxpr(g, in_args_tree, args_avals)
out = core.jaxpr_as_fun(jaxpr)(*lax.stop_gradient(consts + tuple(args_flat)))
return tree_unflatten(out_tree, out)
return wrapper
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
def _split_root_args(args, const_lengths):
params_list = split_list(args, list(const_lengths))
return _RootTuple(*params_list[:-1]), params_list[-1]
def custom_root(f, initial_guess, solve, tangent_solve):
"""Differentiably solve for a roots of a function.
This is a low-level routine, mostly intended for internal use in JAX.
Gradients of custom_root() are defined with respect to closed-over variables
from the provided function ``f`` via the implicit function theorem:
https://en.wikipedia.org/wiki/Implicit_function_theorem
Args:
f: function for which to find a root. Should accept a single argument,
return a tree of arrays with the same structure as its input.
initial_guess: initial guess for a zero of f.
solve: function to solve for the roots of f. Should take two positional
arguments, f and initial_guess, and return a solution with the same
structure as initial_guess such that func(solution) = 0. In other words,
the following is assumed to be true (but not checked)::
solution = solve(f, initial_guess)
error = f(solution)
assert all(error == 0)
tangent_solve: function to solve the tangent system. Should take two
positional arguments, a linear function ``g`` (the function ``f``
linearized at its root) and a tree of array(s) ``y`` with the same
structure as initial_guess, and return a solution ``x`` such that
``g(x)=y``:
- For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
- For vector ``y``, you could use a linear solve with the Jacobian, if
dimensionality of ``y`` is not too large:
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
Returns:
The result of calling solve(f, initial_guess) with gradients defined via
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals)
in_tree, = treedef_children(in_args_tree)
_check_tree("f", "initial_guess", out_tree, in_tree)
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
partial(solve, _stop_gradient_fun(f)), in_args_tree, guess_avals)
_check_tree("solve", "initial_guess", solution_tree, in_tree)
def linearize_and_solve(x, b):
unchecked_zeros, f_jvp = jax.linearize(f, x)
return tangent_solve(f_jvp, b)
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
_check_tree("tangent_solve", "x", out_tree, in_tree)
all_consts = [f_consts, solve_consts, l_and_s_consts]
const_lengths = _RootTuple(*_map(len, all_consts))
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
out_flat = _custom_root(
const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
return tree_unflatten(out_tree, out_flat)
@partial(jax.custom_jvp, nondiff_argnums=(0, 1))
def _custom_root(const_lengths, jaxprs, *args):
params, initial_guess = _split_root_args(args, const_lengths)
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
return solution
@_custom_root.defjvp
def _root_jvp(const_lengths, jaxprs, primals, tangents):
params, _ = _split_root_args(primals, const_lengths)
solution = _custom_root(const_lengths, jaxprs, *primals)
params_dot, _ = _split_root_args(tangents, const_lengths)
# F(m, u) = 0 # system of equations in u, parameterized by m
# # solution is u*(m) defined in a neighborhood
# F(m, u*(m)) = 0 # satisfied in a neighborhood
#
# ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above
# ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange
#
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
f = core.jaxpr_as_fun(jaxprs.f)
linearize_and_solve = partial(
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
f_at_solution = lambda *params: f(*itertools.chain(params, solution))
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
params.f, params_dot.f)
solution_dot = _map(
operator.neg, linearize_and_solve(*itertools.chain(solution, rhs)))
return solution, solution_dot
class _LinearSolveTuple(collections.namedtuple(
'_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):
def transpose(self):
return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)
def _split_linear_solve_args(args, const_lengths):
params_list = split_list(args, list(const_lengths))
return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
def _transpose_function(linear_fun, primals):
"""Transpose a linear function."""
# TODO(shoyer): can we use something more direct than the vjp machinery?
# It's particularly awkward that we need the second argument to give
# particular values of the primals, which are entirely arbitrary.
_, vjp_fun = jax.vjp(linear_fun, primals)
def transposed_fun(x):
(y,) = vjp_fun(x)
return y
return transposed_fun
def _flatten(args):
return [x for arg in args for x in arg]
def _check_shapes(func_name, expected_name, actual, expected, tree):
actual_shapes = _map(onp.shape, actual)
expected_shapes = _map(onp.shape, expected)
if actual_shapes != expected_shapes:
actual_shape_tree = tree_unflatten(tree, actual_shapes)
act_shape_tree = tree_unflatten(tree, actual_shapes)
raise ValueError('{}() output shapes must match {}, got {} and {}'
.format(func_name, expected_name,
tree_unflatten(tree, actual_shapes),
tree_unflatten(tree, expected_shapes)))
def custom_linear_solve(
matvec, b, solve, transpose_solve=None, symmetric=False):
"""Perform a matrix-free linear solve with implicitly defined gradients.
This function allows for overriding or defining gradients for a linear
solve directly via implicit differentiation at the solution, rather than by
differentiating *through* the solve operation. This can sometimes be much faster
or more numerically stable, or differentiating through the solve operation
may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``).
Required invariant::
x = solve(matvec, b) # solve the linear equation
assert matvec(x) == b # not checked
Args:
matvec: linear function to invert. Must be differentiable.
b: constant right handle side of the equation. May be any nested structure
of arrays.
solve: higher level function that solves for solution to the linear
equation, i.e., ``solve(matvec, x)) == x`` for all ``x`` of the same form
as ``b``. This function need not be differentiable.
transpose_solve: higher level function for solving the transpose linear
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
the transpose of the linear map ``matvec`` (computed automatically with
autodiff). Required for backwards mode automatic differentiation, unless
``symmetric=True``, in which case ``solve`` provides the default value.
symmetric: bool indicating if it is safe to assume the linear map
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
Returns:
Result of ``solve(matvec, b)``, with gradients defined assuming that the
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
"""
if transpose_solve is None and symmetric:
transpose_solve = solve
b_flat, in_args_tree = tree_flatten((b,))
b_avals = tuple(_map(_abstractify, b_flat))
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
matvec, in_args_tree, b_avals)
tree, = treedef_children(in_args_tree)
_check_tree("matvec", "b", out_tree, tree)
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
partial(solve, matvec), in_args_tree, b_avals)
_check_tree("solve", "b", out_tree, tree)
if transpose_solve is None:
vecmat_jaxpr = tr_solve_jaxpr = None
vecmat_consts = tr_solve_consts = []
else:
if symmetric:
vecmat = matvec
vecmat_jaxpr = matvec_jaxpr
vecmat_consts = matvec_consts
else:
vecmat = _transpose_function(matvec, b)
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
vecmat, in_args_tree, b_avals)
assert out_tree == tree
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
partial(transpose_solve, vecmat), in_args_tree, b_avals)
_check_tree("transpose_solve", "b", out_tree, tree)
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
jaxprs = _LinearSolveTuple(
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
out_flat = linear_solve_p.bind(
*(_flatten(all_consts) + b_flat),
const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
return tree_unflatten(tree, out_flat)
def _linear_solve_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
def _custom_linear_solve_impl(*args, **kwargs):
const_lengths, jaxprs, tree = split_dict(
kwargs, ['const_lengths', 'jaxprs', 'tree'])
params, b = _split_linear_solve_args(args, const_lengths)
x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
_check_shapes('solve', 'b', x, b, tree)
return x
def _tangent_linear_map(func, params, params_dot, *x):
"""Compute the tangent of a linear map.
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
this function computes ``∂A @ x``.
"""
assert any(p is not ad_util.zero for p in params_dot)
zeros = [ad_util.zero] * len(x)
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
params + list(x), params_dot + zeros)
return out_tangent
def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
# A x - b = 0
# ∂A x + A ∂x - ∂b = 0
# ∂x = A^{-1} (∂b - ∂A x)
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
x = linear_solve_p.bind(*primals, **kwargs)
params, _ = _split_linear_solve_args(primals, const_lengths)
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
if all(p is ad_util.zero for p in params_dot.matvec):
# no need to evaluate matvec_tangents
rhs = b_dot
else:
matvec_tangents = _tangent_linear_map(
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x)
_check_shapes("matvec", "b", matvec_tangents, x, tree)
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
return x, x_dot
def _linear_solve_transpose_rule(cotangent, *primals, **kwargs):
const_lengths, jaxprs, tree = split_dict(
kwargs, ['const_lengths', 'jaxprs', 'tree'])
if jaxprs.transpose_solve is None:
raise TypeError('transpose_solve required for backwards mode automatic '
'differentiation of custom_linear_solve')
params, b = _split_linear_solve_args(primals, const_lengths)
assert all(ad.is_undefined_primal(x) for x in b)
cotangent_b = linear_solve_p.bind(
*(_flatten(params.transpose()) + cotangent),
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose(),
tree=tree)
return [None] * sum(const_lengths) + cotangent_b
def _linear_solve_batching_rule(args, dims, **kwargs):
const_lengths, jaxprs, tree = split_dict(kwargs,
["const_lengths", "jaxprs", "tree"])
orig_bat = [d is not batching.not_mapped for d in dims]
size, = {
a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped
}
params, b = _split_linear_solve_args(args, const_lengths)
params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)
(matvec, vecmat, solve, solve_t) = jaxprs
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
# Fixpoint computation of which parts of x and b are batched; we need to
# ensure this is consistent between all four jaxprs
b_bat = orig_b_bat
x_bat = [False] * len(solve.out_avals)
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
# Apply vecmat and solve -> new batched parts of x
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
solve, size, solve_bat + b_bat, instantiate=x_bat)
if vecmat is None:
vecmat_jaxpr_batched = None
x_bat_out = solve_x_bat
else:
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat)
x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
# Apply matvec and solve_t -> new batched parts of b
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat)
if solve_t is None:
solve_t_jaxpr_batched = None
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
else:
solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr(
solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat)
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
orig_b_bat)
if x_bat_out == x_bat and b_bat_out == b_bat:
break
else:
x_bat = x_bat_out
b_bat = b_bat_out
else:
assert False, "Fixedpoint not reached"
batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched,
solve_jaxpr_batched, solve_t_jaxpr_batched)
# Move batched axes to the front
new_params = [
batching.moveaxis(x, d, 0)
if d is not batching.not_mapped and d != 0 else x
for x, d in zip(_flatten(params), _flatten(params_dims))
]
# Broadcast out b if necessary
new_b = [
batching.broadcast(x, size, 0) if now_bat and not was_bat else
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
]
outs = linear_solve_p.bind(
*(new_params + new_b),
const_lengths=const_lengths,
jaxprs=batched_jaxprs,
tree=tree)
out_dims = [0 if batched else batching.not_mapped for batched in b_bat]
return outs, out_dims
linear_solve_p = core.Primitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.initial_style_translations[linear_solve_p] = \
xla.lower_fun_initial_style(_custom_linear_solve_impl)
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.primitive_batchers[linear_solve_p] = _linear_solve_batching_rule