1998 lines
86 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the loop primitives."""
from functools import partial
import itertools
import operator
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar
import jax
import weakref
from jax import core
from jax import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
import jax._src.pretty_printer as pp
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map)
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (
cache,
extend_name_stack,
partition_list,
safe_map,
safe_zip,
split_list,
unzip2,
weakref_lru_cache,
)
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_avals_short,
_check_tree_and_avals,
_initial_style_jaxpr,
_make_closed_jaxpr,
_prune_zeros,
_typecheck_param,
allowed_effects,
)
_map = safe_map
zip = safe_zip
T = TypeVar('T')
Array = Any
BooleanNumeric = Any # A bool, or a Boolean array.
### Helper functions
def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
"""Promote weakly-typed in_vals to be compatible with out_avals.
Args:
in_vals : flattened list of input values.
in_avals : corresponding list of avals.
out_avals : list of target output avals.
Returns:
in_vals_new : flattened list of modified in_vals with no weak types.
changed : bool; true if in_vals required modification.
"""
if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
# Calling function is responsible for catching this.
return in_vals, False
weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
if not weak_mismatches:
return in_vals, False
for i in weak_mismatches:
new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
return in_vals, True
### scan
Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')
@api_boundary
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1) -> Tuple[Carry, Y]:
"""Scan a function over leading array axes while carrying along state.
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
where we use [t] here to denote the type t with an additional leading axis.
That is, if t is an array type then [t] represents the type with an additional
leading axis, and if t is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves
each with an additional leading axis.
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
of ``scan`` are given roughly by this Python implementation::
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
types, and so multiple arrays can be scanned over at once and produce multiple
output arrays. (None is actually an empty pytree.)
Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
a single XLA While HLO. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
across all iterations (and not just be consistent up to NumPy rank/shape
broadcasting and dtype promotion rules, for example). In other words, the type
``c`` in the type signature above represents an array with a fixed shape and
dtype (or a nested tuple/list/dict container data structure with a fixed
structure and arrays with fixed shape and dtype at the leaves).
.. note::
:py:func:`scan` compiles ``f``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
Args:
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
that ``f`` accepts two arguments where the first is a value of the loop
carry and the second is a slice of ``xs`` along its leading axis, and that
``f`` returns a pair where the first element represents a new value for
the loop carry and the second represents a slice of the output.
init: an initial loop carry value of type ``c``, which can be a scalar,
array, or any pytree (nested Python tuple/list/dict) thereof, representing
the initial loop carry value. This value must have the same structure as
the first element of the pair returned by ``f``.
xs: the value of type ``[a]`` over which to scan along the leading axis,
where ``[a]`` can be an array or any pytree (nested Python
tuple/list/dict) thereof with consistent leading axis sizes.
length: optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
be used to perform scans where no input ``xs`` are needed).
reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
unroll: optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
Returns:
A pair of type ``(c, [b])`` where the first element represents the final
loop carry value and the second element represents the stacked outputs of
the second output of ``f`` when scanned over the leading axis of the inputs.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not callable(f):
raise TypeError("lax.scan: f argument should be a callable.")
xs_flat, xs_tree = tree_flatten(xs)
try:
lengths = [x.shape[0] for x in xs_flat]
except AttributeError as err:
msg = "scan got value with no leading axis to scan over: {}."
raise ValueError(
msg.format(', '.join(str(x) for x in xs_flat
if not hasattr(x, 'shape')))) from err
if length is not None:
length = int(length)
if not all(length == l for l in lengths):
msg = ("scan got `length` argument of {} which disagrees with "
"leading axis sizes {}.")
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
else:
unique_lengths = set(lengths)
if len(unique_lengths) > 1:
msg = "scan got values with different leading axis sizes: {}."
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
elif len(unique_lengths) == 0:
msg = "scan got no values to scan over and `length` not provided."
raise ValueError(msg)
else:
length, = unique_lengths
if config.jax_disable_jit:
if length == 0:
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
ys.append(y)
stack = lambda *ys: jax.numpy.stack(ys)
stacked_y = tree_map(stack, *maybe_reversed(ys))
return carry, stacked_y
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(_abstractify, init_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(
f, in_tree, (*carry_avals, *x_avals), "scan")
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves]
return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children
# The carry input and output avals must match exactly. However, we want to account for
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
# may not match the output despite being compatible by virtue of their weak type.
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
# necessary, a second time with modified init values.
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out)
if changed:
new_init = tree_unflatten(init_tree, new_init_flat)
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
_check_tree_and_avals("scan carry output and input",
# Extract the subtree and avals for the first element of the return tuple
out_tree_children[0], carry_avals_out,
init_tree, carry_avals)
disallowed_effects = jaxpr.effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `scan`: {disallowed_effects}')
out = scan_p.bind(*consts, *in_flat,
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=len(init_flat),
linear=(False,) * (len(consts) + len(in_flat)),
unroll=unroll)
return tree_unflatten(out_tree, out)
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
carry = init
ys = []
for i in range(length):
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out = f_impl(*consts, *carry, *x)
carry, y = split_list(out, [num_carry])
ys.append(y)
ys = list(reversed(ys)) if reverse else ys
ys = list(zip(*ys))
ys = _map(_stack, y_avals, ys)
return (*carry, *ys)
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
def cond_fun(vals):
i, *_ = vals
return i < length
def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out
ys_init = _map(partial(_empty_array, length), y_avals)
if length == 0:
return init + ys_init
else:
init_val = [lax._const(length, 0)] + init + ys_init
_, *outs = while_loop(cond_fun, body_fun, init_val)
return outs
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
linear, block_length, f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, block_length)
assert rem == 0
partition = partial(_partition_leading, num_blocks, block_length)
xs_block = _map(partition, x_avals, xs)
prepend_aval = partial(_prepend_dim_to_aval, block_length)
x_block_avals = _map(prepend_aval, x_avals)
y_block_avals = _map(prepend_aval, y_avals)
f_impl_block = partial(
_scan_impl_unrolled, reverse=reverse, length=block_length,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
outs = _scan_impl_loop(
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
carry, ys_blocks = split_list(outs, [num_carry])
combine = partial(_combine_leading, num_blocks, block_length)
ys = _map(combine, y_avals, ys_blocks)
return (*carry, *ys)
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
unroll):
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
f_impl = core.jaxpr_as_fun(jaxpr)
if unroll == 1:
return _scan_impl_loop(
*args, reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
y_avals=y_avals)
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, unroll)
length_div = num_blocks * unroll
if rem > 0:
if reverse:
split = partial(_split_leading_dim, rem)
xs_rem, xs = unzip2(_map(split, x_avals, xs))
else:
split = partial(_split_leading_dim, length_div)
xs, xs_rem = unzip2(_map(split, x_avals, xs))
outs = _scan_impl_block_unrolled(
*consts, *init, *xs, reverse=reverse, length=length_div,
num_consts=num_consts, num_carry=num_carry, linear=linear,
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys = split_list(outs, [num_carry])
if rem > 0:
outs = _scan_impl_unrolled(
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys_rem = split_list(outs, [num_carry])
if reverse:
ys = _map(_concatenate, y_avals, ys_rem, ys)
else:
ys = _map(_concatenate, y_avals, ys, ys_rem)
return (*carry, *ys)
def _stack(aval, vals):
vals = [lax.expand_dims(x, (0,)) for x in vals]
return lax.concatenate(vals, 0)
def _concatenate(aval, x1, x2):
return lax.concatenate([x1, x2], 0)
def _split_leading_dim(i, aval, x):
assert x.ndim >= 1
return (slicing.slice_in_dim(x, 0, i),
slicing.slice_in_dim(x, i, x.shape[0]))
def _dynamic_index_array(i, aval, x):
return slicing.dynamic_index_in_dim(x, i, keepdims=False)
def _index_array(i, aval, x):
return slicing.index_in_dim(x, i, keepdims=False)
def _empty_array(sz, aval):
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
def _update_array(i, aval, xs, x):
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
def _partition_leading(sz0, sz1, aval, x):
assert x.ndim >= 1
assert x.shape[0] == sz0 * sz1
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
def _combine_leading(sz0, sz1, aval, x):
assert x.ndim >= 2
assert x.shape[0] == sz0
assert x.shape[1] == sz1
return lax.collapse(x, 0, 2)
def _prepend_dim_to_aval(sz, aval):
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
return carry_avals + ys_avals, jaxpr.effects
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear, unroll):
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])
# Fixpoint computation of which carry are not ad.zero: either
# non-zero from init, or the carry out is non-zero. Each iteration promotes
# at least one carry to non-zero. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_nz.
carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
nonzeros = const_nz + carry_nz + xs_nz
jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
if carry_nz_out == carry_nz:
break
else:
carry_nz = _map(operator.or_, carry_nz, carry_nz_out)
else:
assert False, "Fixpoint not reached"
tangents = [ad.instantiate_zeros(t) if nz else t
for t, nz in zip(tangents, nonzeros)]
consts, init, xs = split_list(primals, [num_consts, num_carry])
all_tangents = split_list(tangents, [num_consts, num_carry])
consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)
jaxpr_jvp_rearranged = ad.rearrange_binders(
jaxpr_jvp,
[num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)],
[num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)])
consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot)
+ init_linear + [True] * len(init_dot)
+ xs_linear + [True] * len(xs_dot))
out_flat = scan_p.bind(
*(consts + consts_dot + init + init_dot + xs + xs_dot),
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
num_consts=num_consts + len(consts_dot),
num_carry=num_carry + len(init_dot),
linear=jaxpr_jvp_linear, unroll=unroll)
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
primals_out = carry + ys
tangents_out_iter = iter(carry_dot + ys_dot)
tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(primals_out, nonzeros_out)]
return primals_out, tangents_out
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry
unknowns = [not t.pval.is_known() for t in tracers]
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
# Fixpoint computation of which carry elements are unknown. Each iteration
# promotes at least one carry to unknown. We need at most len(carry)
# iterations, but we need one last iteration to prepare the jaxpr based on the
# final carry_uk.
carry_uk = init_uk
for _ in range(1 + len(carry_uk)):
unknowns = const_uk + carry_uk + xs_uk
jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
carry_uk_out, ys_uk = split_list(out_uk, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
num_res = len(res_avals)
del res_avals, carry_uk_out
# Instantiate those inputs which must be treated as unknown from the fixpoint.
tracers = [trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns)]
# The residual inputs and outputs of the jaxprs produced haven't yet been
# adapted to the scan calling convention; in particular, jaxpr_known has its
# residual outputs all at the end, meaning they're extensive outputs (which is
# fully general but may be wasteful for residuals which are loop-invariant)
# while jaxpr_unknown has its corresponding residual inputs at the front (just
# as a convention with partial_eval_jaxpr_nounits), making them constant
# inputs. To make them consistent, we move the residual inputs on
# jaxpr_unknown to the end, even though we may move some back in the sequel.
jaxpr_unknown = pe.move_binders_to_back(
jaxpr_unknown, [True] * num_res + [False] * sum(unknowns))
# At this point, all residuals are treated as extensive outputs of jaxpr_known
# (and extensive inputs to jaxpr_unknown). But residuals that are loop-
# invariant can be hoisted out of the scan, rather than letting them get
# broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
# want to broadcast the matrix!). So, outside the loop we perform a partial
# evaluation with known 'const' inputs (but all other inputs unknown).
const_pvals = [pe.PartialVal.known(t.pval.get_known())
for t in tracers[:num_consts] if t.pval.is_known()]
other_pvals = [pe.PartialVal.unknown(aval)
for aval in jaxpr_known.in_avals[len(const_pvals):]]
with source_info_util.reset_name_stack():
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals,
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
# (known values in invar_pvals_out) and also computed loop-invariant values
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
# previous consts). We need to collect the computed inteisive residuals, and
# move corresponding intensive residual binders in jaxpr_unknown to the front.
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
jaxpr_unknown = pe.move_binders_to_front(
jaxpr_unknown,
[False] * sum(unknowns) + [pval.is_known() for pval in res_pvals])
del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals
# We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
# we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.
# As another optimization, for any extensive inputs that are just forwarded to
# extensive outputs, to avoid a copy (which would be looping over
# dynamic-update-slice) we'd rather forward the input tracer/value. That means
# pruning some outputs from jaxpr_known here, and updating `out_flat` below.
fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
# Prune fwds_known to include only extensive input to extensive output.
fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and
in_idx is not None and
in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk)
else None for out_idx, in_idx in enumerate(fwds_known)]
# Drop any extensive output we can instead get by forwarding an input.
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
jaxpr_known_.outvars = [x for x, i in zip(jaxpr_known_.outvars, fwds_known)
if i is None]
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
del jaxpr_known_
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
# Run the known part of the scan (if it has any outputs or effects).
known_inputs = (list(jaxpr_known_consts) +
[t.pval.get_known() for t in tracers[num_consts:]
if t.pval.is_known()])
if not jaxpr_known.out_avals and not jaxpr_known.effects:
out_known = []
else:
linear_known = [False] * len(known_inputs) # conservative!
out_known = scan_p.bind(
*known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk),
linear=tuple(linear_known), unroll=unroll)
del linear_known
# Complete the known output by filling in forwarded values using fwds_known.
out_known_iter = iter(out_known)
out_known = [next(out_known_iter) if f is None
else _maybe_put(known_inputs[f]) for f in fwds_known]
assert next(out_known_iter, None) is None
del known_inputs, out_known_iter
# Split known outputs from residuals.
out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)])
assert len(intensive_res) + len(extensive_res) == num_res
# Create input tracers for jaxpr_unknown bind.
unknown_inputs = [t for t in tracers if not t.pval.is_known()]
intensive_res = _map(trace.new_instantiated_const, intensive_res)
extensive_res = _map(trace.new_instantiated_const, extensive_res)
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval)
for y_aval in y_avals]
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in itertools.chain(carry_avals, ys_avals)]
del carry_avals, y_avals
# Create equation.
linear_unknown = tuple([False] * len(intensive_res) +
[l for l, uk in zip(linear, unknowns) if uk] +
[False] * len(extensive_res))
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res],
out_tracers, scan_p,
dict(reverse=reverse, length=length, unroll=unroll,
jaxpr=jaxpr_unknown, linear=linear_unknown,
num_consts=len(intensive_res) + sum(const_uk),
num_carry=sum(carry_uk)),
jaxpr_unknown.effects, source)
for t in out_tracers: t.recipe = eqn
# Merge known and unknown outputs into final result.
return util.merge_lists(out_uk, out_known, out_tracers)
def _maybe_put(x):
if isinstance(x, np.ndarray):
return jax.device_put(x, jax.devices('cpu')[0])
else:
return x
def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts,
num_carry, jaxpr, linear, unroll):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
num_eres = len(xs_lin) - sum(xs_lin)
if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
raise NotImplementedError
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
raise NotImplementedError
if not all(init_lin):
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
consts, _, xs = split_list(args, [num_consts, num_carry])
ires, _ = split_list(consts, [num_ires])
_, eres = split_list(xs, [sum(xs_lin)])
assert not any(ad.is_undefined_primal(r) for r in ires)
assert not any(ad.is_undefined_primal(r) for r in eres)
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
ct_carry, ct_ys = split_list(cts, [num_carry])
ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
jaxpr_trans = _transpose_scan_jaxpr(
num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes)
linear_trans = ([False] * num_ires +
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
[False] * num_eres)
outs = scan_p.bind(
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans),
unroll=unroll)
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
# if an axis isn't reduced
res1_avals, c_avals, a_avals, res2_avals = split_list(
jaxpr.in_avals, [num_res1, num_c, num_a])
num_b = len(jaxpr.out_avals)
b_avals = list(jaxpr.out_avals)
@lu.wrap_init
def transposed(*res1_cbar_bbar_res2):
res1, c_bar, b_bar, res2 = split_list(
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts,
primals, b_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
_map(ad.add_tangents, c_bar, new_c_bar))
return c_bar + a_bar
return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
def _scan_batching_rule(axis_size, axis_name, main_type, args, dims, reverse, length,
jaxpr, num_consts, num_carry, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry
orig_batched = [d is not batching.not_mapped for d in dims]
const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry])
# Fixpoint computation of which carry are batched: either
# batched from init, or the carry out is batched. Each iteration promotes
# at least one carry to batched. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_batched.
carry_batched = init_batched
for _ in range(1 + len(carry_batched)):
batched = const_batched + carry_batched + xs_batched
jaxpr_batched, batched_out = batching.batch_jaxpr(
jaxpr, axis_size, batched,
instantiate=carry_batched + [False] * num_ys,
axis_name=axis_name,
main_type=main_type)
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
if carry_batched_out == carry_batched:
break
else:
carry_batched = _map(operator.or_, carry_batched, carry_batched_out)
else:
assert False, "Fixpoint not reached"
consts, init, xs = split_list(args, [num_consts, num_carry])
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(consts, consts_bdims)]
new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched
else batching.moveaxis(x, d, 0) if now_batched else x
for x, d, was_batched, now_batched in
zip(init, init_bdims, init_batched, carry_batched)]
new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1
else x for x, d in zip(xs, xs_bdims)]
new_args = new_consts + new_init + new_xs
outs = scan_p.bind(
*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll)
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
return outs, carry_bdims + ys_bdims
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
fun = core.jaxpr_as_fun(jaxpr)
@lu.wrap_init
def masked(*args):
[dynamic_length], consts, [i], carry, xs = split_list(
args, [1, num_consts, 1, num_carry])
out = fun(*(consts + carry + xs))
new_carry, ys = split_list(out, [num_carry])
new_carry = [lax.select(i < dynamic_length, new_c, c)
for new_c, c in zip(new_carry, carry)]
return [i + 1] + new_carry + ys
aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_))
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], core.JaxprEqn]:
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry])
for i in range(1 + num_carry):
used_outputs = used_carry_out + used_extensive_out
jaxpr_dce, used_inputs = pe.dce_jaxpr(
jaxpr.jaxpr, used_outputs,
instantiate=[False] * num_consts + used_carry_out + [False] * num_xs)
used_consts, used_carry_in, used_extensive_in = \
split_list(used_inputs, [num_consts, num_carry])
if list(used_carry_in) == list(used_carry_out):
break
else:
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
else:
assert False, "Fixpoint not reached"
if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr)
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
new_params = dict(eqn.params, num_consts=sum(used_consts),
num_carry=sum(used_carry_in), linear=tuple(new_linear),
jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts))
# TODO(mattjj,sharadmv): don't assume effects are never DCE'd?
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, eqn.effects, eqn.source_info)
assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals )
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
return used_inputs, new_eqn
# TODO(mattjj): de-duplicate code with _scan_partial_eval
def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_ys = len(jaxpr.out_avals) - num_carry
# Fixpoint (trivial on 'inst_in', since we might as well make all inputs
# available as DCE can subsequently prune any unused ones)
const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry])
for _ in range(1 + len(carry_uk)):
unks_in = const_uk + carry_uk + xs_uk
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=True,
ensure_out_unknowns=carry_uk + [False] * num_ys,
ensure_out_inst=True, saveable=saveable)
carry_uk_out, ys_uk = split_list(unks_out, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
# Move all residual binders to the back of jaxpr_staged so they're extensive.
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
res_avals = jaxpr_staged.in_avals[:num_res]
jaxpr_staged = pe.move_binders_to_back(
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
# passing in_inst argument to partial_eval_jaxpr_custom above).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)
# As an optimization, hoist loop-invariant residuals out of the loop rather
# than using extensive outputs for them. See _scan_partial_eval for comments.
num_const_known = len(const_uk) - sum(const_uk)
num_carry_known = len(carry_uk) - sum(carry_uk)
num_xs_known = len( xs_uk) - sum( xs_uk)
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \
pe.partial_eval_jaxpr_nounits(
jaxpr_known,
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
[True] * (len(unks_out) - sum(unks_out)) + [False] * num_res)
# jaxpr_known_hoist produces intensive residuals followed by the constants for
# jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts.
_, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res])
jaxpr_staged = pe.move_binders_to_front(
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
del loop_dep, num_carry_known, num_xs_known, const_uk
# Create residual variables.
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a)
for a in ext_avals_mapped]
newvar = core.gensym()
intensive_res = _map(newvar, intensive_avals)
extensive_res = _map(newvar, ext_avals)
# Create known eqn, which is a call_p combining evaluation of
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
# jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist
# (corresponding to consts_known_lp_avals) followed by known carry and xs.
linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
_, linear_known_ = split_list(linear_known_, [num_const_known])
linear_known = [False] * len(consts_known_lp_avals) + linear_known_
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
num_consts=len(consts_known_lp_avals),
num_carry=len(carry_uk)-sum(carry_uk),
linear=tuple(linear_known))
@lu.wrap_init
def known(*ins_known):
consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known])
out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist)
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
return [*intensive_res, *out_loop]
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in ins_known])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*intensive_res, *out_binders_known, *extensive_res],
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
eqn.source_info)
# Create the staged eqn.
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
[False] * len(extensive_res))
params_staged = dict(eqn.params, jaxpr=jaxpr_staged,
num_consts=len(intensive_res) + eqn.params['num_consts'],
linear=tuple(linear_staged))
eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res],
out_binders_staged, eqn.primitive,
params_staged, jaxpr_staged.effects,
eqn.source_info)
new_vars = [*new_inst, *intensive_res, *extensive_res]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
avals = [x.aval for x in in_atoms]
tc = partial(_typecheck_param, 'scan')
tc(reverse, 'reverse', 'bool', type(reverse) is bool)
tc(num_consts, 'num_consts', 'non-negative int',
type(num_consts) is int and num_consts >= 0)
tc(num_carry, 'num_carry', 'non-negative int',
type(num_carry) is int and num_carry >= 0)
tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr)
tc(linear, 'linear', 'tuple of bool',
type(linear) is tuple and all(type(x) is bool for x in linear))
tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0)
tc(length, 'length', 'non-negative int',
type(length) is int and length >= 0)
if len(linear) != len(avals):
raise core.JaxprTypeError(
f'scan param linear has length {len(linear)} for {len(avals)} operands')
const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list(
jaxpr.in_avals, [num_consts, num_carry])
carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry])
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a)
for a in y_avals_mapped]
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
raise core.JaxprTypeError(
f'scan input carry input and output types mismatch: '
f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)):
raise core.JaxprTypeError(
f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
f'called with consts of type\n{_avals_short(const_avals)}')
if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)):
raise core.JaxprTypeError(
f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
f'called with initial carry of type\n{_avals_short(init_avals)}')
if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)):
raise core.JaxprTypeError(
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
f'called with sequence of type\n{_avals_short(x_avals)}')
return [*init_avals, *y_avals], jaxpr.effects
def _scan_pp_rule(eqn, context, settings):
printed_params = dict(eqn.params)
del printed_params['linear']
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
del printed_params['length']
if printed_params['unroll'] == 1:
del printed_params['unroll']
if printed_params['num_carry'] == 0:
del printed_params['num_carry']
if printed_params['num_consts'] == 0:
del printed_params['num_consts']
if not printed_params['reverse']:
del printed_params['reverse']
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
def scan_bind(*args, **params):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_scan_typecheck(True, *in_atoms, **params)
core.check_jaxpr(params['jaxpr'].jaxpr)
return core.AxisPrimitive.bind(scan_p, *args, **params)
scan_p = core.AxisPrimitive("scan")
scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.register_initial_style_primitive(scan_p)
mlir.register_lowering(scan_p,
mlir.lower_fun(_scan_impl, multiple_results=True))
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
# TODO(mattjj,frostig): un-comment this pp rule
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
### while_loop
@api_boundary
def while_loop(cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T) -> T:
"""Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
The semantics of ``while_loop`` are given by this Python implementation::
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
to a single XLA While HLO. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
other words, the type ``a`` in the type signature above represents an array
with a fixed shape and dtype (or a nested tuple/list/dict container data
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
Another difference from using Python-native loop constructs is that
``while_loop`` is not reverse-mode differentiable because XLA computations
require static bounds on memory requirements.
.. note::
:py:func:`while_loop` compiles ``cond_fun`` and ``body_fun``, so while it
can be combined with :py:func:`jit`, it's usually unnecessary.
Args:
cond_fun: function of type ``a -> Bool``.
body_fun: function of type ``a -> a``.
init_val: value of type ``a``, a type that can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof, representing the initial
loop carry value.
Returns:
The output from the final iteration of body_fun, of type ``a``.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not (callable(body_fun) and callable(cond_fun)):
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
if config.jax_disable_jit:
try:
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
except core.ConcretizationTypeError:
# Can't run this while_loop in Python (e.g. because there's a vmap
# transformation on it), so we fall back to the primitive version.
pass
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, "while_loop")
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
pred_aval = cond_jaxpr.out_avals[0]
if (not isinstance(pred_aval, ShapedArray)
or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
# The body input and output avals must match exactly. However, we want to account for
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
# may not match the output despite being compatible by virtue of their weak type.
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
# necessary, a second time with modified init values.
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
if changed:
new_init_val, = tree_unflatten(in_tree, new_init_vals)
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
cond_jaxpr, cond_consts, body_consts, body_tree = rest
in_tree_children = in_tree.children()
assert len(in_tree_children) == 1
_check_tree_and_avals("body_fun output and input",
body_tree, body_jaxpr.out_avals,
in_tree_children[0], init_avals)
effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
disallowed_effects = effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
return tree_unflatten(body_tree, outs)
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
del args, kwargs
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects
def _while_loop_batching_rule(axis_size, axis_name, main_type, args, dims,
cond_nconsts, cond_jaxpr,
body_nconsts, body_jaxpr):
orig_batched = [d is not batching.not_mapped for d in dims]
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts])
cconst_dims, bconst_dims, init_dims = split_list(dims, [cond_nconsts, body_nconsts])
carry_bat = init_bat
# Fixpoint computation of which carry are batched: either
# batched from init, or the carry out is batched. Each iteration promotes
# at least one carry to batched. We need at most len(carry) iterations to
# reach a fixpoint.
for _ in range(1 + len(carry_bat)):
_, carry_bat_out = batching.batch_jaxpr(
body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat,
axis_name=axis_name, main_type=main_type)
if carry_bat == carry_bat_out:
break
carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out)
else:
assert False, "Fixpoint not reached"
# Knowing how the carry is batched now, we can determine if the predicate is
# batched.
_, (pred_bat,) = batching.batch_jaxpr(
cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False,
axis_name=axis_name, main_type=main_type)
if pred_bat:
# If the predicate is batched, we have to batch *all* of the carry
# regardless of if the body needs it.
carry_bat = [True] * len(carry_bat)
carry_dims = [0] * len(carry_bat)
body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
body_jaxpr, axis_size, bconst_dims + carry_dims,
carry_dims, axis_name=axis_name, main_type=main_type)
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
cond_jaxpr, axis_size, cconst_dims + carry_dims, [0],
axis_name=axis_name, main_type=main_type)
else:
# If the predicate is not batched, we can look at the `cond_jaxpr`'s out
# shape to determine the rank of the predicate. From this rank we pick the
# dims of the carry to be batched to ensure that the predicate shape is a
# prefix of the carry in and out shapes. We can then batch the `body_jaxpr`
# according to these new batch dims.
cond_rank = len(cond_jaxpr.out_avals[0].shape)
carry_dims = [cond_rank if b else None for b in carry_bat]
body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims,
axis_name=axis_name, main_type=main_type)
# Now we need to rebatch the `cond_jaxpr` according to the new dims of the
# carry.
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,),
axis_name=axis_name, main_type=main_type)
# To prepare the `init` to the `while_p`, we broadcast values if they are
# unbatched and need to have an out axis. If their current batch axis does not
# match the one it needs to be for the translation rule to work, we move it
# into place.
new_init = []
for x, old_axis, new_axis in zip(init, init_dims, carry_dims):
if old_axis is batching.not_mapped and new_axis is not batching.not_mapped:
new_init.append(batching.broadcast(x, axis_size, new_axis))
elif old_axis is batching.not_mapped and new_axis is batching.not_mapped:
new_init.append(x)
else:
assert new_axis is not batching.not_mapped
new_init.append(batching.moveaxis(x, old_axis, new_axis))
outs = while_p.bind(*(cconsts + bconsts + new_init),
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
return outs, carry_dims
def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
body_jaxpr):
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts])
carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
body_nonzeros = bconst_nz + carry_nz
body_jvp, nonzeros_out = ad.jvp_jaxpr(
body_jaxpr, body_nonzeros, instantiate=carry_nz)
if nonzeros_out == carry_nz:
break
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
else:
assert False, "Fixpoint not reached"
nonzeros = cconst_nz + body_nonzeros
tangents = [ad.instantiate_zeros(t) if nz else t
for t, nz in zip(tangents, nonzeros)]
cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts])
_, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts])
bconst_dot = _prune_zeros(bconst_dot)
init_dot = _prune_zeros(init_dot)
num_carry = len(primals) - cond_nconsts - body_nconsts
body_jvp_rearranged = ad.rearrange_binders(
body_jvp,
[body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
[num_carry], [len(init_dot)])
newvar = core.gensym([cond_jaxpr.jaxpr])
invars_aug = (
cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot])
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
invars_aug,
cond_jaxpr.jaxpr.outvars,
cond_jaxpr.jaxpr.eqns,
cond_jaxpr.jaxpr.effects)
cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts)
out = while_p.bind(
*(cconst + bconst + bconst_dot + init + init_dot),
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr_augmented,
body_nconsts=len(bconst) + len(bconst_dot),
body_jaxpr=body_jvp_rearranged)
out_carry, out_carry_dot = split_list(out, [num_carry])
out_tangents_iter = iter(out_carry_dot)
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(out_carry, nonzeros_out)]
return out_carry, out_tangents
def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int,
body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]:
# As long as some carry (and hence output) are known and the output of
# `cond_jaxpr` is known, we use a portion of the loop body to compute the
# known outputs of the `while_loop`. For the unknown outputs we generate a
# jaxpr to run the whole while, including recomputing the known parts,
# basically like building in checkpointing/rematieralization. This means that
# we don't actually save any computation by partial evaluation if there are
# unknown outputs.
#
# What this achieves is twofold: jax.linearize works, and we can give a proper
# error for reverse differentiation of `while`.
unknowns = [not t.pval.is_known() for t in tracers]
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
cond_consts_uk, body_consts_uk, carry_init_uk = \
split_list(unknowns, [cond_nconsts, body_nconsts])
# Fixpoint computation of unknown carry. Each iteration promotes at least one
# carry to unknown. We need one last iteration to prepare the jaxpr.
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_jaxpr_known, _, carry_out_uk, body_res_avals = pe.partial_eval_jaxpr_nounits( # type: ignore
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
if carry_out_uk == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
else:
assert False, "Fixpoint not reached"
cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
# If conditional is unknown, or all inputs are known, or all are unknown,
# just do the default processing.
return trace.default_process_primitive(while_p, tracers, params)
# Run the known part of the while.
in_consts = [t.pval.get_known() for uk, t in
zip(cond_consts_uk + body_consts_uk + carry_uk, tracers)
if not uk]
cond_nconsts_known = len(cond_consts_uk) - sum(cond_consts_uk)
body_nconsts_known = len(body_consts_uk) - sum(body_consts_uk)
num_known_outs = len(carry_uk) - sum(carry_uk)
# TODO(mattjj): use pe.dce_jaxpr to drop res computations and not just outputs
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:num_known_outs]
out_known = while_p.bind(
*in_consts, cond_nconsts=cond_nconsts_known, cond_jaxpr=cond_jaxpr_known,
body_nconsts=body_nconsts_known, body_jaxpr=body_jaxpr_known)
del body_jaxpr_known
# Run the whole while_loop to get all the outputs, then merge with known ones
out_tracers_ = trace.default_process_primitive(while_p, tracers, params)
out_tracers = [t for t, uk in zip(out_tracers_, carry_uk) if uk]
return util.merge_lists(carry_uk, out_known, out_tracers)
# TODO(mattjj): de-duplicate code with _while_partial_eval
def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn):
del saveable # We can't save any residuals anyway (w/o dynamic shapes)!
cond_jaxpr = eqn.params['cond_jaxpr']
cond_nconsts = eqn.params['cond_nconsts']
body_jaxpr = eqn.params['body_jaxpr']
body_nconsts = eqn.params['body_nconsts']
cond_consts_uk, body_consts_uk, carry_init_uk = \
split_list(unks_in, [cond_nconsts, body_nconsts])
# Fixpoint to compute known part of the body (trivial on 'inst_in', since we
# make all inputs available as DCE can subsequently prune any unused ones)
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_unks_in = body_consts_uk + carry_uk
jaxpr_known_, _, carry_uk_out, _, num_res = \
pe.partial_eval_jaxpr_custom(
body_jaxpr.jaxpr, in_unknowns=body_unks_in, in_inst=True,
ensure_out_unknowns=carry_uk, ensure_out_inst=True,
saveable=ad_checkpoint.nothing_saveable)
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
assert not num_res
body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts)
del jaxpr_known_, carry_uk_out, num_res
# Instantiate all inputs (b/c jaxpr_staged will take all inputs).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
# Compute the known part of cond_fun (basically pruning inputs on known side).
cond_unks_in = cond_consts_uk + carry_uk
cond_jaxpr_known_, _, [cond_uk], _, _ = \
pe.partial_eval_jaxpr_custom(
cond_jaxpr.jaxpr, cond_unks_in, in_inst=True,
ensure_out_unknowns=False, ensure_out_inst=True,
saveable=ad_checkpoint.nothing_saveable)
# NOTE(mattjj): I think it should be impossible for the condition to be
# unknown, but asserting that caused a test failure in diffrax. So
# we handle it: if it is unknown, stage out the whole cond function.
if cond_uk:
return None, eqn, [True] * len(carry_uk), [True] * len(carry_uk), new_inst
cond_jaxpr_known = core.ClosedJaxpr(cond_jaxpr_known_, cond_jaxpr.consts)
del cond_uk
# Build the known eqn.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(carry_uk, eqn.outvars)
params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known,
cond_nconsts=len(cond_consts_uk) - sum(cond_consts_uk),
body_nconsts=len(body_consts_uk) - sum(body_consts_uk))
effects_known = core.join_effects(cond_jaxpr_known.effects,
body_jaxpr_known.effects)
eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p,
params_known, effects_known, eqn.source_info)
# Staged eqn is same as input eqn.
eqn_staged = eqn
unks_out = carry_uk
inst_out = [True] * len(unks_out)
return eqn_known, eqn_staged, unks_out, inst_out, new_inst
def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for "
"lax.while_loop or lax.fori_loop. "
"Try using lax.scan instead.")
# For a while loop with ordered effects in the cond, we need a special
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
# this:
# ```
# while cond(x):
# x = body(x)
# ```
# into something that looks like this:
# ```
# while True:
# token, pred = cond(token, x)
# if not pred:
# break
# token, x = body(token, x)
# ```
# Unfortunately, with an MHLO while we can't (1) return multiple values
# from a `cond` and (2) can't break a while loop. We thus adopt the
# following rewrite strategy:
# ```
# def new_cond(pred, token, x):
# return pred
# token, pred = cond(token, x)
# while new_cond(pred, token, x):
# token, x = body(token, x)
# token, pred = cond(token, x)
# ```
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)
cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in
core.ordered_effects]
if cond_ordered_effects:
def cond(args):
return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
def body(args):
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))
def new_cond(pred_args):
pred, _ = pred_args
return pred
def new_body(pred_args):
_, args = pred_args
args = body(args)
pred = cond(args)
return pred, args
def fun(*args):
pred = cond(args)
_, out = while_loop(new_cond, new_body, (pred, args))
return out
return mlir.lower_fun(fun)(ctx, *args)
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
body_effects = [eff for eff in body_jaxpr.effects
if eff in core.ordered_effects]
num_tokens = len(body_effects)
tokens = [ctx.tokens_in.get(eff) for eff in body_effects]
token_types = [mlir.token_type() for _ in tokens]
loop_carry_types = [*token_types, *loop_carry_types]
flat_loop_carry_types = util.flatten(loop_carry_types)
args = [*tokens, *args]
flat_args = mlir.flatten_lowering_ir_args(args)
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)
# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
name_stack = extend_name_stack(ctx.module_context.name_stack, 'while')
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
# Remove tokens from cond args
cond_args = cond_args[num_tokens:]
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, 'cond'))
((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z))
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context,
primitive=None,
avals_in=[pred_aval],
avals_out=[pred_aval.update(shape=())],
tokens_in=mlir.TokenSet(),
tokens_out=None)
pred, = lax._unary_reduce_lower(
mhlo.OrOp,
lambda dtype: np.array(False, dtype),
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
mhlo.ReturnOp([pred])
# Loop body
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
with ir.InsertionPoint(body_block):
flat_body_args = [
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
# Tokens are at the front of the args list to the while loop
token_args, body_args = util.split_list(body_args, [num_tokens])
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, 'body'))
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts), *(y + z))
out_tokens = [tokens_out.get(eff) for eff in body_effects]
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack,
'body_pred'))
((body_pred,),), _ = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts), *(x + z))
new_z = _map(
partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
mhlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
*util.flatten(y), *util.flatten(new_z)])
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])
if tokens:
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
return z
def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
if joined_effects - allowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {joined_effects - allowed_effects}')
return body_jaxpr.out_avals, joined_effects
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_initial_style_primitive(while_p)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck
def _pred_bcast_select_mhlo(
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
y, = ys
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
x_y_type = mlir.aval_to_ir_type(x_y_aval)
bcast_pred_type = ir.RankedTensorType.get(
x_y_type.shape, mlir.dtype_to_ir_type(np.dtype(np.bool_)))
bcast_pred = mhlo.BroadcastInDimOp(
bcast_pred_type, pred,
mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
return mhlo.SelectOp(bcast_pred, x, y).results
### fori_loop
def _fori_cond_fun(loop_carry):
i, upper, _ = loop_carry
return lax.lt(i, upper)
@weakref_lru_cache
def _fori_body_fun(body_fun):
body_fun = weakref.ref(body_fun)
def while_body_fun(loop_carry):
i, upper, x = loop_carry
return lax.add(i, lax._const(i, 1)), upper, body_fun()(i, x)
return while_body_fun
@weakref_lru_cache
def _fori_scan_body_fun(body_fun):
body_fun = weakref.ref(body_fun)
def scanned_fun(loop_carry, _):
i, x = loop_carry
return (i + 1, body_fun()(i, x)), None
return scanned_fun
@api_boundary
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
The semantics of ``fori_loop`` are given by this Python implementation::
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
Unlike that Python version, ``fori_loop`` is implemented in terms of either a
call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the
trip count is static (meaning known at tracing time, perhaps because ``lower``
and ``upper`` are Python integer literals) then the ``fori_loop`` is
implemented in terms of ``scan`` and reverse-mode autodiff is supported;
otherwise, a ``while_loop`` is used and reverse-mode autodiff is not
supported. See those functions' docstrings for more information.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
other words, the type ``a`` in the type signature above represents an array
with a fixed shape and dtype (or a nested tuple/list/dict container data
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
.. note::
:py:func:`fori_loop` compiles ``body_fun``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: function of type ``(int, a) -> a``.
init_val: initial loop carry value of type ``a``.
Returns:
Loop value from the final iteration, of type ``a``.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not callable(body_fun):
raise TypeError("lax.fori_loop: body_fun argument should be callable.")
# TODO(phawkins): perhaps do more type checking here, better error messages.
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
if lower_dtype != upper_dtype:
msg = ("lower and upper arguments to fori_loop must have equal types, "
"got {} and {}")
raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
# If we can specialize on the trip count, call scan instead of a while_loop
# to enable efficient reverse-mode differentiation.
if (isinstance(core.get_aval(lower), ConcreteArray) and
isinstance(core.get_aval(upper), ConcreteArray)):
try:
lower_ = int(lower)
upper_ = int(upper)
except TypeError:
use_scan = False
else:
use_scan = True
else:
use_scan = False
if use_scan:
if config.jax_disable_jit and upper_ == lower_:
# non-jit implementation of scan does not support length=0
return init_val
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
None, length=upper_ - lower_)
else:
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
(lower, upper, init_val))
return result
### map and miscellanous rules
@api_boundary
def map(f, xs):
"""Map a function over leading array axes.
Like Python's builtin map, except inputs and outputs are in the form of
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
need to apply a function element by element for reduced memory usage or
heterogeneous computation with other control flow primitives.
When ``xs`` is an array type, the semantics of ``map`` are given by this
Python implementation::
def map(f, xs):
return np.stack([f(x) for x in xs])
Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
nested pytree type, and the mapped computation is compiled only once.
Args:
f: a Python function to apply element-wise over the first axis or axes of
``xs``.
xs: values over which to map along the leading axis.
Returns:
Mapped values.
"""
g = lambda _, x: ((), f(x))
_, ys = scan(g, (), xs)
return ys
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
"""Calls RBG in a loop and stacks the results."""
key, = batched_args
bd, = batch_dims
if bd is batching.not_mapped:
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
algorithm=algorithm), (None, None)
key = batching.moveaxis(key, bd, 0)
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
stacked_keys, stacked_bits = map(map_body, key)
return (stacked_keys, stacked_bits), (0, 0)
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore
### associative_scan
@api_boundary
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
"""Performs a scan with an associative binary operation, in parallel.
For an introduction to associative scans, see [BLE1990]_.
Args:
fn: A Python callable implementing an associative binary operation with
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
must satisfy the equation
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
The inputs and result are (possibly nested Python tree structures of)
array(s) matching ``elems``. Each array has a dimension in place
of the ``axis`` dimension. `fn` should be applied elementwise over
the ``axis`` dimension (for example, by using :func:`jax.vmap` over the
elementwise function.)
The result ``r`` has the same shape (and structure) as the two inputs
``a`` and ``b``.
elems: A (possibly nested Python tree structure of) array(s), each with
an ``axis`` dimension of size ``num_elems``.
reverse: A boolean stating if the scan should be reversed with respect to
the ``axis`` dimension.
axis: an integer identifying the axis over which the scan should occur.
Returns:
A (possibly nested Python tree structure of) array(s) of the same shape
and structure as ``elems``, in which the ``k``'th element of ``axis`` is the
result of recursively applying ``fn`` to combine the first ``k`` elements
of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``,
the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
Example 1: partial sums of an array of numbers:
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
DeviceArray([0, 1, 3, 6], dtype=int32)
Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)
Example 3: reversed partial sums of an array of numbers
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
DeviceArray([6, 6, 5, 3], dtype=int32)
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
University.
"""
if not callable(fn):
raise TypeError("lax.associative_scan: fn argument should be callable.")
elems_flat, tree = tree_flatten(elems)
if reverse:
elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat]
def combine(a_flat, b_flat):
# Lower `fn` to operate on flattened sequences of elems.
a = tree_unflatten(tree, a_flat)
b = tree_unflatten(tree, b_flat)
c = fn(a, b)
c_flat, _ = tree_flatten(c)
return c_flat
# Check that all inputs have a consistent leading dimension `num_elems`.
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
raise ValueError('Array inputs to associative_scan must have the same '
'first dimension. (saw: {})'
.format([elem.shape for elem in elems_flat]))
# Summary of algorithm:
#
# Consider elements of `_scan(elems)` at odd indices. That's the same as first
# summing successive pairs of elements of `elems` and performing a scan on
# that half sized tensor. We perform the latter scan by recursion.
#
# Now consider the even elements of `_scan(elems)`. These can be computed
# from the odd elements of `_scan(elems)` by adding each odd element of
# `_scan(elems)` to the matching even element in the original `elems`.
#
# We return the odd and even elements interleaved.
#
# For the base case of the recursion we return the first element
# of `elems` followed by the sum of the first two elements computed as
# a (small two-down-to-one) reduction step.
def _scan(elems):
"""Perform scan on `elems`."""
num_elems = elems[0].shape[axis]
if num_elems < 2:
return elems
# Combine adjacent pairs of elements.
reduced_elems = combine(
[slicing.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems],
[slicing.slice_in_dim(elem, 1, None, stride=2, axis=axis)
for elem in elems])
# Recursively compute scan for partially reduced tensors.
odd_elems = _scan(reduced_elems)
if num_elems % 2 == 0:
even_elems = combine(
[slicing.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems],
[slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
else:
even_elems = combine(
odd_elems,
[slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
# The first element of a scan is the same as the first element
# of the original `elems`.
even_elems = [
lax.concatenate([slicing.slice_in_dim(elem, 0, 1, axis=axis), result],
dimension=axis)
for (elem, result) in zip(elems, even_elems)]
return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems))
scans = _scan(elems_flat)
if reverse:
scans = [lax.rev(scanned, [axis]) for scanned in scans]
return tree_unflatten(tree, scans)
def _interleave(a, b, axis):
"""Given two Tensors of static shape, interleave them along the first axis."""
assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
a_pad = [(0, 0, 0)] * a.ndim
b_pad = [(0, 0, 0)] * b.ndim
a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
op = lax.bitwise_or if a.dtype == np.bool_ else lax.add
return op(lax.pad(a, lax._const(a, 0), a_pad),
lax.pad(b, lax._const(b, 0), b_pad))
### Cumulative reductions.
def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative sum along `axis`."""
return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative maximum along `axis`."""
return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative minimum along `axis`."""
return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def _cumred_shape_rule(x, *, axis: int, reverse: bool):
if axis < 0 or axis >= x.ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of shape {x.shape}")
return x.shape
def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool):
return [cumsum(t, axis=axis, reverse=not reverse)]
def cumred_reduce_window_impl(window_reduce: Callable, x, *, axis: int,
reverse: bool):
n = x.shape[axis]
if n == 0:
return x
padding = [(0, 0)] * x.ndim
padding[axis] = (0, n - 1) if reverse else (n - 1, 0)
strides = [1] * x.ndim
window_dims = [1] * x.ndim
window_dims[axis] = n
return window_reduce(x, window_dims, strides, padding)
def cumred_gpu_impl(window_reduce: Callable, reduce_fn: Callable, x, *,
axis: int, reverse: bool):
# On GPU, reduce_window is executed in a single fusion and associative_scan
# is split into multiple to materialize intermediate calculations.
# On small inputs reduce_window is faster being a single fusion,
# but on larger ones is slower because of O(n^2) complexity.
# This conservative value of the threshold was obtained via benchmarking.
if x.shape[axis] > 32:
return associative_scan(reduce_fn, x, reverse=reverse, axis=axis)
return cumred_reduce_window_impl(window_reduce, x, axis=axis, reverse=reverse)
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int,
reverse: bool):
operand, = batched_args
bdim, = batch_dims
axis = axis if axis < bdim else axis + 1
return prim.bind(operand, axis=axis, reverse=reverse), bdim
def _cumred_dtype_rule(name, operand, *args, **kw):
if not dtypes.issubdtype(operand.dtype, np.number):
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
"of number.".format(name, np.dtype(operand.dtype).name))
return dtypes.canonicalize_dtype(operand.dtype)
def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn):
reducer_p = lax.standard_primitive(
_cumred_shape_rule, partial(_cumred_dtype_rule, name),
name)
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule,
reducer_p)
def register_lowering(fn, platform=None):
mlir.register_lowering(
reducer_p,
mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)),
platform=platform)
# Default for platforms not treated specially below.
register_lowering(partial(associative_scan, reduce_fn))
# On GPU, we choose between window reduction and associative scan
# based on the input size.
for platform in ['cuda', 'rocm']:
register_lowering(
partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform)
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
register_lowering(partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu')
return reducer_p
cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, windowed_reductions._reduce_window_sum)
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, windowed_reductions._reduce_window_prod)
cummax_p = _cumulative_reduction_primitive("cummax", lax.max, windowed_reductions._reduce_window_max)
cummin_p = _cumulative_reduction_primitive("cummin", lax.min, windowed_reductions._reduce_window_min)
def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(associative_scan, combine_fn, axis=axis,
reverse=reverse),
primals, tangents)
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)