mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details. PiperOrigin-RevId: 476167538
803 lines
33 KiB
Python
803 lines
33 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Module for conditional control flow primitives."""
|
|
import collections
|
|
import functools
|
|
from functools import partial
|
|
import inspect
|
|
import itertools
|
|
import operator
|
|
|
|
from typing import Callable, Sequence, List, Tuple
|
|
|
|
from jax import core
|
|
from jax import linear_util as lu
|
|
from jax.config import config
|
|
from jax.core import ConcreteArray, raise_to_shaped
|
|
from jax.interpreters import ad
|
|
from jax.interpreters import batching
|
|
from jax.interpreters import mlir
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax.interpreters import xla
|
|
from jax.tree_util import tree_flatten, tree_unflatten
|
|
from jax._src import ad_util
|
|
from jax._src import dtypes
|
|
from jax._src import source_info_util
|
|
from jax._src import util
|
|
from jax._src.lax import lax
|
|
from jax._src.traceback_util import api_boundary
|
|
from jax._src.util import (safe_map, extend_name_stack, split_list,
|
|
partition_list)
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import mhlo
|
|
import numpy as np
|
|
|
|
from jax._src.lax.control_flow.common import (
|
|
_abstractify,
|
|
_avals_short,
|
|
_check_tree_and_avals,
|
|
_initial_style_jaxprs_with_common_consts,
|
|
_make_closed_jaxpr,
|
|
_prune_zeros,
|
|
_typecheck_param,
|
|
allowed_effects,
|
|
)
|
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
|
|
# For backward compatibility with a previous switch/cond calling convention,
|
|
# we allow a single (pytree) `operand` argument to be passed by keyword. We use
|
|
# a sentinel object as its default value to indicate when it is _not_ passed.
|
|
_no_operand_sentinel = object()
|
|
|
|
@api_boundary
|
|
def switch(index, branches: Sequence[Callable], *operands,
|
|
operand=_no_operand_sentinel):
|
|
"""Apply exactly one of ``branches`` given by ``index``.
|
|
|
|
If ``index`` is out of bounds, it is clamped to within bounds.
|
|
|
|
Has the semantics of the following Python::
|
|
|
|
def switch(index, branches, *operands):
|
|
index = clamp(0, index, len(branches) - 1)
|
|
return branches[index](*operands)
|
|
|
|
Args:
|
|
index: Integer scalar type, indicating which branch function to apply.
|
|
branches: Sequence of functions (A -> B) to be applied based on ``index``.
|
|
operands: Operands (A) input to whichever branch is applied.
|
|
|
|
Returns:
|
|
Value (B) of ``branch(*operands)`` for the branch that was selected based
|
|
on ``index``.
|
|
"""
|
|
if not all(callable(branch) for branch in branches):
|
|
raise TypeError("lax.switch: branches argument should be a sequence of callables.")
|
|
if operand is not _no_operand_sentinel:
|
|
if operands:
|
|
raise TypeError("if 'operand' keyword is passed then no positional "
|
|
f"operands can be passed, got operand={operand} "
|
|
f"and positional operands {operands}")
|
|
operands = (operand,)
|
|
del operand
|
|
|
|
if len(np.shape(index)) != 0:
|
|
raise TypeError(
|
|
f"Branch index must be scalar, "
|
|
f"got {index} of shape {np.shape(index)}.")
|
|
|
|
try:
|
|
index_dtype = dtypes.result_type(index)
|
|
except TypeError as err:
|
|
msg = f"Index type must be an integer, got {index}."
|
|
raise TypeError(msg) from err
|
|
|
|
if index_dtype.kind not in 'iu':
|
|
raise TypeError(
|
|
f"Index type must be an integer, got {index} as {index_dtype}")
|
|
|
|
branches = tuple(branches)
|
|
|
|
if len(branches) == 0:
|
|
raise ValueError("Empty branch sequence")
|
|
elif len(branches) == 1:
|
|
return branches[0](*operands)
|
|
|
|
index = lax.convert_element_type(index, np.int32)
|
|
lo = np.array(0, np.int32)
|
|
hi = np.array(len(branches) - 1, np.int32)
|
|
index = lax.clamp(lo, index, hi)
|
|
|
|
if (config.jax_disable_jit and
|
|
isinstance(core.get_aval(index), ConcreteArray)):
|
|
return branches[int(index)](*operands)
|
|
|
|
ops, ops_tree = tree_flatten(operands)
|
|
ops_avals = tuple(map(_abstractify, ops))
|
|
|
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
|
branches, ops_tree, ops_avals, primitive_name='switch')
|
|
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
|
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
|
out_trees[0], jaxprs[0].out_avals,
|
|
out_tree, jaxpr.out_avals)
|
|
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
|
disallowed_effects = joined_effects - allowed_effects
|
|
if disallowed_effects:
|
|
raise NotImplementedError(
|
|
f'Effects not supported in `switch`: {disallowed_effects}')
|
|
|
|
linear = (False,) * (len(consts) + len(ops))
|
|
out = cond_p.bind(
|
|
index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
|
|
return tree_unflatten(out_trees[0], out)
|
|
|
|
|
|
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
|
operand=_no_operand_sentinel, linear=None):
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
Provided arguments are correctly typed, ``cond()`` has equivalent
|
|
semantics to this Python implementation::
|
|
|
|
def cond(pred, true_fun, false_fun, *operands):
|
|
if pred:
|
|
return true_fun(*operands)
|
|
else:
|
|
return false_fun(*operands)
|
|
|
|
``pred`` must be a scalar type.
|
|
|
|
Args:
|
|
pred: Boolean scalar type, indicating which branch function to apply.
|
|
true_fun: Function (A -> B), to be applied if ``pred`` is True.
|
|
false_fun: Function (A -> B), to be applied if ``pred`` is False.
|
|
operands: Operands (A) input to either branch depending on ``pred``. The
|
|
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
|
|
thereof.
|
|
|
|
Returns:
|
|
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
|
|
depending on the value of ``pred``. The type can be a scalar, array, or any
|
|
pytree (nested Python tuple/list/dict) thereof.
|
|
"""
|
|
if not (callable(true_fun) and callable(false_fun)):
|
|
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
|
if operand is not _no_operand_sentinel:
|
|
if operands:
|
|
raise TypeError("if 'operand' keyword is passed then no positional "
|
|
f"operands can be passed, got operand={operand} "
|
|
f"and positional operands {operands}")
|
|
operands = (operand,)
|
|
del operand
|
|
|
|
if pred is None:
|
|
raise TypeError("cond predicate is None")
|
|
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
|
raise TypeError(
|
|
f"Pred must be a scalar, got {pred} of " +
|
|
(f"type {type(pred)}" if isinstance(pred, Sequence)
|
|
else f"shape {np.shape(pred)}."))
|
|
|
|
try:
|
|
pred_dtype = dtypes.result_type(pred)
|
|
except TypeError as err:
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
raise TypeError(msg.format(pred)) from err
|
|
|
|
if pred_dtype.kind != 'b':
|
|
if pred_dtype.kind in 'iuf':
|
|
pred = pred != 0
|
|
else:
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
raise TypeError(msg.format(pred_dtype))
|
|
|
|
if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray):
|
|
if pred:
|
|
return true_fun(*operands)
|
|
else:
|
|
return false_fun(*operands)
|
|
|
|
ops, ops_tree = tree_flatten(operands)
|
|
if linear is None:
|
|
linear_ops = [False] * len(ops)
|
|
else:
|
|
linear_ops, ops_tree2 = tree_flatten(linear)
|
|
if ops_tree != ops_tree2:
|
|
raise TypeError('linear tree and operand tree mismatch')
|
|
ops_avals = tuple(map(_abstractify, ops))
|
|
|
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
|
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
|
true_jaxpr, false_jaxpr = jaxprs
|
|
out_tree, false_out_tree = out_trees
|
|
|
|
_check_tree_and_avals("true_fun and false_fun output",
|
|
out_tree, true_jaxpr.out_avals,
|
|
false_out_tree, false_jaxpr.out_avals)
|
|
joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
|
|
disallowed_effects = joined_effects - allowed_effects
|
|
if disallowed_effects:
|
|
raise NotImplementedError(
|
|
f'Effects not supported in `cond`: {disallowed_effects}')
|
|
|
|
index = lax.convert_element_type(pred, np.int32)
|
|
|
|
linear = [False] * len(consts) + linear_ops
|
|
out = cond_p.bind(
|
|
index, *consts, *ops,
|
|
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
@api_boundary
|
|
@functools.wraps(_cond)
|
|
def cond(*args, **kwargs):
|
|
# detect an attempt to call the former, deprecated cond
|
|
try:
|
|
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
|
|
except TypeError:
|
|
pass
|
|
else:
|
|
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
|
|
_, _, maybe_true_fun, _, maybe_false_fun = ba.args
|
|
if callable(maybe_true_fun) and callable(maybe_false_fun):
|
|
return _cond_with_per_branch_args(*ba.args)
|
|
|
|
return _cond(*args, **kwargs)
|
|
|
|
def _cond_with_per_branch_args(pred,
|
|
true_operand, true_fun: Callable,
|
|
false_operand, false_fun: Callable):
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
Has equivalent semantics to this Python implementation::
|
|
|
|
def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
|
if pred:
|
|
return true_fun(true_operand)
|
|
else:
|
|
return false_fun(false_operand)
|
|
|
|
Pred has to be a scalar type, collection types (list, tuple) are not supported
|
|
"""
|
|
if not (callable(true_fun) and callable(false_fun)):
|
|
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
|
return _cond(pred,
|
|
lambda op: true_fun(op[0]),
|
|
lambda op: false_fun(op[1]),
|
|
(true_operand, false_operand))
|
|
|
|
def _cond_abstract_eval(*args, branches, **kwargs):
|
|
joined_effects = core.join_effects(*(b.effects for b in branches))
|
|
disallowed_effects = joined_effects - allowed_effects
|
|
if disallowed_effects:
|
|
raise NotImplementedError(
|
|
f'Effects not supported in `cond`: {disallowed_effects}')
|
|
joined_effects = core.join_effects(*(b.effects for b in branches))
|
|
return map(raise_to_shaped, branches[0].out_avals), joined_effects
|
|
|
|
def _bcast_select(pred, on_true, on_false):
|
|
if np.ndim(pred) != np.ndim(on_true):
|
|
idx = list(range(np.ndim(pred)))
|
|
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
|
|
return lax.select(pred, on_true, on_false)
|
|
|
|
def _bcast_select_n(pred, *cases):
|
|
if np.ndim(pred) != np.ndim(cases[0]):
|
|
idx = list(range(np.ndim(pred)))
|
|
pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
|
|
return lax.select_n(pred, *cases)
|
|
|
|
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
|
|
index, *ops = args
|
|
index_dim, *op_dims = dims
|
|
|
|
if index_dim is not batching.not_mapped:
|
|
# Convert to a lax.select. While we could get away with not broadcasting
|
|
# some operands yet, because all outputs must be broadcast together anyway
|
|
# for the select we broadcast the input operands for simplicity and leave
|
|
# optimizations to XLA.
|
|
# TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
|
|
index, *ops = (
|
|
batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims))
|
|
|
|
in_batched = [True] * len(branches[0].in_avals)
|
|
out_batched = [True] * len(branches[0].out_avals)
|
|
|
|
branches_batched = [
|
|
batching.batch_jaxpr(
|
|
jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0]
|
|
for jaxpr in branches]
|
|
|
|
branch_outs = []
|
|
for i, jaxpr in enumerate(branches_batched):
|
|
# Perform a select on the inputs for safety of reverse-mode autodiff; see
|
|
# https://github.com/google/jax/issues/1052
|
|
predicate = lax.eq(index, lax._const(index, i))
|
|
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]
|
|
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
|
|
out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
|
|
return out, [0 if b else None for b in out_batched]
|
|
else:
|
|
ops_bat = [d is not batching.not_mapped for d in op_dims]
|
|
ops = [batching.moveaxis(x, d, 0) if b else x
|
|
for b, x, d in zip(ops_bat, ops, op_dims)]
|
|
|
|
branches_out_bat = [
|
|
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1]
|
|
for jaxpr in branches]
|
|
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
|
|
branches_batched = tuple(
|
|
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0]
|
|
for jaxpr in branches)
|
|
|
|
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
|
out = cond_p.bind(
|
|
index, *ops, branches=branches_batched, linear=linear)
|
|
return out, out_dims
|
|
|
|
def _cond_jvp(primals, tangents, branches, linear):
|
|
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
index_nz, *ops_nz = nonzeros
|
|
assert index_nz is False
|
|
|
|
branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1]
|
|
for jaxpr in branches]
|
|
out_nz = [any(nz) for nz in zip(*branches_out_nz)]
|
|
|
|
branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
|
|
for jaxpr in branches)
|
|
|
|
index, *ops = primals
|
|
_, *ops_dot = tangents
|
|
ops_dot = _prune_zeros(ops_dot)
|
|
|
|
ops_lin = tuple(linear)
|
|
linear_jvp = ops_lin + (True,) * len(ops_dot)
|
|
out = cond_p.bind(
|
|
index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp)
|
|
out_primals, out_tangents = split_list(out, [len(out_nz)])
|
|
out_tangents_iter = iter(out_tangents)
|
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
|
|
for p, nz in zip(out_primals, out_nz)]
|
|
return out_primals, out_tangents
|
|
|
|
def _cond_partial_eval(trace, *tracers, branches, linear):
|
|
in_unknowns = [t.pval[0] is not None for t in tracers]
|
|
index_uk, *ops_uk = in_unknowns
|
|
|
|
if index_uk:
|
|
# When the branch index is unknown, we stage out the whole cond.
|
|
# TODO(mattjj): remove this path when old remat is removed
|
|
params = dict(branches=branches, linear=linear)
|
|
return trace.default_process_primitive(cond_p, tracers, params)
|
|
|
|
branches_out_uks = []
|
|
for branch_jaxpr in branches:
|
|
_, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(
|
|
branch_jaxpr, ops_uk, instantiate=False)
|
|
branches_out_uks.append(out_uks)
|
|
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
|
|
|
|
branches_known, branches_unknown, branch_res_avals = [], [], []
|
|
for branch_jaxpr in branches:
|
|
branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
|
|
pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
|
|
branches_known.append(branch_jaxpr_known)
|
|
branches_unknown.append(branch_jaxpr_unknown)
|
|
branch_res_avals.append(res_avals)
|
|
|
|
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
|
|
num_res = len(all_res_avals)
|
|
|
|
num_known_outs = len(out_uks) - sum(out_uks)
|
|
branches_known = _join_cond_outputs(
|
|
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
|
|
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
|
|
branches_unknown, all_res_avals, res_avals_per_branch)
|
|
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
|
for j in branches_known[1:])
|
|
|
|
in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
|
|
linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
|
|
out_consts_res = cond_p.bind(*in_consts, branches=branches_known,
|
|
linear=tuple(linear_known))
|
|
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
|
|
|
|
index_tracer = trace.instantiate_const(tracers[0])
|
|
ops_tracers = [trace.instantiate_const(t)
|
|
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
|
|
res_tracers = map(trace.new_instantiated_const, res)
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
|
for aval in branches_unknown[0].out_avals]
|
|
linear_unknown = ([False] * num_res +
|
|
[l for l, uk in zip(linear, in_unknowns[1:]) if uk])
|
|
params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
|
|
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
|
eqn = pe.new_eqn_recipe(
|
|
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
|
|
core.join_effects(*(j.effects for j in branches_unknown)), source)
|
|
for t in out_tracers: t.recipe = eqn
|
|
return util.merge_lists(out_uks, out_consts, out_tracers)
|
|
|
|
# TODO(mattjj): de-duplicate with _cond_partial_eval
|
|
def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
|
index_uk, *ops_uk = unks_in
|
|
branches = eqn.params['branches']
|
|
|
|
# Instantiate all inputs (b/c jaxpr_staged will take all inputs).
|
|
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
|
if type(x) is core.Var and not inst]
|
|
del inst_in
|
|
|
|
# NOTE(mattjj): I think it should be impossible for the index to be unknown,
|
|
# but asserting that caused a test failure in diffrax. So we handle it: if it
|
|
# is unknown, stage out the whole cond.
|
|
if index_uk:
|
|
all_true = [True] * len(branches[0].out_avals)
|
|
return None, eqn, all_true, all_true, new_inst
|
|
|
|
# First, compute output unknowns (unks_out), where an output of the cond is
|
|
# unknown if it would be unknown on any of the branches.
|
|
unks_out: List[bool] = [False] * len(eqn.outvars)
|
|
for jaxpr in branches:
|
|
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
|
|
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
|
|
ensure_out_unknowns=False, ensure_out_inst=True, saveable=saveable)
|
|
unks_out = map(operator.or_, unks_out, unks_out_)
|
|
|
|
# Next, use the computed output unknowns to build a known jaxpr and a staged
|
|
# jaxpr for each branch.
|
|
branches_known_ : List[core.ClosedJaxpr] = []
|
|
branches_staged_: List[core.ClosedJaxpr] = []
|
|
branch_res_avals: List[core.AbstractValue] = []
|
|
for jaxpr in branches:
|
|
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
|
|
pe.partial_eval_jaxpr_custom(
|
|
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
|
|
ensure_out_unknowns=unks_out, ensure_out_inst=True,
|
|
saveable=saveable)
|
|
branches_known_.append( core.ClosedJaxpr(jaxpr_known, jaxpr.consts))
|
|
branches_staged_.append(core.ClosedJaxpr(jaxpr_staged, jaxpr.consts))
|
|
branch_res_avals.append(branches_staged_[-1].in_avals[:num_res])
|
|
|
|
# Residuals may differ across branches, so we merge them, then use the merged
|
|
# residuals to join the outputs of all branches to the same type.
|
|
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
|
|
num_res = len(all_res_avals)
|
|
num_known_outs = len(unks_out) - sum(unks_out)
|
|
branches_known = _join_cond_outputs(
|
|
branches_known_, all_res_avals, res_avals_per_branch, num_known_outs)
|
|
branches_staged = _join_cond_pe_staged_jaxpr_inputs(
|
|
branches_staged_, all_res_avals, res_avals_per_branch)
|
|
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
|
for j in branches_known[1:])
|
|
|
|
# Create residual variables.
|
|
newvar = core.gensym()
|
|
res_binders = map(newvar, all_res_avals)
|
|
|
|
# Build the known eqn.
|
|
ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar
|
|
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
|
linear_known = [l for l, uk in zip(eqn.params['linear'], ops_uk) if not uk]
|
|
params_known = dict(branches=branches_known, linear=tuple(linear_known))
|
|
effects_known = core.join_effects(*(b.effects for b in branches_known))
|
|
eqn_known = pe.new_jaxpr_eqn(
|
|
ins_known, [*out_binders_known, *res_binders], cond_p, params_known,
|
|
effects_known, eqn.source_info)
|
|
|
|
# Build the staged eqn.
|
|
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
|
linear_staged = [False] * len(res_binders) + list(eqn.params['linear'])
|
|
params_staged = dict(branches=branches_staged, linear=tuple(linear_staged))
|
|
effects_staged = core.join_effects(*(b.effects for b in branches_staged))
|
|
eqn_staged = pe.new_jaxpr_eqn(
|
|
[eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged,
|
|
cond_p, params_staged, effects_staged, eqn.source_info)
|
|
|
|
new_vars = [*new_inst, *res_binders]
|
|
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
|
|
|
# When partially evaluating conditionals, each branch produces residuals
|
|
# depending on the computation carried out by the branch, and a corresponding
|
|
# staged jaxpr that accepts those residuals as its first few inputs. The
|
|
# residual-producing branches are staged as jaxprs and bound right away in a
|
|
# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
|
|
# conditional. The following helper functions ensure that both collections of
|
|
# jaxprs (those evaluated and those staged) are valid for joint use under their
|
|
# respective conditionals.
|
|
#
|
|
# In particular, the residuals derived from each original branch may have
|
|
# distinct types. Because the branches of conditionals must have identical type
|
|
# signatures, we join residuals together across branches into a common format.
|
|
|
|
# In order to set up a type signature that all branches can conform to, it would
|
|
# suffice to concatenate all branches' residuals. But concatenation can result
|
|
# in redundant inputs and outputs, and might lead to memory allocation that
|
|
# scales unnecessarily with the branch count. This function finds common
|
|
# residual types across branches for reuse, so as to avoid redundant
|
|
# allocation. It returns a list L of types (avals) representing the collection
|
|
# of residuals merged according to type, and, for each branch, a lookup table to
|
|
# match its residuals to their positions/types in L. Example input/output:
|
|
#
|
|
# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]]
|
|
# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]]
|
|
# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
|
|
def _merge_branch_residuals(branch_res_avals):
|
|
def enumerate_equal(xs):
|
|
counts = {v: itertools.count() for v in set(xs)}
|
|
return [(x, next(counts[x])) for x in xs]
|
|
branch_res_tagged_avals = map(enumerate_equal, branch_res_avals)
|
|
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
|
|
indices = {v: i for i, v in enumerate(all_tagged_avals)}
|
|
branch_indices = [
|
|
[indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
|
|
all_avals = [x for x, _ in all_tagged_avals]
|
|
return all_avals, branch_indices
|
|
|
|
# This function augments branch outputs to agree with the merged residual
|
|
# format: each branch is made to return zero-filled values in the places of
|
|
# residual outputs that it does not populate.
|
|
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
|
num_non_res_outputs):
|
|
def augment_jaxpr(jaxpr, res_indices):
|
|
@lu.wrap_init
|
|
def f_aug(*args):
|
|
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
|
|
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
|
|
aug_residuals = map(ad_util.zeros_like_aval, all_res_avals)
|
|
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
|
|
return outs + list(aug_residuals)
|
|
|
|
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
|
|
|
|
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
|
|
|
# This function augments branch inputs to agree with the merged residual format:
|
|
# each branch is made to accept all residuals, even though it will ignore those
|
|
# that it does not read.
|
|
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
|
res_aval_indices_per_jaxpr):
|
|
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
|
|
all_res_vars = map(newvar, all_res_avals)
|
|
|
|
def augment_jaxpr(jaxpr, res_indices):
|
|
num_res = len(res_indices)
|
|
res_vars = jaxpr.jaxpr.invars[:num_res]
|
|
non_res_vars = jaxpr.jaxpr.invars[num_res:]
|
|
|
|
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
|
|
aug_invars = aug_res_vars + non_res_vars
|
|
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
|
|
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
|
|
jaxpr.jaxpr.effects)
|
|
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
|
return jaxpr_aug
|
|
|
|
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
|
|
|
def _ordered_unique(xs):
|
|
d = collections.OrderedDict((x, None) for x in xs)
|
|
return list(d.keys())
|
|
|
|
def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
|
|
) -> Tuple[List[bool], core.JaxprEqn]:
|
|
closed_branches = eqn.params['branches']
|
|
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
|
|
|
|
# First, compute which inputs are used in any branch (not including `pred`).
|
|
used_inputs: List[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
|
|
for jaxpr in branches:
|
|
_, used_inputs_ = pe.dce_jaxpr(jaxpr, used_outputs, instantiate=False)
|
|
used_inputs = map(operator.or_, used_inputs, used_inputs_)
|
|
|
|
# Next, compute DCEd branches, instantiating according to used_inputs.
|
|
dce_branches_ = [pe.dce_jaxpr(jaxpr, used_outputs, instantiate=used_inputs)[0]
|
|
for jaxpr in branches]
|
|
dce_branches = [core.ClosedJaxpr(jaxpr, closed_jaxpr.consts)
|
|
for closed_jaxpr, jaxpr in zip(closed_branches, dce_branches_)]
|
|
|
|
# Finally, update parameters and form the new eqn.
|
|
dce_linear = [l for l, used in zip(eqn.params['linear'], used_inputs) if used]
|
|
new_params = dict(eqn.params, branches=tuple(dce_branches),
|
|
linear=tuple(dce_linear))
|
|
new_effects = core.join_effects(*(b.effects for b in dce_branches))
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
|
|
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
|
eqn.primitive, new_params, new_effects, eqn.source_info)
|
|
|
|
assert all(len(new_eqn.invars ) == 1 + len(jaxpr.in_avals )
|
|
for jaxpr in new_params['branches'])
|
|
assert all(len(new_eqn.outvars) == len(jaxpr.out_avals)
|
|
for jaxpr in new_params['branches'])
|
|
return [True, *used_inputs], new_eqn
|
|
|
|
|
|
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
|
|
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
|
primal_avals = map(raise_to_shaped, primal_avals)
|
|
|
|
@lu.wrap_init
|
|
def transposed(*args):
|
|
res, cts_out = split_list(args, [num_res])
|
|
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
|
cts_in = ad.backward_pass(
|
|
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
|
|
_, cts_in = split_list(cts_in, [num_res])
|
|
return map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
|
|
|
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
|
|
|
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
|
|
index, *ops = args
|
|
in_avals = map(raise_to_shaped, branches[0].in_avals)
|
|
num_res = len(ops) - sum(linear)
|
|
|
|
branches_trans = tuple(
|
|
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
|
|
lin_in_avals = [raise_to_shaped(a, weak_type=False)
|
|
for a, l in zip(in_avals, linear) if l]
|
|
assert all(core.typematch(out_aval, lin_in_aval)
|
|
for jaxpr in branches_trans
|
|
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
|
|
|
|
res = ops[:num_res]
|
|
cts = map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
|
|
linear_trans = (False,) * num_res + (True,) * len(cts)
|
|
|
|
out = cond_p.bind(
|
|
index, *res, *cts, branches=branches_trans, linear=linear_trans)
|
|
assert all(map(core.typecheck, lin_in_avals, out))
|
|
|
|
out_iter = iter(out)
|
|
out = [next(out_iter) if l else None for l in linear]
|
|
assert next(out_iter, None) is None
|
|
return [None] + out
|
|
|
|
def _cond_axis_substitution(params, subst, traverse):
|
|
if not traverse:
|
|
return params
|
|
branches = tuple(core.subst_axis_names_jaxpr(jaxpr, subst) for jaxpr in params['branches'])
|
|
return dict(params, branches=branches)
|
|
|
|
def _cond_typecheck(*in_atoms, branches, linear):
|
|
avals = [x.aval for x in in_atoms]
|
|
tc = partial(_typecheck_param, 'cond')
|
|
tc(branches, 'branches', 'tuple of ClosedJaxpr',
|
|
type(branches) is tuple and
|
|
all(type(x) is core.ClosedJaxpr for x in branches))
|
|
tc(linear, 'linear', 'tuple of bool',
|
|
type(linear) is tuple and all(type(x) is bool for x in linear))
|
|
|
|
if len(branches) == 0:
|
|
raise core.JaxprTypeError('cond requires at least one branch function')
|
|
if len(linear) + 1 != len(avals):
|
|
raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for '
|
|
f'{len(avals) - 1} non-predicate operands')
|
|
|
|
jaxpr0 = branches[0]
|
|
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
|
|
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
|
|
joined_effects = core.join_effects(*(b.effects for b in branches))
|
|
disallowed_effects = joined_effects - allowed_effects
|
|
if disallowed_effects:
|
|
raise NotImplementedError(
|
|
f'Effects not supported in `cond`: {disallowed_effects}')
|
|
|
|
for i, jaxpr in enumerate(branches[1:]):
|
|
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
|
|
raise core.JaxprTypeError(
|
|
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
|
|
f'branch {i+1} takes {len(jaxpr.in_avals)}')
|
|
if len(jaxpr0.out_avals) != len(jaxpr.out_avals):
|
|
raise core.JaxprTypeError(
|
|
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
|
|
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
|
|
if not all(map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
|
|
raise core.JaxprTypeError(
|
|
f'cond branches 0 and {i+1} have mismatching input types: '
|
|
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
|
|
if not all(map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
|
|
raise core.JaxprTypeError(
|
|
f'cond branches 0 and {i+1} have mismatching output types: '
|
|
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
|
|
|
|
if len(avals) != 1 + len(jaxpr0.in_avals):
|
|
raise core.JaxprTypeError(
|
|
f'cond called with {len(avals) - 1} non-predicate operands, '
|
|
f'but branches take {len(jaxpr0.in_avals)} inputs')
|
|
|
|
index_aval, *op_avals = avals
|
|
if index_aval.dtype != np.int32:
|
|
raise core.JaxprTypeError(
|
|
f'cond called with index of type {index_aval.dtype} instead of int32')
|
|
if not all(map(core.typecompat, jaxpr0.in_avals, op_avals)):
|
|
raise core.JaxprTypeError(
|
|
f'cond branches take input types {jaxpr0_in_avals_str}, '
|
|
f'called with operands of type {_avals_short(op_avals)}')
|
|
if any(b.effects != branches[0].effects for b in branches[1:]):
|
|
raise core.JaxprTypeError(
|
|
f'cond branches must have matching effect types: '
|
|
f'{[b.effects for b in branches]}')
|
|
joined_effects = core.join_effects(*(b.effects for b in branches))
|
|
return jaxpr0.out_avals, joined_effects
|
|
|
|
def cond_bind(*args, branches, linear):
|
|
if config.jax_enable_checks:
|
|
avals = map(core.get_aval, args)
|
|
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
|
|
_cond_typecheck(*in_atoms, branches=branches, linear=linear)
|
|
for jaxpr in branches:
|
|
core.check_jaxpr(jaxpr.jaxpr)
|
|
return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear)
|
|
|
|
cond_p = core.AxisPrimitive('cond')
|
|
cond_p.multiple_results = True
|
|
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
|
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
|
|
cond_p.def_custom_bind(cond_bind)
|
|
ad.primitive_jvps[cond_p] = _cond_jvp
|
|
ad.reducing_transposes[cond_p] = _cond_transpose
|
|
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
|
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
|
|
xla.register_initial_style_primitive(cond_p)
|
|
core.custom_typechecks[cond_p] = _cond_typecheck
|
|
core.axis_substitution_rules[cond_p] = _cond_axis_substitution
|
|
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
|
|
pe.dce_rules[cond_p] = _cond_dce_rule
|
|
|
|
def _cond_lowering(ctx, index, *args, branches, linear):
|
|
del linear # Unused.
|
|
joined_effects = core.join_effects(*(branch.effects for branch in branches))
|
|
ordered_effects = [eff for eff in joined_effects
|
|
if eff in core.ordered_effects]
|
|
num_tokens = len(ordered_effects)
|
|
tokens_in = ctx.tokens_in.subset(ordered_effects)
|
|
output_token_types = [mlir.token_type() for _ in ordered_effects]
|
|
output_types = [
|
|
*output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)]
|
|
flat_output_types = util.flatten(output_types)
|
|
|
|
# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
|
|
# have no arguments; the computation within the block uses implicit
|
|
# captures.
|
|
case_op = mhlo.CaseOp(flat_output_types, index=index,
|
|
num_branches=len(branches))
|
|
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
|
|
for i, jaxpr in enumerate(branches):
|
|
branch = case_op.regions[i].blocks.append()
|
|
with ir.InsertionPoint(branch):
|
|
sub_ctx = ctx.module_context.replace(
|
|
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
|
|
out_vals, tokens_out = mlir.jaxpr_subcomp(
|
|
sub_ctx, jaxpr.jaxpr, tokens_in,
|
|
map(mlir.ir_constants, jaxpr.consts),
|
|
*map(mlir.wrap_singleton_ir_values, args))
|
|
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
|
|
out_vals = [*out_tokens, *out_vals]
|
|
mhlo.ReturnOp(util.flatten(out_vals))
|
|
|
|
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
|
|
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
|
|
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
|
|
return outputs
|
|
|
|
mlir.register_lowering(cond_p, _cond_lowering)
|