mirror of
https://github.com/ROCm/jax.git
synced 2025-05-02 18:46:05 +00:00
1818 lines
78 KiB
Python
1818 lines
78 KiB
Python
# coding=utf-8
|
|
# Copyright 2019 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Control flow primitives.
|
|
"""
|
|
|
|
|
|
import collections
|
|
import functools
|
|
import itertools
|
|
import operator
|
|
import threading
|
|
from typing import Callable, Sequence
|
|
|
|
import numpy as onp
|
|
|
|
import jax
|
|
from jax import core
|
|
from jax import dtypes
|
|
from jax.lax import lax
|
|
from jax import linear_util as lu
|
|
from jax.abstract_arrays import ShapedArray, raise_to_shaped
|
|
from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs
|
|
from jax.core import get_aval
|
|
from jax.interpreters import ad
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax.interpreters import xla
|
|
from jax.interpreters import batching
|
|
from jax.interpreters import masking
|
|
from jax.lib import xla_bridge as xb
|
|
from jax.lib import xla_client
|
|
from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
|
|
split_dict, cache, extend_name_stack)
|
|
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
|
treedef_children, treedef_tuple, tree_leaves,
|
|
tree_map, tree_multimap)
|
|
from jax import ad_util
|
|
|
|
xops = xla_client.ops
|
|
|
|
_map = safe_map
|
|
zip = safe_zip
|
|
_reduce = functools.reduce
|
|
|
|
|
|
@cache()
|
|
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
|
|
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
|
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
|
with core.initial_style_staging():
|
|
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
|
wrapped_fun, in_pvals, instantiate=True, stage_out=False)
|
|
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
|
|
const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
|
|
typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr),
|
|
(), const_avals + in_avals, out_avals)
|
|
return typed_jaxpr, consts, out_tree()
|
|
|
|
def _abstractify(x):
|
|
return raise_to_shaped(core.get_aval(x))
|
|
|
|
def typecheck(aval, x):
|
|
aval = raise_to_shaped(aval).strip_weak_type()
|
|
try:
|
|
return aval == core.lattice_join(aval, core.get_aval(x)).strip_weak_type()
|
|
except TypeError:
|
|
return False
|
|
|
|
def typematch(aval1, aval2):
|
|
return (raise_to_shaped(aval1).strip_weak_type() ==
|
|
raise_to_shaped(aval2).strip_weak_type())
|
|
|
|
def _disable_jit_impl(prim, interp, *args, **kwargs):
|
|
if jax.api._jit_is_disabled():
|
|
return interp(*args, **kwargs)
|
|
else:
|
|
return xla.apply_primitive(prim, *args, **kwargs)
|
|
|
|
|
|
### fori_loop and while_loop
|
|
|
|
def _fori_cond_fun(loop_carry):
|
|
i, upper, _ = loop_carry
|
|
return lax.lt(i, upper)
|
|
|
|
@cache()
|
|
def _fori_body_fun(body_fun):
|
|
def while_body_fun(loop_carry):
|
|
i, upper, x = loop_carry
|
|
return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
|
|
return while_body_fun
|
|
|
|
@cache()
|
|
def _fori_scan_body_fun(body_fun):
|
|
def scanned_fun(loop_carry, _):
|
|
i, upper, x = loop_carry
|
|
return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None
|
|
return scanned_fun
|
|
|
|
def fori_loop(lower, upper, body_fun, init_val):
|
|
"""Loop from ``lower`` to ``upper`` by reduction to ``while_loop``.
|
|
|
|
The type signature in brief is
|
|
|
|
.. code-block:: haskell
|
|
|
|
fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a
|
|
|
|
The semantics of ``fori_loop`` are given by this Python implementation::
|
|
|
|
def fori_loop(lower, upper, body_fun, init_val):
|
|
val = init_val
|
|
for i in range(lower, upper):
|
|
val = body_fun(i, val)
|
|
return val
|
|
|
|
Unlike that Python version, ``fori_loop`` is implemented in terms of a call to
|
|
``while_loop``. See the docstring for ``while_loop`` for more information.
|
|
|
|
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
|
fixed shape and dtype across all iterations (and not just be consistent up to
|
|
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
|
other words, the type ``a`` in the type signature above represents an array
|
|
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
|
structure with a fixed structure and arrays with fixed shape and dtype at the
|
|
leaves).
|
|
|
|
Args:
|
|
lower: an integer representing the loop index lower bound (inclusive)
|
|
upper: an integer representing the loop index upper bound (exclusive)
|
|
body_fun: function of type ``(int, a) -> a``.
|
|
init_val: initial loop carry value of type ``a``.
|
|
|
|
Returns:
|
|
Loop value from the final iteration, of type ``a``.
|
|
"""
|
|
# TODO(phawkins): perhaps do more type checking here, better error messages.
|
|
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
|
|
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
|
|
if lower_dtype != upper_dtype:
|
|
msg = ("lower and upper arguments to fori_loop must have equal types, "
|
|
"got {} and {}")
|
|
raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
|
|
|
|
# If we can specialize on the trip count, call scan instead of a while_loop
|
|
# to enable efficient reverse-mode differentiation.
|
|
try:
|
|
lower_ = int(lower)
|
|
upper_ = int(upper)
|
|
except TypeError:
|
|
use_scan = False
|
|
else:
|
|
use_scan = False # TODO(mattjj): re-enable this
|
|
|
|
if use_scan:
|
|
(_, _, result), _ = scan(_fori_scan_body_fun(body_fun),
|
|
(lower, upper, init_val), None,
|
|
length=upper_ - lower_)
|
|
else:
|
|
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
|
|
(lower, upper, init_val))
|
|
return result
|
|
|
|
|
|
def while_loop(cond_fun, body_fun, init_val):
|
|
"""Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
|
|
|
|
The type signature in brief is
|
|
|
|
.. code-block:: haskell
|
|
|
|
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
|
|
|
|
The semantics of ``while_loop`` are given by this Python implementation::
|
|
|
|
def while_loop(cond_fun, body_fun, init_val):
|
|
val = init_val
|
|
while cond_fun(val):
|
|
val = body_fun(val)
|
|
return val
|
|
|
|
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
|
to a single XLA While HLO. That makes it useful for reducing compilation times
|
|
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
|
function are unrolled, leading to large XLA computations.
|
|
|
|
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
|
fixed shape and dtype across all iterations (and not just be consistent up to
|
|
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
|
other words, the type ``a`` in the type signature above represents an array
|
|
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
|
structure with a fixed structure and arrays with fixed shape and dtype at the
|
|
leaves).
|
|
|
|
Another difference from using Python-native loop constructs is that
|
|
``while_loop`` is not reverse-mode differentiable because XLA computations
|
|
require static bounds on memory requirements.
|
|
|
|
Args:
|
|
cond_fun: function of type ``a -> Bool``.
|
|
body_fun: function of type ``a -> a``.
|
|
init_val: value of type ``a``, a type that can be a scalar, array, or any
|
|
pytree (nested Python tuple/list/dict) thereof, representing the initial
|
|
loop carry value.
|
|
|
|
Returns:
|
|
The output from the final iteration of body_fun, of type ``a``.
|
|
"""
|
|
if jax.api._jit_is_disabled():
|
|
val = init_val
|
|
while cond_fun(val):
|
|
val = body_fun(val)
|
|
return val
|
|
|
|
init_vals, in_tree = tree_flatten((init_val,))
|
|
init_avals = tuple(_map(_abstractify, init_vals))
|
|
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
|
|
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
|
|
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
|
msg = "cond_fun must return a boolean scalar, but got pytree {}."
|
|
raise TypeError(msg.format(cond_tree))
|
|
if cond_jaxpr.out_avals[0].strip_weak_type() != ShapedArray((), onp.bool_):
|
|
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
|
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
|
|
|
in_tree_children = in_tree.children()
|
|
assert len(in_tree_children) == 1
|
|
_check_tree_and_avals("body_fun output and input",
|
|
# Extract the subtree and avals for the first element of the return tuple
|
|
body_tree, body_jaxpr.out_avals,
|
|
in_tree_children[0], init_avals)
|
|
outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals),
|
|
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
|
|
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
|
|
return tree_unflatten(body_tree, outs)
|
|
|
|
def _while_loop_abstract_eval(*args, **kwargs):
|
|
return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
|
|
|
|
def _while_loop_translation_rule(c, axis_env, name_stack, avals, backend, *args,
|
|
cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts):
|
|
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
|
|
batched = bool(cond_jaxpr.out_avals[0].shape)
|
|
|
|
# Since jaxprs don't have tuples and have multiple return values, but we need
|
|
# the HLO While loop to take a single tuple input and output a single boolean
|
|
# (for the cond computation) or a single tuple output (for the body
|
|
# computation), we build XLA computations that handle the tuple munging before
|
|
# generating a Call into the computations formed from the jaxprs.
|
|
|
|
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
|
|
|
|
cond_c = xb.make_computation_builder("cond_computation")
|
|
cond_carry = xb.parameter(cond_c, 0, c.GetShape(init_carry))
|
|
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
|
|
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
|
|
pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env,
|
|
_map(partial(xb.constant, cond_c),
|
|
cond_jaxpr.literals),
|
|
extend_name_stack(name_stack, 'cond'), *(x + z))
|
|
if batched:
|
|
scalar = ShapedArray((), onp.bool_)
|
|
or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
|
|
pred = xops.Reduce(cond_c, [pred], [xb.constant(cond_c, onp.array(False))], or_,
|
|
list(range(cond_jaxpr.out_avals[0].ndim)))
|
|
|
|
body_c = xb.make_computation_builder("body_computation")
|
|
body_carry = xb.parameter(body_c, 0, c.GetShape(init_carry))
|
|
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
|
|
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
|
|
new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
|
|
_map(partial(xb.constant, body_c), body_jaxpr.literals),
|
|
extend_name_stack(name_stack, 'body'), *(y + z))
|
|
if batched:
|
|
body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env,
|
|
_map(partial(xb.constant, body_c), cond_jaxpr.literals),
|
|
extend_name_stack(name_stack, 'body_pred'), *(x + z))
|
|
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
|
|
assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast
|
|
new_carry = xops.Tuple(body_c, list(itertools.chain(x, y, new_z)))
|
|
|
|
ans = xops.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
|
|
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
|
|
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
|
|
return xops.Tuple(c, z)
|
|
|
|
def _pred_bcast_select(c, pred, x, y):
|
|
pred_shape = c.GetShape(pred).dimensions()
|
|
x_shape = c.GetShape(x).dimensions()
|
|
y_shape = c.GetShape(y).dimensions()
|
|
assert x_shape == y_shape
|
|
assert pred_shape == x_shape[:len(pred_shape)] == y_shape[:len(pred_shape)]
|
|
bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape))))
|
|
return xops.Select(bcast_pred, x, y)
|
|
|
|
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr,
|
|
body_nconsts, body_jaxpr):
|
|
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
|
orig_batched = [d is not batching.not_mapped for d in dims]
|
|
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
|
|
|
|
# Fixpoint computation of which carry are batched: either
|
|
# batched from init, or the carry out is batched. Each iteration promotes
|
|
# at least one carry to batched. We need at most len(carry) iterations,
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
# carry_bat.
|
|
carry_bat = init_bat
|
|
for _ in range(1 + len(carry_bat)):
|
|
batched = bconst_bat + carry_bat
|
|
body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
|
|
body_jaxpr, size, batched, instantiate=carry_bat)
|
|
cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
|
|
cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False)
|
|
carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
|
|
if carry_bat_out == carry_bat:
|
|
break
|
|
else:
|
|
carry_bat = _map(operator.or_, carry_bat, carry_bat_out)
|
|
else:
|
|
assert False, "Fixpoint not reached"
|
|
|
|
consts, init = split_list(args, [cond_nconsts + body_nconsts])
|
|
const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
|
|
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
|
else x for x, d in zip(consts, const_dims)]
|
|
new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
|
|
else batching.moveaxis(x, d, 0) if now_bat else x
|
|
for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]
|
|
|
|
outs = while_p.bind(*(new_consts + new_init),
|
|
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
|
|
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
|
|
out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
|
|
return outs, out_bdims
|
|
|
|
def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
|
|
body_jaxpr):
|
|
nonzeros = [t is not ad_util.zero for t in tangents]
|
|
cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts])
|
|
|
|
carry_nz = init_nz
|
|
for _ in range(1 + len(carry_nz)):
|
|
body_nonzeros = bconst_nz + carry_nz
|
|
body_jvp, nonzeros_out = ad.jvp_jaxpr(
|
|
body_jaxpr, body_nonzeros, instantiate=carry_nz)
|
|
if nonzeros_out == carry_nz:
|
|
break
|
|
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
|
|
else:
|
|
assert False, "Fixpoint not reached"
|
|
|
|
nonzeros = cconst_nz + body_nonzeros
|
|
tangents = [ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t
|
|
for x, t, nz in zip(primals, tangents, nonzeros)]
|
|
|
|
cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts])
|
|
_, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts])
|
|
bconst_dot = _prune_zeros(bconst_dot)
|
|
init_dot = _prune_zeros(init_dot)
|
|
|
|
num_carry = len(primals) - cond_nconsts - body_nconsts
|
|
|
|
body_jvp_rearranged = ad.rearrange_binders(
|
|
body_jvp,
|
|
[body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
|
|
[num_carry], [len(init_dot)])
|
|
|
|
newvar = core.gensym('')
|
|
invars_aug = (
|
|
cond_jaxpr.jaxpr.invars + [newvar(get_aval(x)) for x in init_dot])
|
|
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
|
|
invars_aug,
|
|
cond_jaxpr.jaxpr.outvars,
|
|
cond_jaxpr.jaxpr.eqns)
|
|
in_avals_aug = (cond_jaxpr.in_avals[:cond_nconsts] +
|
|
body_jvp_rearranged.in_avals[body_nconsts + len(bconst_dot):])
|
|
cond_jaxpr_augmented = core.TypedJaxpr(cond_jaxpr_augmented,
|
|
cond_jaxpr.literals,
|
|
in_avals_aug,
|
|
cond_jaxpr.out_avals)
|
|
|
|
out = while_p.bind(
|
|
*(cconst + bconst + bconst_dot + init + init_dot),
|
|
cond_nconsts=cond_nconsts,
|
|
cond_jaxpr=cond_jaxpr_augmented,
|
|
body_nconsts=len(bconst) + len(bconst_dot),
|
|
body_jaxpr=body_jvp_rearranged)
|
|
|
|
out_carry, out_carry_dot = split_list(out, [num_carry])
|
|
out_tangents_iter = iter(out_carry_dot)
|
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.zero
|
|
for nz in nonzeros_out]
|
|
return out_carry, out_tangents
|
|
|
|
def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
|
|
cond_jaxpr: pe.TypedJaxpr, body_nconsts: int,
|
|
body_jaxpr: pe.TypedJaxpr) -> Sequence[pe.Tracer]:
|
|
"""An implementation of partial evaluation for while.
|
|
As long as some carry (and hence output) are known and the output
|
|
of `cond_jaxpr` is known, we use a portion of the loop body to compute the known
|
|
outputs of the `while_loop`. For the unknown outputs we generate Jaxpr to run
|
|
the whole while, including recomputing the known parts.
|
|
|
|
This means that we don't actually save any computation by partial
|
|
evaluation if there are unknown outputs.
|
|
|
|
What this achieves is that we can give a proper error for reverse
|
|
differentiation of `while`, because in that use of partial evaluation the
|
|
primal inputs are considered "known", and only the tangent computation is
|
|
unknown (see issue #2129).
|
|
"""
|
|
unknowns = [not t.pval.is_known() for t in tracers]
|
|
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
|
|
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
|
|
|
|
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
|
|
# Fixpoint computation of unknown carry. Each iteration promotes
|
|
# at least one carry to unknown. We need one last iteration to prepare the jaxpr.
|
|
carry_uk = carry_init_uk
|
|
for _ in range(1 + len(carry_uk)):
|
|
body_jaxpr_known, _, carry_out_uk = pe.partial_eval_jaxpr(
|
|
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk,
|
|
trace_type=trace.master.trace_type)
|
|
if carry_out_uk == carry_uk:
|
|
break
|
|
else:
|
|
carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
|
|
else:
|
|
assert False, "Fixpoint not reached"
|
|
|
|
cond_jaxpr_known, _, cond_uk = pe.partial_eval_jaxpr(
|
|
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False,
|
|
trace_type=trace.master.trace_type)
|
|
|
|
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
|
|
# If conditional is unknown, or all inputs are known, or all are unknown,
|
|
# just do the default processing.
|
|
return trace.default_process_primitive(while_p, tracers, params)
|
|
|
|
# Run the known part of the while. Prepare the inputs, as constants (if known), or
|
|
# as core.unit.
|
|
in_consts = [ core.unit if uk else t.pval.get_known()
|
|
for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk,
|
|
tracers)]
|
|
# There should be no residuals for the cond_jaxpr_known
|
|
assert 1 == len(cond_jaxpr_known.out_avals)
|
|
# We ignore the residuals from the body_jaxpr_known, so the type of inputs matches
|
|
# the type of outputs; residuals are at the end
|
|
if len(body_jaxpr_known.out_avals) > len(body_jaxpr.out_avals):
|
|
# TODO(necula): this is not quite enough; we should drop the residual computations also
|
|
body_jaxpr_known.out_avals = body_jaxpr_known.out_avals[:len(body_jaxpr.out_avals)]
|
|
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:len(body_jaxpr.out_avals)]
|
|
out_known = while_p.bind(
|
|
*in_consts,
|
|
cond_nconsts=cond_nconsts,
|
|
cond_jaxpr=cond_jaxpr_known,
|
|
body_nconsts=body_nconsts,
|
|
body_jaxpr=body_jaxpr_known)
|
|
|
|
# Run the whole while_loop to get all the outputs, then merge with known ones
|
|
out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params)
|
|
out_tracers: Sequence[pe.Tracer] = [
|
|
out_unknown if uk
|
|
else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe)
|
|
for uk, out_unknown, known in zip(carry_uk, out_all, out_known)]
|
|
|
|
return out_tracers
|
|
|
|
def _while_transpose_error(*_, **kwargs):
|
|
raise ValueError("Reverse-mode differentiation does not work for "
|
|
"lax.while_loop or lax.fori_loop. "
|
|
"Try using lax.scan instead.")
|
|
|
|
while_p = lax.Primitive('while')
|
|
while_p.multiple_results = True
|
|
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
|
while_p.def_abstract_eval(_while_loop_abstract_eval)
|
|
ad.primitive_jvps[while_p] = _while_loop_jvp
|
|
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
|
xla.initial_style_translations[while_p] = _while_loop_translation_rule
|
|
ad.primitive_transposes[while_p] = _while_transpose_error
|
|
batching.primitive_batchers[while_p] = _while_loop_batching_rule
|
|
|
|
|
|
### cond
|
|
|
|
def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
Has equivalent semantics to this Python implementation::
|
|
|
|
def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
|
if pred:
|
|
return true_fun(true_operand)
|
|
else:
|
|
return false_fun(false_operand)
|
|
|
|
Pred has to be a scalar type, collection types (list, tuple) are not supported
|
|
|
|
"""
|
|
|
|
if len(onp.shape(pred)) != 0:
|
|
raise TypeError("Pred must be a scalar, got {} of shape {}.".format(pred, onp.shape(pred)))
|
|
|
|
try:
|
|
pred_dtype = dtypes.result_type(pred)
|
|
except TypeError as err:
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
raise TypeError(msg.format(pred)) from err
|
|
|
|
if pred_dtype.kind != 'b':
|
|
if pred_dtype.kind in 'iuf':
|
|
pred = pred != 0
|
|
else:
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
raise TypeError(msg.format(pred_dtype))
|
|
|
|
if jax.api._jit_is_disabled():
|
|
if pred:
|
|
return true_fun(true_operand)
|
|
else:
|
|
return false_fun(false_operand)
|
|
|
|
true_ops, true_tree = tree_flatten((true_operand,))
|
|
true_avals = tuple(_map(_abstractify, true_ops))
|
|
true_jaxpr, true_consts, true_out_tree = _initial_style_jaxpr(true_fun, true_tree, true_avals)
|
|
false_ops, false_tree = tree_flatten((false_operand,))
|
|
false_avals = tuple(_map(_abstractify, false_ops))
|
|
false_jaxpr, false_consts, false_out_tree = _initial_style_jaxpr(false_fun, false_tree, false_avals)
|
|
_check_tree_and_avals("true_fun and false_fun output",
|
|
true_out_tree, true_jaxpr.out_avals,
|
|
false_out_tree, false_jaxpr.out_avals)
|
|
linear = (False,) * (len(true_consts) + len(true_ops) + len(false_consts) +
|
|
len(false_ops))
|
|
out = cond_p.bind(
|
|
*itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
|
|
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr, linear=linear)
|
|
return tree_unflatten(true_out_tree, out)
|
|
|
|
def _cond_abstract_eval(*args, **kwargs):
|
|
return _map(raise_to_shaped, kwargs["true_jaxpr"].out_avals)
|
|
|
|
def _cond_translation_rule(c, axis_env, name_stack, avals, backend,
|
|
pred, *args, true_jaxpr, false_jaxpr, linear):
|
|
del linear # Unused.
|
|
true_ops, false_ops = split_list(args, [len(true_jaxpr.in_avals)])
|
|
|
|
def make_computation(name, jaxpr, op_shape):
|
|
c = xb.make_computation_builder(name + '_comp')
|
|
op = xb.parameter(c, 0, op_shape)
|
|
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
|
|
outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
|
|
_map(partial(xb.constant, c), jaxpr.literals),
|
|
extend_name_stack(name_stack, name + '_fun'), *ops)
|
|
return c.Build(xops.Tuple(c, outs))
|
|
|
|
true_op = xops.Tuple(c, true_ops)
|
|
true_c = make_computation('true', true_jaxpr, c.GetShape(true_op))
|
|
|
|
false_op = xops.Tuple(c, false_ops)
|
|
false_c = make_computation('false', false_jaxpr, c.GetShape(false_op))
|
|
|
|
return xops.Conditional(pred, true_op, true_c, false_op, false_c)
|
|
|
|
def _cond_pred_bcast_select(pred, x, y):
|
|
if core.get_aval(x) is core.get_aval(y) is core.abstract_unit:
|
|
return x
|
|
else:
|
|
bcast_pred = lax.broadcast_in_dim(pred, onp.shape(x), list(range(onp.ndim(pred))))
|
|
return lax.select(bcast_pred, x, y)
|
|
|
|
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, linear):
|
|
# TODO: maybe avoid moving arg axes to front if we're promoting to select?
|
|
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
|
args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
|
else x for x, d in zip(args, dims)]
|
|
orig_bat = [d is not batching.not_mapped for d in dims]
|
|
del dims
|
|
(pred,), true_ops, false_ops = split_list(args, [1, len(true_jaxpr.in_avals)])
|
|
(pred_bat,), t_bat, f_bat = split_list(orig_bat, [1, len(true_jaxpr.in_avals)])
|
|
|
|
_, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, t_bat, False)
|
|
_, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, f_bat, False)
|
|
out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]
|
|
|
|
true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, t_bat, out_bat)
|
|
false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, f_bat, out_bat)
|
|
|
|
if pred_bat:
|
|
true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*true_ops)
|
|
false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*false_ops)
|
|
true_out = [batching.broadcast(x, size, 0) if not b else x
|
|
for x, b in zip(true_out, out_bat)]
|
|
false_out = [batching.broadcast(x, size, 0) if not b else x
|
|
for x, b in zip(false_out, out_bat)]
|
|
return [_cond_pred_bcast_select(pred, t, f)
|
|
for t, f in zip(true_out, false_out)], [0] * len(true_out)
|
|
else:
|
|
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
|
out = cond_p.bind(
|
|
*itertools.chain([pred], true_ops, false_ops),
|
|
true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched, linear=linear)
|
|
return out, out_dims
|
|
|
|
def _cond_jvp(primals, tangents, true_jaxpr, false_jaxpr, linear):
|
|
nonzeros = [t is not ad_util.zero for t in tangents]
|
|
|
|
(pred_nz,), t_nz, f_nz = split_list(nonzeros, [1, len(true_jaxpr.in_avals)])
|
|
assert pred_nz is False
|
|
|
|
_, true_out_nz = ad.jvp_jaxpr(true_jaxpr, t_nz, instantiate=False)
|
|
_, false_out_nz = ad.jvp_jaxpr(false_jaxpr, f_nz, instantiate=False)
|
|
out_nz = [a or b for a, b in zip(true_out_nz, false_out_nz)]
|
|
|
|
true_jvp, _ = ad.jvp_jaxpr(true_jaxpr, t_nz, instantiate=out_nz)
|
|
false_jvp, _ = ad.jvp_jaxpr(false_jaxpr, f_nz, instantiate=out_nz)
|
|
|
|
(pred,), tops, fops = split_list(primals, [1, len(true_jaxpr.in_avals)])
|
|
_, tops_dot, fops_dot = split_list(tangents, [1, len(true_jaxpr.in_avals)])
|
|
|
|
tops_dot = _prune_zeros(tops_dot)
|
|
fops_dot = _prune_zeros(fops_dot)
|
|
|
|
tops_lin, fops_lin = _map(tuple, split_list(linear, [len(tops)]))
|
|
linear_jvp = (tops_lin + (True,) * len(tops_dot) +
|
|
fops_lin + (True,) * len(fops_dot))
|
|
out = cond_p.bind(
|
|
*itertools.chain([pred], tops, tops_dot, fops, fops_dot),
|
|
true_jaxpr=true_jvp, false_jaxpr=false_jvp, linear=linear_jvp)
|
|
out_primals, out_tangents = split_list(out, [len(out_nz)])
|
|
out_tangents_iter = iter(out_tangents)
|
|
out_tangents = [
|
|
next(out_tangents_iter) if nz else ad_util.zero for nz in out_nz]
|
|
return out_primals, out_tangents
|
|
|
|
def _cond_partial_eval(trace, *tracers, true_jaxpr, false_jaxpr, linear):
|
|
unknowns = [t.pval[0] is not None for t in tracers]
|
|
|
|
(pred_uk,), t_uk, f_uk = split_list(unknowns, [1, len(true_jaxpr.in_avals)])
|
|
|
|
if pred_uk:
|
|
# When the predicate is unknown, we stage out the whole cond.
|
|
params = dict(true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr, linear=linear)
|
|
return trace.default_process_primitive(cond_p, tracers, params)
|
|
|
|
_, _, t_out_uks = pe.partial_eval_jaxpr(true_jaxpr, t_uk, instantiate=False,
|
|
trace_type=trace.master.trace_type)
|
|
_, _, f_out_uks = pe.partial_eval_jaxpr(false_jaxpr, f_uk, instantiate=False,
|
|
trace_type=trace.master.trace_type)
|
|
out_uks = [a or b for a, b in zip(t_out_uks, f_out_uks)]
|
|
|
|
true_jaxpr_1, true_jaxpr_2, _ = pe.partial_eval_jaxpr(true_jaxpr, t_uk,
|
|
instantiate=out_uks,
|
|
trace_type=trace.master.trace_type)
|
|
false_jaxpr_1, false_jaxpr_2, _ = pe.partial_eval_jaxpr(false_jaxpr, f_uk,
|
|
instantiate=out_uks,
|
|
trace_type=trace.master.trace_type)
|
|
|
|
num_t_res = len(true_jaxpr_1.out_avals) - len(out_uks)
|
|
num_f_res = len(false_jaxpr_1.out_avals) - len(out_uks)
|
|
|
|
# Move the residuals to front
|
|
move = [False] * len(true_jaxpr.in_avals) + [True] * num_t_res
|
|
true_jaxpr_2 = pe.move_binders_to_front(true_jaxpr_2, move)
|
|
move = [False] * len(false_jaxpr.in_avals) + [True] * num_f_res
|
|
false_jaxpr_2 = pe.move_binders_to_front(false_jaxpr_2, move)
|
|
|
|
# TODO(frostig,mattjj): pe.partial_eval_jaxpr should raise to shaped avals
|
|
t_res_avals = _map(raise_to_shaped, true_jaxpr_2.in_avals[:num_t_res])
|
|
f_res_avals = _map(raise_to_shaped, false_jaxpr_2.in_avals[:num_f_res])
|
|
|
|
assert len(true_jaxpr_2.out_avals) == len(false_jaxpr_2.out_avals)
|
|
num_outs = len(true_jaxpr_2.out_avals)
|
|
|
|
true_jaxpr_1 = _join_cond_outputs(
|
|
true_jaxpr_1, num_outs, f_res_avals, zeros_on_left=False)
|
|
false_jaxpr_1 = _join_cond_outputs(
|
|
false_jaxpr_1, num_outs, t_res_avals, zeros_on_left=True)
|
|
|
|
# TODO(frostig,mattjj): reinstate this assertion once pe.partial_eval_jaxpr
|
|
# raises to shaped avals
|
|
# assert true_jaxpr_1.out_avals == false_jaxpr_1.out_avals
|
|
num_res = num_t_res + num_f_res
|
|
|
|
_, in_consts = unzip2([t.pval for t in tracers])
|
|
out_consts_res = cond_p.bind(
|
|
*in_consts, true_jaxpr=true_jaxpr_1, false_jaxpr=false_jaxpr_1,
|
|
linear=linear)
|
|
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
|
|
|
|
# TODO(frostig,mattjj): remove raised_to_shaped of avals once
|
|
# pe.partial_eval_jaxpr handles it
|
|
out_avals = _map(raise_to_shaped, true_jaxpr_2.out_avals)
|
|
out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uks)]
|
|
|
|
pred_tracer = trace.instantiate_const(tracers[0])
|
|
|
|
ops_tracers = [trace.instantiate_const(t) if uk
|
|
else trace.new_instantiated_literal(core.unit)
|
|
for uk, t in zip(unknowns[1:], tracers[1:])]
|
|
true_ops_tracers, false_ops_tracers = split_list(
|
|
ops_tracers, [len(true_jaxpr.in_avals)])
|
|
|
|
res_tracers = _map(trace.new_instantiated_const, res)
|
|
true_res_tracers, false_res_tracers = split_list(res_tracers, [num_t_res])
|
|
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
|
|
for pv, const in zip(out_pvs, out_consts)]
|
|
|
|
tops_lin, fops_lin = _map(tuple, split_list(linear, [len(true_jaxpr.in_avals)]))
|
|
linear_2 = ((False,) * num_t_res + tops_lin + (False,) * num_f_res + fops_lin)
|
|
params = dict(true_jaxpr=true_jaxpr_2, false_jaxpr=false_jaxpr_2,
|
|
linear=linear_2)
|
|
eqn = pe.new_eqn_recipe([pred_tracer] +
|
|
true_res_tracers + true_ops_tracers +
|
|
false_res_tracers + false_ops_tracers,
|
|
out_tracers,
|
|
cond_p, params)
|
|
for t in out_tracers: t.recipe = eqn
|
|
return out_tracers
|
|
|
|
def _join_cond_outputs(jaxpr, num_prefix, zeros_avals, zeros_on_left):
|
|
@lu.wrap_init
|
|
def f_aug(*args):
|
|
prefix_and_rest = core.jaxpr_as_fun(jaxpr)(*args)
|
|
prefix, rest = split_list(prefix_and_rest, [num_prefix])
|
|
zeros = [ad_util.zeros_like_aval(a) for a in zeros_avals]
|
|
if zeros_on_left:
|
|
return prefix + zeros + rest
|
|
else:
|
|
return prefix + rest + zeros
|
|
|
|
return _make_typed_jaxpr(f_aug, jaxpr.in_avals)
|
|
|
|
def _transpose_cond_jaxpr(jaxpr, num_res):
|
|
num_non_res = len(jaxpr.in_avals) - num_res
|
|
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
|
primal_avals = _map(raise_to_shaped, primal_avals)
|
|
|
|
@lu.wrap_init
|
|
def transposed(*args):
|
|
res, cts_out = split_list(args, [num_res])
|
|
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
|
cts_in = ad.backward_pass(
|
|
jaxpr.jaxpr, jaxpr.literals, primals, cts_out)
|
|
_, cts_in = split_list(cts_in, [num_res])
|
|
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
|
|
|
return _make_typed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
|
|
|
def _cond_transpose(cts, *args, true_jaxpr, false_jaxpr, linear):
|
|
(pred,), tops, fops = split_list(args, [1, len(true_jaxpr.in_avals)])
|
|
tops_lin, fops_lin = split_list(linear, [len(true_jaxpr.in_avals)])
|
|
in_avals = _map(raise_to_shaped, true_jaxpr.in_avals + false_jaxpr.in_avals)
|
|
|
|
num_t_res = len(tops) - sum(tops_lin)
|
|
num_f_res = len(fops) - sum(fops_lin)
|
|
|
|
t_jaxpr_trans = _transpose_cond_jaxpr(true_jaxpr, num_t_res)
|
|
f_jaxpr_trans = _transpose_cond_jaxpr(false_jaxpr, num_f_res)
|
|
lin_in_avals = _map(raise_to_shaped, [a for a, l in zip(in_avals, linear) if l])
|
|
assert t_jaxpr_trans.out_avals + f_jaxpr_trans.out_avals == lin_in_avals
|
|
|
|
t_jaxpr_trans_ = _join_cond_outputs(
|
|
t_jaxpr_trans, 0, f_jaxpr_trans.out_avals, zeros_on_left=False)
|
|
f_jaxpr_trans_ = _join_cond_outputs(
|
|
f_jaxpr_trans, 0, t_jaxpr_trans.out_avals, zeros_on_left=True)
|
|
assert t_jaxpr_trans_.out_avals == f_jaxpr_trans_.out_avals == lin_in_avals
|
|
|
|
t_res, _ = split_list(tops, [num_t_res])
|
|
f_res, _ = split_list(fops, [num_f_res])
|
|
|
|
linear_trans = ((False,) * num_t_res + (True,) * len(cts) +
|
|
(False,) * num_f_res + (True,) * len(cts))
|
|
|
|
cts = _map(ad.instantiate_zeros_aval, true_jaxpr.out_avals, cts)
|
|
|
|
out = cond_p.bind(
|
|
pred, *itertools.chain(t_res, cts, f_res, cts),
|
|
true_jaxpr=t_jaxpr_trans_, false_jaxpr=f_jaxpr_trans_,
|
|
linear=linear_trans)
|
|
assert all(_map(typecheck, lin_in_avals, out))
|
|
|
|
out_iter = iter(out)
|
|
out = [next(out_iter) if l else None for l in linear]
|
|
assert next(out_iter, None) is None
|
|
return [None] + out
|
|
|
|
def cond_bind(*args, true_jaxpr, false_jaxpr, linear):
|
|
if not core.skip_checks:
|
|
assert len(linear) + 1 == len(args)
|
|
assert len(args) == 1 + len(true_jaxpr.in_avals) + len(false_jaxpr.in_avals)
|
|
(pred,), tops, fops = split_list(args, [1, len(true_jaxpr.in_avals)])
|
|
assert all(_map(typecheck, true_jaxpr.in_avals, tops))
|
|
assert all(_map(typecheck, false_jaxpr.in_avals, fops))
|
|
core.check_jaxpr(true_jaxpr.jaxpr)
|
|
core.check_jaxpr(false_jaxpr.jaxpr)
|
|
return core.Primitive.bind(cond_p, *args, true_jaxpr=true_jaxpr,
|
|
false_jaxpr=false_jaxpr, linear=linear)
|
|
|
|
cond_p = lax.Primitive('cond')
|
|
cond_p.multiple_results = True
|
|
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
|
cond_p.def_abstract_eval(_cond_abstract_eval)
|
|
cond_p.def_custom_bind(cond_bind)
|
|
ad.primitive_jvps[cond_p] = _cond_jvp
|
|
ad.primitive_transposes[cond_p] = _cond_transpose
|
|
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
|
batching.primitive_batchers[cond_p] = _cond_batching_rule
|
|
xla.initial_style_translations[cond_p] = _cond_translation_rule
|
|
|
|
|
|
### scan
|
|
|
|
def scan(f, init, xs, length=None, reverse=False):
|
|
"""Scan a function over leading array axes while carrying along state.
|
|
|
|
The type signature in brief is
|
|
|
|
.. code-block:: haskell
|
|
|
|
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
|
|
|
|
where we use [t] here to denote the type t with an additional leading axis.
|
|
That is, if t is an array type then [t] represents the type with an additional
|
|
leading axis, and if t is a pytree (container) type with array leaves then [t]
|
|
represents the type with the same pytree structure and corresponding leaves
|
|
each with an additional leading axis.
|
|
|
|
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
|
|
of ``scan`` are given roughly by this Python implementation::
|
|
|
|
def scan(f, init, xs, length=None):
|
|
if xs is None:
|
|
xs = [None] * length
|
|
carry = init
|
|
ys = []
|
|
for x in xs:
|
|
carry, y = f(carry, x)
|
|
ys.append(y)
|
|
return carry, np.stack(ys)
|
|
|
|
Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
|
|
types, and so multiple arrays can be scanned over at once and produce multiple
|
|
output arrays. (None is actually an empty pytree.)
|
|
|
|
Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
|
|
a single XLA While HLO. That makes it useful for reducing compilation times
|
|
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
|
function are unrolled, leading to large XLA computations.
|
|
|
|
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
|
|
across all iterations (and not just be consistent up to NumPy rank/shape
|
|
broadcasting and dtype promotion rules, for example). In other words, the type
|
|
``c`` in the type signature above represents an array with a fixed shape and
|
|
dtype (or a nested tuple/list/dict container data structure with a fixed
|
|
structure and arrays with fixed shape and dtype at the leaves).
|
|
|
|
Args:
|
|
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
|
that ``f`` accepts two arguments where the first is a value of the loop
|
|
carry and the second is a slice of ``xs`` along its leading axis, and that
|
|
``f`` returns a pair where the first element represents a new value for
|
|
the loop carry and the second represents a slice of the output.
|
|
init: an initial loop carry value of type ``c``, which can be a scalar,
|
|
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
|
the initial loop carry value. This value must have the same structure as
|
|
the first element of the pair returned by ``f``.
|
|
xs: the value of type ``[a]`` over which to scan along the leading axis,
|
|
where ``[a]`` can be an array or any pytree (nested Python
|
|
tuple/list/dict) thereof with consistent leading axis sizes.
|
|
length: optional integer specifying the number of loop iterations, which
|
|
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
|
be used to perform scans where no input ``xs`` are needed).
|
|
reverse: optional boolean specifying whether to run the scan iteration
|
|
forward (the default) or in reverse, equivalent to reversing the leading
|
|
axes of the arrays in both ``xs`` and in ``ys``.
|
|
|
|
Returns:
|
|
A pair of type ``(c, [b])`` where the first element represents the final
|
|
loop carry value and the second element represents the stacked outputs of
|
|
the second output of ``f`` when scanned over the leading axis of the inputs.
|
|
"""
|
|
init_flat, init_tree = tree_flatten(init)
|
|
xs_flat, xs_tree = tree_flatten(xs)
|
|
in_flat, in_tree = tree_flatten((init, xs))
|
|
|
|
try:
|
|
lengths = [x.shape[0] for x in xs_flat]
|
|
except AttributeError as err:
|
|
msg = "scan got value with no leading axis to scan over: {}."
|
|
raise ValueError(
|
|
msg.format(', '.join(str(x) for x in xs_flat
|
|
if not hasattr(x, 'shape')))) from err
|
|
|
|
if length is not None:
|
|
length = int(length)
|
|
if not all(length == l for l in lengths):
|
|
msg = ("scan got `length` argument of {} which disagrees with "
|
|
"leading axis sizes {}.")
|
|
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
|
|
else:
|
|
unique_lengths = set(lengths)
|
|
if len(unique_lengths) > 1:
|
|
msg = "scan got values with different leading axis sizes: {}."
|
|
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
|
elif len(unique_lengths) == 0:
|
|
msg = "scan got no values to scan over and `length` not provided."
|
|
raise ValueError(msg)
|
|
else:
|
|
length, = unique_lengths
|
|
|
|
if jax.api._jit_is_disabled():
|
|
carry = init
|
|
ys = []
|
|
maybe_reversed = reversed if reverse else lambda x: x
|
|
for i in maybe_reversed(range(length)):
|
|
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
|
|
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
|
|
ys.append(y)
|
|
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
|
|
else jax.numpy.stack((y, *ys)))
|
|
ys = tree_multimap(stack, *maybe_reversed(ys))
|
|
return carry, ys
|
|
|
|
carry_avals = tuple(_map(_abstractify, init_flat))
|
|
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
|
|
x_dtypes = [x.dtype for x in xs_flat]
|
|
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
|
|
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
|
|
out_tree_children = out_tree.children()
|
|
if len(out_tree_children) != 2:
|
|
msg = "scan body output must be a pair, got {}."
|
|
raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
|
|
_check_tree_and_avals("scan carry output and input",
|
|
# Extract the subtree and avals for the first element of the return tuple
|
|
out_tree_children[0], jaxpr.out_avals[:out_tree_children[0].num_leaves],
|
|
init_tree, carry_avals)
|
|
|
|
out = scan_p.bind(*itertools.chain(consts, in_flat),
|
|
reverse=reverse, length=length, jaxpr=jaxpr,
|
|
num_consts=len(consts), num_carry=len(init_flat),
|
|
linear=(False,) * (len(consts) + len(in_flat)))
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
|
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
|
|
def cond_fun(vals):
|
|
i, *_ = vals
|
|
return i < length
|
|
|
|
def body_fun(vals):
|
|
[i], carry, ys = split_list(vals, [1, num_carry])
|
|
i_ = length - i - 1 if reverse else i
|
|
x = _map(partial(_index_array, i_), x_avals, xs)
|
|
out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
|
|
carry_out, y_updates = split_list(out_flat, [num_carry])
|
|
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
|
|
return [i + 1] + carry_out + ys_out
|
|
|
|
ys_init = _map(partial(_empty_array, length), y_avals)
|
|
if length == 0:
|
|
return init + ys_init
|
|
else:
|
|
init_val = [lax._const(length, 0)] + init + ys_init
|
|
_, *outs = while_loop(cond_fun, body_fun, init_val)
|
|
return outs
|
|
|
|
def _index_array(i, aval, x):
|
|
if aval is core.abstract_unit:
|
|
return core.unit
|
|
else:
|
|
return lax.dynamic_index_in_dim(x, i, keepdims=False)
|
|
|
|
def _empty_array(sz, aval):
|
|
if aval is core.abstract_unit:
|
|
return core.unit
|
|
else:
|
|
return lax.full((sz,) + aval.shape, 0, aval.dtype)
|
|
|
|
def _update_array(i, aval, xs, x):
|
|
if aval is core.abstract_unit:
|
|
return core.unit
|
|
else:
|
|
return lax.dynamic_update_index_in_dim(xs, x, i, 0)
|
|
|
|
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
|
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
|
|
if aval is not core.abstract_unit else aval for aval in y_avals]
|
|
return carry_avals + ys_avals
|
|
|
|
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
|
linear):
|
|
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
|
nonzeros = [t is not ad_util.zero for t in tangents]
|
|
const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])
|
|
|
|
# Fixpoint computation of which carry are not ad.zero: either
|
|
# non-zero from init, or the carry out is non-zero. Each iteration promotes
|
|
# at least one carry to non-zero. We need at most len(carry) iterations,
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
# carry_nz.
|
|
carry_nz = init_nz
|
|
for _ in range(1 + len(carry_nz)):
|
|
nonzeros = const_nz + carry_nz + xs_nz
|
|
jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
|
|
jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
|
|
carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
|
|
if carry_nz_out == carry_nz:
|
|
break
|
|
else:
|
|
carry_nz = _map(operator.or_, carry_nz, carry_nz_out)
|
|
else:
|
|
assert False, "Fixpoint not reached"
|
|
|
|
tangents = [ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t
|
|
for x, t, nz in zip(primals, tangents, nonzeros)]
|
|
|
|
consts, init, xs = split_list(primals, [num_consts, num_carry])
|
|
all_tangents = split_list(tangents, [num_consts, num_carry])
|
|
consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)
|
|
|
|
jaxpr_jvp_rearranged = ad.rearrange_binders(
|
|
jaxpr_jvp,
|
|
[num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)],
|
|
[num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)])
|
|
|
|
consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
|
jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot)
|
|
+ init_linear + [True] * len(init_dot)
|
|
+ xs_linear + [True] * len(xs_dot))
|
|
|
|
out_flat = scan_p.bind(
|
|
*(consts + consts_dot + init + init_dot + xs + xs_dot),
|
|
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
|
|
num_consts=num_consts+len(consts_dot), num_carry=num_carry+len(init_dot),
|
|
linear=jaxpr_jvp_linear)
|
|
|
|
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
|
|
primals_out = carry + ys
|
|
tangents_out_iter = iter(carry_dot + ys_dot)
|
|
tangents_out = [next(tangents_out_iter) if nz else ad_util.zero
|
|
for nz in nonzeros_out]
|
|
return primals_out, tangents_out
|
|
|
|
def _prune_zeros(ts):
|
|
return [t for t in ts if t is not ad_util.zero]
|
|
|
|
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
|
jaxpr, linear):
|
|
if trace.master.trace_type is pe.StagingJaxprTrace:
|
|
params = {"reverse": reverse, "length": length, "num_consts": num_consts,
|
|
"num_carry": num_carry, "jaxpr": jaxpr, "linear": linear}
|
|
return trace.default_process_primitive(scan_p, tracers, params)
|
|
|
|
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
|
|
|
unknowns = [t.pval[0] is not None for t in tracers]
|
|
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
|
|
|
|
# Fixpoint computation of which carry are unknown (not a constant): either
|
|
# unknown from init, or the carry out is unknown. Each iteration promotes
|
|
# at least one carry to unknown. We need at most len(carry) iterations,
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
# carry_uk.
|
|
carry_uk = init_uk
|
|
for _ in range(1 + len(carry_uk)):
|
|
unknowns = const_uk + carry_uk + xs_uk
|
|
jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
|
|
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys,
|
|
trace_type=trace.master.trace_type)
|
|
carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
|
|
if carry_uk_out == carry_uk:
|
|
break
|
|
else:
|
|
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
|
|
else:
|
|
assert False, "Fixpoint not reached"
|
|
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
|
|
|
|
# The residuals are treated as extensive outputs of jaxpr_1 (and extensive
|
|
# inputs to jaxpr_2), but residuals that are loop-invariant can be hoisted.
|
|
# TODO(mattjj): hoist other loop-invariant values here too (instantiate=False)
|
|
invariant_pvals = [pe.PartialVal.known(core.unit if uk else t.pval[1])
|
|
for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])]
|
|
other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]]
|
|
in_pvals_1 = invariant_pvals + other_pvals
|
|
untyped_jaxpr_1, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
|
|
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
|
|
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
|
|
const_avals_1 = [raise_to_shaped(core.get_aval(c)) for c in consts_1]
|
|
in_avals_1 = [core.abstract_unit] * num_consts + jaxpr_1.in_avals[num_consts:]
|
|
out_avals_1 = [core.abstract_unit if pv is None else pv for pv, c in out_pvals_1]
|
|
|
|
# TODO(cjfj): Explain the need for the code below.
|
|
for var in untyped_jaxpr_1.invars[:num_consts]:
|
|
var.aval = core.abstract_unit
|
|
|
|
jaxpr_1_opt = pe.TypedJaxpr(pe.convert_constvars_jaxpr(untyped_jaxpr_1),
|
|
(), const_avals_1 + in_avals_1, out_avals_1)
|
|
num_consts_1 = num_consts + len(consts_1)
|
|
# any now-known residuals are intensive, so we want to revise jaxpr_2 to take
|
|
# those inputs as constants rather than as extensive inputs
|
|
_, _, res_pvals = split_list(out_pvals_1, [num_carry, num_ys])
|
|
intensive_residuals = [const for pv, const in res_pvals if pv is None]
|
|
move = [False] * len(jaxpr_1.in_avals) + [pv is None for pv, _ in res_pvals]
|
|
jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
|
|
num_consts_2 = num_consts + len(intensive_residuals)
|
|
|
|
in_consts = (list(consts_1) + [core.unit] * num_consts +
|
|
[core.unit if uk else t.pval[1]
|
|
for uk, t in zip(unknowns[num_consts:], tracers[num_consts:])])
|
|
linear_1 = ([False] * len(consts_1) + [True] * num_consts +
|
|
[lin or uk for uk, lin
|
|
in zip(unknowns[num_consts:], linear[num_consts:])])
|
|
out_flat = scan_p.bind(
|
|
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
|
|
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1))
|
|
out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
|
|
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
|
|
|
|
new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
|
|
for uk, t in zip(unknowns, tracers)]
|
|
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
|
|
out_avals = carry_avals + ys_avals
|
|
out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]
|
|
|
|
out_consts = out_carry + ys
|
|
int_res_tracers = _map(trace.new_instantiated_const, intensive_residuals)
|
|
ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals)
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
|
|
for pv, const in zip(out_pvs, out_consts)]
|
|
linear_2 = ([False] * len(int_res_tracers) +
|
|
[lin or not uk for uk, lin in zip(unknowns, linear)] +
|
|
[False] * len(ext_res_tracers))
|
|
eqn = pe.new_eqn_recipe(int_res_tracers + new_tracers + ext_res_tracers,
|
|
out_tracers, scan_p,
|
|
dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
|
|
num_consts=num_consts_2,
|
|
num_carry=num_carry, linear=tuple(linear_2)))
|
|
for t in out_tracers: t.recipe = eqn
|
|
return out_tracers
|
|
|
|
def _promote_aval_rank(sz, aval):
|
|
if aval is core.abstract_unit:
|
|
return core.abstract_unit
|
|
else:
|
|
return ShapedArray((sz,) + aval.shape, aval.dtype)
|
|
|
|
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
|
# we've only implemented transposing scans with specific lin/nonlin patterns
|
|
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
|
|
num_ires = len(consts_lin) - sum(consts_lin)
|
|
num_eres = len(xs_lin) - sum(xs_lin)
|
|
if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
|
|
raise NotImplementedError
|
|
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
|
|
raise NotImplementedError
|
|
if not all(init_lin):
|
|
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
|
|
|
|
consts, _, xs = split_list(args, [num_consts, num_carry])
|
|
ires, _ = split_list(consts, [num_ires])
|
|
_, eres = split_list(xs, [sum(xs_lin)])
|
|
assert not any(ad.is_undefined_primal(r) for r in ires)
|
|
assert not any(ad.is_undefined_primal(r) for r in eres)
|
|
|
|
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
|
|
ct_carry, ct_ys = split_list(cts, [num_carry])
|
|
ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
|
|
ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
|
|
ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
|
|
|
|
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
|
|
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
|
|
jaxpr_trans = _transpose_scan_jaxpr(
|
|
num_ires, num_consts - num_ires, num_eres, jaxpr)
|
|
linear_trans = ([False] * num_ires +
|
|
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
|
|
[False] * num_eres)
|
|
|
|
outs = scan_p.bind(
|
|
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
|
|
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
|
|
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans))
|
|
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
|
|
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
|
|
|
|
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
|
|
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
|
|
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
|
|
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
|
|
res1_avals, c_avals, a_avals, res2_avals = split_list(
|
|
jaxpr.in_avals, [num_res1, num_c, num_a])
|
|
num_b = len(jaxpr.out_avals)
|
|
b_avals = list(jaxpr.out_avals)
|
|
|
|
@lu.wrap_init
|
|
def transposed(*res1_cbar_bbar_res2):
|
|
res1, c_bar, b_bar, res2 = split_list(
|
|
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
|
|
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
|
|
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
|
|
cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, primals,
|
|
b_bar)
|
|
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
|
|
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
|
|
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
|
|
_map(ad.add_tangents, c_bar, new_c_bar))
|
|
return c_bar + a_bar
|
|
return _make_typed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
|
|
|
|
def _make_typed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
|
pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
|
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
|
|
out_avals, _ = unzip2(pvals_out)
|
|
return core.TypedJaxpr(jaxpr, consts, in_avals, _map(raise_to_shaped, out_avals))
|
|
|
|
|
|
def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
|
|
num_carry, linear):
|
|
num_ys = len(jaxpr.out_avals) - num_carry
|
|
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
|
orig_batched = [d is not batching.not_mapped for d in dims]
|
|
const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry])
|
|
|
|
# Fixpoint computation of which carry are batched: either
|
|
# batched from init, or the carry out is batched. Each iteration promotes
|
|
# at least one carry to batched. We need at most len(carry) iterations,
|
|
# but we need one last iteration to prepare the jaxpr based on the final
|
|
# carry_batched.
|
|
carry_batched = init_batched
|
|
for _ in range(1 + len(carry_batched)):
|
|
batched = const_batched + carry_batched + xs_batched
|
|
jaxpr_batched, batched_out = batching.batch_jaxpr(
|
|
jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys)
|
|
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
|
|
if carry_batched_out == carry_batched:
|
|
break
|
|
else:
|
|
carry_batched = _map(operator.or_, carry_batched, carry_batched_out)
|
|
else:
|
|
assert False, "Fixpoint not reached"
|
|
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
|
|
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
|
else x for x, d in zip(consts, consts_bdims)]
|
|
new_init = [batching.broadcast(x, size, 0) if now_batched and not was_batched
|
|
else batching.moveaxis(x, d, 0) if now_batched else x
|
|
for x, d, was_batched, now_batched in
|
|
zip(init, init_bdims, init_batched, carry_batched)]
|
|
new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1
|
|
else x for x, d in zip(xs, xs_bdims)]
|
|
new_args = new_consts + new_init + new_xs
|
|
|
|
outs = scan_p.bind(*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
|
|
num_consts=num_consts, num_carry=num_carry, linear=linear)
|
|
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
|
|
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
|
|
return outs, carry_bdims + ys_bdims
|
|
|
|
def _scan_shape_rule(shapes, reverse, length, jaxpr,
|
|
num_consts, num_carry, linear):
|
|
const_shexprs, init_shexprs, xs_shexprs = split_list(shapes, [num_consts, num_carry])
|
|
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
|
ys_shapes = [(length,) + tuple(y_aval.shape) for y_aval in y_avals]
|
|
return init_shexprs + ys_shapes
|
|
|
|
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, reverse, length,
|
|
jaxpr, num_consts, num_carry, linear):
|
|
out_shape = _scan_shape_rule(shape_exprs, reverse, length, jaxpr,
|
|
num_consts, num_carry, linear)
|
|
dynamic_length = length.evaluate(shape_envs.logical)
|
|
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
|
|
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
|
|
max_length, = {x.shape[0] for x in xs}
|
|
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
|
out_vals = scan_p.bind(
|
|
*itertools.chain([dynamic_length] + consts, [0], init, xs),
|
|
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
|
|
num_consts=1 + num_consts, num_carry=1 + num_carry,
|
|
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear))
|
|
return out_vals[1:], out_shape
|
|
|
|
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
|
fun = core.jaxpr_as_fun(jaxpr)
|
|
|
|
@lu.wrap_init
|
|
def masked(*args):
|
|
[dynamic_length], consts, [i], carry, xs = split_list(
|
|
args, [1, num_consts, 1, num_carry])
|
|
out = fun(*(consts + carry + xs))
|
|
new_carry, ys = split_list(out, [num_carry])
|
|
new_carry = [lax.select(i < dynamic_length, new_c, c)
|
|
for new_c, c in zip(new_carry, carry)]
|
|
return [i + 1] + new_carry + ys
|
|
|
|
aval = ShapedArray((), dtypes.int_)
|
|
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
|
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
|
|
|
def scan_bind(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
|
if not core.skip_checks:
|
|
assert len(linear) == len(args)
|
|
consts, init, xs = split_list(args, [num_consts, num_carry])
|
|
consts_avals, init_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
|
xs_avals = _map(partial(_promote_aval_rank, length), x_avals)
|
|
assert all(_map(typecheck, consts_avals, consts)), (consts, consts_avals)
|
|
assert all(_map(typecheck, init_avals, init))
|
|
# assert all(_map(typecheck, xs_avals, xs))
|
|
carry_avals, _ = split_list(jaxpr.out_avals, [num_carry])
|
|
assert all(_map(typematch, init_avals, carry_avals))
|
|
core.check_jaxpr(jaxpr.jaxpr)
|
|
return core.Primitive.bind(scan_p, *args, reverse=reverse, length=length,
|
|
jaxpr=jaxpr, num_consts=num_consts,
|
|
num_carry=num_carry, linear=linear)
|
|
|
|
scan_p = core.Primitive("scan")
|
|
scan_p.multiple_results = True
|
|
scan_p.def_custom_bind(scan_bind)
|
|
scan_p.def_impl(_scan_impl)
|
|
# scan_p.def_impl(partial(xla.apply_primitive, scan_p)) # TODO(mattjj): re-enable
|
|
scan_p.def_abstract_eval(_scan_abstract_eval)
|
|
ad.primitive_jvps[scan_p] = _scan_jvp
|
|
ad.primitive_transposes[scan_p] = _scan_transpose
|
|
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
|
xla.initial_style_translations[scan_p] = \
|
|
xla.lower_fun_initial_style(_scan_impl)
|
|
batching.primitive_batchers[scan_p] = _scan_batching_rule
|
|
masking.shape_parameterized_primitive_rules[scan_p] = _scan_masking_rule
|
|
|
|
|
|
def map(f, xs):
|
|
"""Map a function over leading array axes.
|
|
|
|
Like Python's builtin map, except inputs and outputs are in the form of
|
|
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
|
|
need to apply a function element by element for reduced memory usage or
|
|
heterogeneous computation with other control flow primitives.
|
|
|
|
When ``xs`` is an array type, the semantics of ``map`` are given by this
|
|
Python implementation::
|
|
|
|
def map(f, xs):
|
|
return np.stack([f(x) for x in xs])
|
|
|
|
Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
|
|
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
|
|
nested pytree type, and the mapped computation is compiled only once.
|
|
|
|
Args:
|
|
f: a Python function to apply element-wise over the first axis or axes of
|
|
``xs``.
|
|
xs: values over which to map along the leading axis.
|
|
|
|
Returns:
|
|
Mapped values.
|
|
"""
|
|
g = lambda _, x: ((), f(x))
|
|
_, ys = scan(g, (), xs)
|
|
return ys
|
|
|
|
|
|
def _concat_masking_rule(padded_vals, logical_shapes, dimension):
|
|
result = lax.concatenate(padded_vals, dimension) # fragmented
|
|
offset = 0
|
|
for padded_val, logical_shape in zip(padded_vals, logical_shapes):
|
|
result = _memcpy(dimension, logical_shape[dimension], padded_val,
|
|
result, offset)
|
|
offset = offset + logical_shape[dimension]
|
|
return result
|
|
|
|
def _memcpy(axis, num, src, dst, offset):
|
|
def body(i, dst):
|
|
update = lax.dynamic_index_in_dim(src, i, axis)
|
|
return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
|
|
return fori_loop(0, num, body, dst)
|
|
|
|
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule
|
|
|
|
|
|
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
|
|
if actual_tree != expected_tree:
|
|
raise TypeError(
|
|
"{}() output pytree structure must match {}, got {} and {}."
|
|
.format(func_name, expected_name, actual_tree, expected_tree))
|
|
|
|
|
|
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
|
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
|
|
|
|
Corresponding `tree` and `avals` must match in the sense that the number of
|
|
leaves in `tree` must be equal to the length of `avals`. `what` will be
|
|
prepended to details of the mismatch in TypeError.
|
|
"""
|
|
if tree1 != tree2:
|
|
msg = ("{} must have same type structure, got {} and {}.")
|
|
raise TypeError(msg.format(what, tree1, tree2))
|
|
if not all(safe_map(typematch, avals1, avals2)):
|
|
msg = ("{} must have identical types, "
|
|
"got\n{}\nand\n{}.")
|
|
raise TypeError(msg.format(what, tree_unflatten(tree1, avals1),
|
|
tree_unflatten(tree2, avals2)))
|
|
|
|
|
|
def _stop_gradient_fun(f):
|
|
"""Create a version of f() that stops all gradients."""
|
|
def wrapper(*args, **kwargs):
|
|
args_flat, in_args_tree = tree_flatten((args, kwargs))
|
|
args_avals = tuple(_map(_abstractify, args_flat))
|
|
g = lambda a, b: f(*a, **b)
|
|
jaxpr, consts, out_tree = _initial_style_jaxpr(g, in_args_tree, args_avals)
|
|
out = core.jaxpr_as_fun(jaxpr)(*lax.stop_gradient(consts + tuple(args_flat)))
|
|
return tree_unflatten(out_tree, out)
|
|
return wrapper
|
|
|
|
|
|
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
|
|
|
|
|
|
def _split_root_args(args, const_lengths):
|
|
params_list = split_list(args, list(const_lengths))
|
|
return _RootTuple(*params_list[:-1]), params_list[-1]
|
|
|
|
|
|
def custom_root(f, initial_guess, solve, tangent_solve):
|
|
"""Differentiably solve for a roots of a function.
|
|
|
|
This is a low-level routine, mostly intended for internal use in JAX.
|
|
Gradients of custom_root() are defined with respect to closed-over variables
|
|
from the provided function ``f`` via the implicit function theorem:
|
|
https://en.wikipedia.org/wiki/Implicit_function_theorem
|
|
|
|
Args:
|
|
f: function for which to find a root. Should accept a single argument,
|
|
return a tree of arrays with the same structure as its input.
|
|
initial_guess: initial guess for a zero of f.
|
|
solve: function to solve for the roots of f. Should take two positional
|
|
arguments, f and initial_guess, and return a solution with the same
|
|
structure as initial_guess such that func(solution) = 0. In other words,
|
|
the following is assumed to be true (but not checked)::
|
|
|
|
solution = solve(f, initial_guess)
|
|
error = f(solution)
|
|
assert all(error == 0)
|
|
|
|
tangent_solve: function to solve the tangent system. Should take two
|
|
positional arguments, a linear function ``g`` (the function ``f``
|
|
linearized at its root) and a tree of array(s) ``y`` with the same
|
|
structure as initial_guess, and return a solution ``x`` such that
|
|
``g(x)=y``:
|
|
|
|
- For scalar ``y``, use ``lambda g, y: y / g(1.0)``.
|
|
- For vector ``y``, you could use a linear solve with the Jacobian, if
|
|
dimensionality of ``y`` is not too large:
|
|
``lambda g, y: np.linalg.solve(jacobian(g)(y), y)``.
|
|
|
|
Returns:
|
|
The result of calling solve(f, initial_guess) with gradients defined via
|
|
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
|
|
"""
|
|
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
|
guess_avals = tuple(_map(_abstractify, guess_flat))
|
|
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
|
f, in_args_tree, guess_avals)
|
|
|
|
in_tree, = treedef_children(in_args_tree)
|
|
_check_tree("f", "initial_guess", out_tree, in_tree)
|
|
|
|
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
|
|
partial(solve, _stop_gradient_fun(f)), in_args_tree, guess_avals)
|
|
_check_tree("solve", "initial_guess", solution_tree, in_tree)
|
|
|
|
def linearize_and_solve(x, b):
|
|
unchecked_zeros, f_jvp = jax.linearize(f, x)
|
|
return tangent_solve(f_jvp, b)
|
|
|
|
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
|
|
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
|
|
_check_tree("tangent_solve", "x", out_tree, in_tree)
|
|
|
|
all_consts = [f_consts, solve_consts, l_and_s_consts]
|
|
const_lengths = _RootTuple(*_map(len, all_consts))
|
|
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
|
|
|
|
out_flat = _custom_root(
|
|
const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
|
|
return tree_unflatten(out_tree, out_flat)
|
|
|
|
|
|
@partial(jax.custom_jvp, nondiff_argnums=(0, 1))
|
|
def _custom_root(const_lengths, jaxprs, *args):
|
|
params, initial_guess = _split_root_args(args, const_lengths)
|
|
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
|
|
return solution
|
|
|
|
|
|
@_custom_root.defjvp
|
|
def _root_jvp(const_lengths, jaxprs, primals, tangents):
|
|
params, _ = _split_root_args(primals, const_lengths)
|
|
solution = _custom_root(const_lengths, jaxprs, *primals)
|
|
|
|
params_dot, _ = _split_root_args(tangents, const_lengths)
|
|
|
|
# F(m, u) = 0 # system of equations in u, parameterized by m
|
|
# # solution is u*(m) defined in a neighborhood
|
|
# F(m, u*(m)) = 0 # satisfied in a neighborhood
|
|
#
|
|
# ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above
|
|
# ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange
|
|
#
|
|
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
|
|
|
|
f = core.jaxpr_as_fun(jaxprs.f)
|
|
linearize_and_solve = partial(
|
|
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
|
|
f_at_solution = lambda *params: f(*itertools.chain(params, solution))
|
|
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
|
|
params.f, params_dot.f)
|
|
solution_dot = _map(
|
|
operator.neg, linearize_and_solve(*itertools.chain(solution, rhs)))
|
|
|
|
return solution, solution_dot
|
|
|
|
|
|
class _LinearSolveTuple(collections.namedtuple(
|
|
'_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):
|
|
|
|
def transpose(self):
|
|
return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)
|
|
|
|
|
|
def _split_linear_solve_args(args, const_lengths):
|
|
params_list = split_list(args, list(const_lengths))
|
|
return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
|
|
|
|
|
|
def _transpose_function(linear_fun, primals):
|
|
"""Transpose a linear function."""
|
|
# TODO(shoyer): can we use something more direct than the vjp machinery?
|
|
# It's particularly awkward that we need the second argument to give
|
|
# particular values of the primals, which are entirely arbitrary.
|
|
_, vjp_fun = jax.vjp(linear_fun, primals)
|
|
|
|
def transposed_fun(x):
|
|
(y,) = vjp_fun(x)
|
|
return y
|
|
|
|
return transposed_fun
|
|
|
|
|
|
def _flatten(args):
|
|
return [x for arg in args for x in arg]
|
|
|
|
|
|
def _check_shapes(func_name, expected_name, actual, expected, tree):
|
|
actual_shapes = _map(onp.shape, actual)
|
|
expected_shapes = _map(onp.shape, expected)
|
|
if actual_shapes != expected_shapes:
|
|
actual_shape_tree = tree_unflatten(tree, actual_shapes)
|
|
act_shape_tree = tree_unflatten(tree, actual_shapes)
|
|
raise ValueError('{}() output shapes must match {}, got {} and {}'
|
|
.format(func_name, expected_name,
|
|
tree_unflatten(tree, actual_shapes),
|
|
tree_unflatten(tree, expected_shapes)))
|
|
|
|
|
|
def custom_linear_solve(
|
|
matvec, b, solve, transpose_solve=None, symmetric=False):
|
|
"""Perform a matrix-free linear solve with implicitly defined gradients.
|
|
|
|
This function allows for overriding or defining gradients for a linear
|
|
solve directly via implicit differentiation at the solution, rather than by
|
|
differentiating *through* the solve operation. This can sometimes be much faster
|
|
or more numerically stable, or differentiating through the solve operation
|
|
may not even be implemented (e.g., if ``solve`` uses ``lax.while_loop``).
|
|
|
|
Required invariant::
|
|
|
|
x = solve(matvec, b) # solve the linear equation
|
|
assert matvec(x) == b # not checked
|
|
|
|
Args:
|
|
matvec: linear function to invert. Must be differentiable.
|
|
b: constant right handle side of the equation. May be any nested structure
|
|
of arrays.
|
|
solve: higher level function that solves for solution to the linear
|
|
equation, i.e., ``solve(matvec, x)) == x`` for all ``x`` of the same form
|
|
as ``b``. This function need not be differentiable.
|
|
transpose_solve: higher level function for solving the transpose linear
|
|
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
|
|
the transpose of the linear map ``matvec`` (computed automatically with
|
|
autodiff). Required for backwards mode automatic differentiation, unless
|
|
``symmetric=True``, in which case ``solve`` provides the default value.
|
|
symmetric: bool indicating if it is safe to assume the linear map
|
|
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
|
|
|
|
Returns:
|
|
Result of ``solve(matvec, b)``, with gradients defined assuming that the
|
|
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
|
|
"""
|
|
if transpose_solve is None and symmetric:
|
|
transpose_solve = solve
|
|
|
|
b_flat, in_args_tree = tree_flatten((b,))
|
|
b_avals = tuple(_map(_abstractify, b_flat))
|
|
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
|
|
matvec, in_args_tree, b_avals)
|
|
|
|
tree, = treedef_children(in_args_tree)
|
|
_check_tree("matvec", "b", out_tree, tree)
|
|
|
|
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
|
|
partial(solve, matvec), in_args_tree, b_avals)
|
|
_check_tree("solve", "b", out_tree, tree)
|
|
|
|
if transpose_solve is None:
|
|
vecmat_jaxpr = tr_solve_jaxpr = None
|
|
vecmat_consts = tr_solve_consts = []
|
|
else:
|
|
if symmetric:
|
|
vecmat = matvec
|
|
vecmat_jaxpr = matvec_jaxpr
|
|
vecmat_consts = matvec_consts
|
|
else:
|
|
vecmat = _transpose_function(matvec, b)
|
|
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
|
|
vecmat, in_args_tree, b_avals)
|
|
assert out_tree == tree
|
|
|
|
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
|
|
partial(transpose_solve, vecmat), in_args_tree, b_avals)
|
|
_check_tree("transpose_solve", "b", out_tree, tree)
|
|
|
|
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
|
|
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
|
|
jaxprs = _LinearSolveTuple(
|
|
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
|
|
|
|
out_flat = linear_solve_p.bind(
|
|
*(_flatten(all_consts) + b_flat),
|
|
const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
|
|
return tree_unflatten(tree, out_flat)
|
|
|
|
|
|
def _linear_solve_abstract_eval(*args, **kwargs):
|
|
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
|
|
|
|
|
|
def _custom_linear_solve_impl(*args, **kwargs):
|
|
const_lengths, jaxprs, tree = split_dict(
|
|
kwargs, ['const_lengths', 'jaxprs', 'tree'])
|
|
params, b = _split_linear_solve_args(args, const_lengths)
|
|
x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
|
|
_check_shapes('solve', 'b', x, b, tree)
|
|
return x
|
|
|
|
|
|
def _tangent_linear_map(func, params, params_dot, *x):
|
|
"""Compute the tangent of a linear map.
|
|
|
|
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
|
|
this function computes ``∂A @ x``.
|
|
"""
|
|
assert any(p is not ad_util.zero for p in params_dot)
|
|
zeros = [ad_util.zero] * len(x)
|
|
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
|
|
params + list(x), params_dot + zeros)
|
|
return out_tangent
|
|
|
|
|
|
def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
|
|
# A x - b = 0
|
|
# ∂A x + A ∂x - ∂b = 0
|
|
# ∂x = A^{-1} (∂b - ∂A x)
|
|
|
|
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
|
|
x = linear_solve_p.bind(*primals, **kwargs)
|
|
|
|
params, _ = _split_linear_solve_args(primals, const_lengths)
|
|
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
|
|
|
|
if all(p is ad_util.zero for p in params_dot.matvec):
|
|
# no need to evaluate matvec_tangents
|
|
rhs = b_dot
|
|
else:
|
|
matvec_tangents = _tangent_linear_map(
|
|
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x)
|
|
_check_shapes("matvec", "b", matvec_tangents, x, tree)
|
|
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
|
|
|
|
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
|
|
|
|
return x, x_dot
|
|
|
|
|
|
def _linear_solve_transpose_rule(cotangent, *primals, **kwargs):
|
|
const_lengths, jaxprs, tree = split_dict(
|
|
kwargs, ['const_lengths', 'jaxprs', 'tree'])
|
|
|
|
if jaxprs.transpose_solve is None:
|
|
raise TypeError('transpose_solve required for backwards mode automatic '
|
|
'differentiation of custom_linear_solve')
|
|
|
|
params, b = _split_linear_solve_args(primals, const_lengths)
|
|
assert all(ad.is_undefined_primal(x) for x in b)
|
|
cotangent_b = linear_solve_p.bind(
|
|
*(_flatten(params.transpose()) + cotangent),
|
|
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose(),
|
|
tree=tree)
|
|
return [None] * sum(const_lengths) + cotangent_b
|
|
|
|
|
|
def _linear_solve_batching_rule(args, dims, **kwargs):
|
|
const_lengths, jaxprs, tree = split_dict(kwargs,
|
|
["const_lengths", "jaxprs", "tree"])
|
|
orig_bat = [d is not batching.not_mapped for d in dims]
|
|
size, = {
|
|
a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped
|
|
}
|
|
|
|
params, b = _split_linear_solve_args(args, const_lengths)
|
|
params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
|
|
params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)
|
|
|
|
(matvec, vecmat, solve, solve_t) = jaxprs
|
|
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
|
|
|
|
# Fixpoint computation of which parts of x and b are batched; we need to
|
|
# ensure this is consistent between all four jaxprs
|
|
b_bat = orig_b_bat
|
|
x_bat = [False] * len(solve.out_avals)
|
|
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
|
|
# Apply vecmat and solve -> new batched parts of x
|
|
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
|
|
solve, size, solve_bat + b_bat, instantiate=x_bat)
|
|
if vecmat is None:
|
|
vecmat_jaxpr_batched = None
|
|
x_bat_out = solve_x_bat
|
|
else:
|
|
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
|
|
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat)
|
|
x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
|
|
# Apply matvec and solve_t -> new batched parts of b
|
|
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
|
|
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat)
|
|
if solve_t is None:
|
|
solve_t_jaxpr_batched = None
|
|
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
|
|
else:
|
|
solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr(
|
|
solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat)
|
|
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
|
|
orig_b_bat)
|
|
if x_bat_out == x_bat and b_bat_out == b_bat:
|
|
break
|
|
else:
|
|
x_bat = x_bat_out
|
|
b_bat = b_bat_out
|
|
else:
|
|
assert False, "Fixedpoint not reached"
|
|
|
|
batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched, vecmat_jaxpr_batched,
|
|
solve_jaxpr_batched, solve_t_jaxpr_batched)
|
|
|
|
# Move batched axes to the front
|
|
new_params = [
|
|
batching.moveaxis(x, d, 0)
|
|
if d is not batching.not_mapped and d != 0 else x
|
|
for x, d in zip(_flatten(params), _flatten(params_dims))
|
|
]
|
|
# Broadcast out b if necessary
|
|
new_b = [
|
|
batching.broadcast(x, size, 0) if now_bat and not was_bat else
|
|
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
|
|
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
|
|
]
|
|
|
|
outs = linear_solve_p.bind(
|
|
*(new_params + new_b),
|
|
const_lengths=const_lengths,
|
|
jaxprs=batched_jaxprs,
|
|
tree=tree)
|
|
out_dims = [0 if batched else batching.not_mapped for batched in b_bat]
|
|
return outs, out_dims
|
|
|
|
|
|
linear_solve_p = core.Primitive('custom_linear_solve')
|
|
linear_solve_p.multiple_results = True
|
|
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
|
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
|
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
|
xla.initial_style_translations[linear_solve_p] = \
|
|
xla.lower_fun_initial_style(_custom_linear_solve_impl)
|
|
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
|
|
batching.primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
|