rocm_jax/jax/_src/lax/control_flow/conditionals.py

867 lines
36 KiB
Python
Raw Normal View History

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for conditional control flow primitives."""
import collections
import functools
from functools import partial
import inspect
import itertools
import operator
from typing import Callable, Sequence, List, Tuple
from jax import core
from jax._src import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
2023-01-18 10:17:01 -08:00
from jax._src.core import replace_jaxpr_effects
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src import state
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import (safe_map, extend_name_stack, split_list,
partition_list)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_avals_short,
_check_tree_and_avals,
_initial_style_jaxprs_with_common_consts,
_make_closed_jaxpr,
_prune_zeros,
_typecheck_param,
allowed_effects,
)
map, unsafe_map = safe_map, map
# For backward compatibility with a previous switch/cond calling convention,
# we allow a single (pytree) `operand` argument to be passed by keyword. We use
# a sentinel object as its default value to indicate when it is _not_ passed.
_no_operand_sentinel = object()
@api_boundary
def switch(index, branches: Sequence[Callable], *operands,
operand=_no_operand_sentinel):
"""Apply exactly one of ``branches`` given by ``index``.
If ``index`` is out of bounds, it is clamped to within bounds.
Has the semantics of the following Python::
def switch(index, branches, *operands):
index = clamp(0, index, len(branches) - 1)
return branches[index](*operands)
Args:
index: Integer scalar type, indicating which branch function to apply.
branches: Sequence of functions (A -> B) to be applied based on ``index``.
operands: Operands (A) input to whichever branch is applied.
Returns:
Value (B) of ``branch(*operands)`` for the branch that was selected based
on ``index``.
"""
if not all(callable(branch) for branch in branches):
raise TypeError("lax.switch: branches argument should be a sequence of callables.")
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
2022-12-01 09:12:01 -08:00
f"operands can be passed, got {operand=} "
f"and positional operands {operands}")
operands = (operand,)
del operand
if len(np.shape(index)) != 0:
raise TypeError(
f"Branch index must be scalar, "
f"got {index} of shape {np.shape(index)}.")
try:
index_dtype = dtypes.result_type(index)
except TypeError as err:
msg = f"Index type must be an integer, got {index}."
raise TypeError(msg) from err
if index_dtype.kind not in 'iu':
raise TypeError(
f"Index type must be an integer, got {index} as {index_dtype}")
branches = tuple(branches)
if len(branches) == 0:
raise ValueError("Empty branch sequence")
elif len(branches) == 1:
return branches[0](*operands)
index = lax.convert_element_type(index, np.int32)
lo = np.array(0, np.int32)
hi = np.array(len(branches) - 1, np.int32)
index = lax.clamp(lo, index, hi)
if (config.jax_disable_jit and
isinstance(core.get_aval(index), ConcreteArray)):
return branches[int(index)](*operands)
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
out_trees[0], jaxprs[0].out_avals,
out_tree, jaxpr.out_avals)
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `switch`: {disallowed_effects}')
if joined_effects:
# Raise index in case of effects to allow data-dependence-based discharging
# of those effects (even if they don't have an explicit data dependence).
index = core.raise_as_much_as_possible(index)
linear = (False,) * (len(consts) + len(ops))
out = cond_p.bind(
index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
return tree_unflatten(out_trees[0], out)
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Wraps XLA's `Conditional
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
operator.
2022-07-26 13:12:16 -07:00
Provided arguments are correctly typed, ``cond()`` has equivalent
semantics to this Python implementation, where ``pred`` must be a
scalar type::
def cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
the two branches is executed (up to compiler rewrites and optimizations).
However, when transformed with :func:`~jax.vmap` to operate over a batch of
predicates, ``cond`` is converted to :func:`~jax.lax.select`.
Args:
pred: Boolean scalar type, indicating which branch function to apply.
true_fun: Function (A -> B), to be applied if ``pred`` is True.
false_fun: Function (A -> B), to be applied if ``pred`` is False.
operands: Operands (A) input to either branch depending on ``pred``. The
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
thereof.
Returns:
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
depending on the value of ``pred``. The type can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof.
"""
if not (callable(true_fun) and callable(false_fun)):
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
2022-12-01 09:12:01 -08:00
f"operands can be passed, got {operand=} "
f"and positional operands {operands}")
operands = (operand,)
del operand
2022-07-26 13:12:16 -07:00
if pred is None:
raise TypeError("cond predicate is None")
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
raise TypeError(
f"Pred must be a scalar, got {pred} of " +
(f"type {type(pred)}" if isinstance(pred, Sequence)
else f"shape {np.shape(pred)}."))
try:
pred_dtype = dtypes.result_type(pred)
except TypeError as err:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred)) from err
if pred_dtype.kind != 'b':
if pred_dtype.kind in 'iuf':
pred = pred != 0
else:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred_dtype))
if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
ops, ops_tree = tree_flatten(operands)
if linear is None:
linear_ops = [False] * len(ops)
else:
linear_ops, ops_tree2 = tree_flatten(linear)
if ops_tree != ops_tree2:
raise TypeError('linear tree and operand tree mismatch')
ops_avals = tuple(map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
if any(isinstance(op_aval, state.ShapedArrayRef) for op_aval in ops_avals):
raise ValueError("Cannot pass `Ref`s into `cond`.")
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees
_check_tree_and_avals("true_fun and false_fun output",
out_tree, true_jaxpr.out_avals,
false_out_tree, false_jaxpr.out_avals)
joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
index = lax.convert_element_type(pred, np.int32)
if joined_effects:
# Raise index in case of effects to allow data-dependence-based discharging
# of those effects (even if they don't have an explicit data dependence).
index = core.raise_as_much_as_possible(index)
2023-01-18 10:17:01 -08:00
false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects)
true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects)
linear = [False] * len(consts) + linear_ops
out = cond_p.bind(
index, *consts, *ops,
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
return tree_unflatten(out_tree, out)
@api_boundary
@functools.wraps(_cond)
def cond(*args, **kwargs):
# detect an attempt to call the former, deprecated cond
try:
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
except TypeError:
pass
else:
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
_, _, maybe_true_fun, _, maybe_false_fun = ba.args
if callable(maybe_true_fun) and callable(maybe_false_fun):
return _cond_with_per_branch_args(*ba.args)
return _cond(*args, **kwargs)
def _cond_with_per_branch_args(pred,
true_operand, true_fun: Callable,
false_operand, false_fun: Callable):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Has equivalent semantics to this Python implementation::
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
Pred has to be a scalar type, collection types (list, tuple) are not supported
"""
if not (callable(true_fun) and callable(false_fun)):
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
return _cond(pred,
lambda op: true_fun(op[0]),
lambda op: false_fun(op[1]),
(true_operand, false_operand))
def _cond_abstract_eval(*avals, branches, **_):
joined_effects = core.join_effects(*(b.effects for b in branches))
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
joined_effects = core.join_effects(*(b.effects for b in branches))
state_effects = {eff for eff in joined_effects if isinstance(eff,
state.RefEffect)}
jaxpr_aval_effects = state.get_ref_state_effects(
[v.aval for v in branches[0].jaxpr.invars], joined_effects)
aval_effects = [set(eff.replace(ref_aval=aval) for eff in effs) for aval, effs
in zip(avals[1:], jaxpr_aval_effects)
if isinstance(aval, state.ShapedArrayRef)]
nonlocal_state_effects = core.join_effects(*aval_effects)
all_effects = (joined_effects - state_effects) | nonlocal_state_effects
return map(raise_to_shaped, branches[0].out_avals), all_effects
def _bcast_select(pred, on_true, on_false):
if np.ndim(pred) != np.ndim(on_true):
idx = list(range(np.ndim(pred)))
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
return lax.select(pred, on_true, on_false)
def _bcast_select_n(pred, *cases):
if np.ndim(pred) != np.ndim(cases[0]):
idx = list(range(np.ndim(pred)))
pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
return lax.select_n(pred, *cases)
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
index, *ops = args
index_dim, *op_dims = dims
2022-11-10 12:00:21 -08:00
# TODO(sharadmv): clean this up by adding a specific blocklist
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError(
2022-11-10 12:00:21 -08:00
"State effect not supported in vmap-of-cond.")
from jax._src.callback import _IOEffect, _OrderedIOEffect
if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect]
for branch in branches):
raise NotImplementedError(
"IO effect not supported in vmap-of-cond.")
if index_dim is not batching.not_mapped:
# Convert to a lax.select. While we could get away with not broadcasting
# some operands yet, because all outputs must be broadcast together anyway
# for the select we broadcast the input operands for simplicity and leave
# optimizations to XLA.
# TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
index, *ops = (
batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims))
in_batched = [True] * len(branches[0].in_avals)
out_batched = [True] * len(branches[0].out_avals)
branches_batched = [
batching.batch_jaxpr(
jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0]
for jaxpr in branches]
branch_outs = []
for i, jaxpr in enumerate(branches_batched):
# Perform a select on the inputs for safety of reverse-mode autodiff; see
# https://github.com/google/jax/issues/1052
predicate = lax.eq(index, lax._const(index, i))
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
return out, [0 if b else None for b in out_batched]
else:
ops_bat = [d is not batching.not_mapped for d in op_dims]
ops = [batching.moveaxis(x, d, 0) if b else x
for b, x, d in zip(ops_bat, ops, op_dims)]
branches_out_bat = [
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1]
for jaxpr in branches]
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
branches_batched = tuple(
batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0]
for jaxpr in branches)
out_dims = [0 if b else batching.not_mapped for b in out_bat]
out = cond_p.bind(
index, *ops, branches=branches_batched, linear=linear)
return out, out_dims
def _cond_jvp(primals, tangents, branches, linear):
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
index_nz, *ops_nz = nonzeros
assert index_nz is False
branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1]
for jaxpr in branches]
out_nz = [any(nz) for nz in zip(*branches_out_nz)]
branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
for jaxpr in branches)
index, *ops = primals
_, *ops_dot = tangents
ops_dot = _prune_zeros(ops_dot)
ops_lin = tuple(linear)
linear_jvp = ops_lin + (True,) * len(ops_dot)
out = cond_p.bind(
index, *ops, *ops_dot, branches=branches_jvp, linear=linear_jvp)
out_primals, out_tangents = split_list(out, [len(out_nz)])
out_tangents_iter = iter(out_tangents)
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(out_primals, out_nz)]
return out_primals, out_tangents
def _cond_partial_eval(trace, *tracers, branches, linear):
in_unknowns = [t.pval[0] is not None for t in tracers]
index_uk, *ops_uk = in_unknowns
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError(
"State effect not supported in cond partial-eval.")
if index_uk:
# When the branch index is unknown, we stage out the whole cond.
# TODO(mattjj): remove this path when old remat is removed
params = dict(branches=branches, linear=linear)
return trace.default_process_primitive(cond_p, tracers, params)
branches_out_uks = []
for branch_jaxpr in branches:
_, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(
branch_jaxpr, ops_uk, instantiate=False)
branches_out_uks.append(out_uks)
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
branches_known, branches_unknown, branch_res_avals = [], [], []
for branch_jaxpr in branches:
branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
branches_known.append(branch_jaxpr_known)
branches_unknown.append(branch_jaxpr_unknown)
branch_res_avals.append(res_avals)
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
num_res = len(all_res_avals)
num_known_outs = len(out_uks) - sum(out_uks)
branches_known = _join_cond_outputs(
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
branches_unknown, all_res_avals, res_avals_per_branch)
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
for j in branches_known[1:])
in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
out_consts_res = cond_p.bind(*in_consts, branches=branches_known,
linear=tuple(linear_known))
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
index_tracer = trace.instantiate_const(tracers[0])
ops_tracers = [trace.instantiate_const(t)
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
res_tracers = map(trace.new_instantiated_const, res)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in branches_unknown[0].out_avals]
linear_unknown = ([False] * num_res +
[l for l, uk in zip(linear, in_unknowns[1:]) if uk])
params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
core.join_effects(*(j.effects for j in branches_unknown)), source)
for t in out_tracers: t.recipe = eqn
return util.merge_lists(out_uks, out_consts, out_tracers)
# TODO(mattjj): de-duplicate with _cond_partial_eval
def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
index_uk, *ops_uk = unks_in
branches = eqn.params['branches']
# Instantiate all inputs (b/c jaxpr_staged will take all inputs).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
del inst_in
# NOTE(mattjj): I think it should be impossible for the index to be unknown,
# but asserting that caused a test failure in diffrax. So we handle it: if it
# is unknown, stage out the whole cond.
if index_uk:
all_true = [True] * len(branches[0].out_avals)
return None, eqn, all_true, all_true, new_inst
# First, compute output unknowns (unks_out), where an output of the cond is
# unknown if it would be unknown on any of the branches.
unks_out: List[bool] = [False] * len(eqn.outvars)
for jaxpr in branches:
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
ensure_out_unknowns=False, ensure_out_inst=True, saveable=saveable)
unks_out = map(operator.or_, unks_out, unks_out_)
# Next, use the computed output unknowns to build a known jaxpr and a staged
# jaxpr for each branch.
branches_known_ : List[core.ClosedJaxpr] = []
branches_staged_: List[core.ClosedJaxpr] = []
branch_res_avals: List[core.AbstractValue] = []
for jaxpr in branches:
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
ensure_out_unknowns=unks_out, ensure_out_inst=True,
saveable=saveable)
branches_known_.append( core.ClosedJaxpr(jaxpr_known, jaxpr.consts))
branches_staged_.append(core.ClosedJaxpr(jaxpr_staged, jaxpr.consts))
branch_res_avals.append(branches_staged_[-1].in_avals[:num_res])
# Residuals may differ across branches, so we merge them, then use the merged
# residuals to join the outputs of all branches to the same type.
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
num_res = len(all_res_avals)
num_known_outs = len(unks_out) - sum(unks_out)
branches_known = _join_cond_outputs(
branches_known_, all_res_avals, res_avals_per_branch, num_known_outs)
branches_staged = _join_cond_pe_staged_jaxpr_inputs(
branches_staged_, all_res_avals, res_avals_per_branch)
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
for j in branches_known[1:])
# Create residual variables.
newvar = core.gensym()
res_binders = map(newvar, all_res_avals)
# Build the known eqn.
ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
linear_known = [l for l, uk in zip(eqn.params['linear'], ops_uk) if not uk]
params_known = dict(branches=branches_known, linear=tuple(linear_known))
effects_known = core.join_effects(*(b.effects for b in branches_known))
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*out_binders_known, *res_binders], cond_p, params_known,
effects_known, eqn.source_info)
# Build the staged eqn.
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = [False] * len(res_binders) + list(eqn.params['linear'])
params_staged = dict(branches=branches_staged, linear=tuple(linear_staged))
effects_staged = core.join_effects(*(b.effects for b in branches_staged))
eqn_staged = pe.new_jaxpr_eqn(
[eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged,
cond_p, params_staged, effects_staged, eqn.source_info)
new_vars = [*new_inst, *res_binders]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
# When partially evaluating conditionals, each branch produces residuals
# depending on the computation carried out by the branch, and a corresponding
# staged jaxpr that accepts those residuals as its first few inputs. The
# residual-producing branches are staged as jaxprs and bound right away in a
# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
# conditional. The following helper functions ensure that both collections of
# jaxprs (those evaluated and those staged) are valid for joint use under their
# respective conditionals.
#
# In particular, the residuals derived from each original branch may have
# distinct types. Because the branches of conditionals must have identical type
# signatures, we join residuals together across branches into a common format.
# In order to set up a type signature that all branches can conform to, it would
# suffice to concatenate all branches' residuals. But concatenation can result
# in redundant inputs and outputs, and might lead to memory allocation that
# scales unnecessarily with the branch count. This function finds common
# residual types across branches for reuse, so as to avoid redundant
# allocation. It returns a list L of types (avals) representing the collection
# of residuals merged according to type, and, for each branch, a lookup table to
# match its residuals to their positions/types in L. Example input/output:
#
# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]]
# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]]
# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
def _merge_branch_residuals(branch_res_avals):
def enumerate_equal(xs):
counts = {v: itertools.count() for v in set(xs)}
return [(x, next(counts[x])) for x in xs]
branch_res_tagged_avals = map(enumerate_equal, branch_res_avals)
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
indices = {v: i for i, v in enumerate(all_tagged_avals)}
branch_indices = [
[indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
all_avals = [x for x, _ in all_tagged_avals]
return all_avals, branch_indices
# This function augments branch outputs to agree with the merged residual
# format: each branch is made to return zero-filled values in the places of
# residual outputs that it does not populate.
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
num_non_res_outputs):
def augment_jaxpr(jaxpr, res_indices):
@lu.wrap_init
def f_aug(*args):
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
aug_residuals = map(ad_util.zeros_like_aval, all_res_avals)
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
return outs + list(aug_residuals)
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
# This function augments branch inputs to agree with the merged residual format:
# each branch is made to accept all residuals, even though it will ignore those
# that it does not read.
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
res_aval_indices_per_jaxpr):
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
all_res_vars = map(newvar, all_res_avals)
def augment_jaxpr(jaxpr, res_indices):
num_res = len(res_indices)
res_vars = jaxpr.jaxpr.invars[:num_res]
non_res_vars = jaxpr.jaxpr.invars[num_res:]
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
aug_invars = aug_res_vars + non_res_vars
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects)
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
return jaxpr_aug
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
def _ordered_unique(xs):
d = collections.OrderedDict((x, None) for x in xs)
return list(d.keys())
def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
) -> Tuple[List[bool], core.JaxprEqn]:
closed_branches = eqn.params['branches']
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
# First, compute which inputs are used in any branch (not including `pred`).
used_inputs: List[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
for jaxpr in branches:
_, used_inputs_ = pe.dce_jaxpr(jaxpr, used_outputs, instantiate=False)
used_inputs = map(operator.or_, used_inputs, used_inputs_)
# Next, compute DCEd branches, instantiating according to used_inputs.
dce_branches_ = [pe.dce_jaxpr(jaxpr, used_outputs, instantiate=used_inputs)[0]
for jaxpr in branches]
dce_branches = [core.ClosedJaxpr(jaxpr, closed_jaxpr.consts)
for closed_jaxpr, jaxpr in zip(closed_branches, dce_branches_)]
# Finally, update parameters and form the new eqn.
dce_linear = [l for l, used in zip(eqn.params['linear'], used_inputs) if used]
new_params = dict(eqn.params, branches=tuple(dce_branches),
linear=tuple(dce_linear))
new_effects = core.join_effects(*(b.effects for b in dce_branches))
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_effects, eqn.source_info)
assert all(len(new_eqn.invars ) == 1 + len(jaxpr.in_avals )
for jaxpr in new_params['branches'])
assert all(len(new_eqn.outvars) == len(jaxpr.out_avals)
for jaxpr in new_params['branches'])
return [True, *used_inputs], new_eqn
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = map(raise_to_shaped, primal_avals)
@lu.wrap_init
def transposed(*args):
res, cts_out = split_list(args, [num_res])
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
cts_in = ad.backward_pass(
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return map(ad.instantiate_zeros_aval, primal_avals, cts_in)
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
del linear # could use for error checking, but see #14026
index, *ops = args
linear = [type(x) is ad.UndefinedPrimal for x in ops]
in_avals = map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError("State effect not supported in cond transpose.")
branches_trans = tuple(
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
lin_in_avals = [raise_to_shaped(a, weak_type=False)
for a, l in zip(in_avals, linear) if l]
assert all(core.typematch(out_aval, lin_in_aval)
for jaxpr in branches_trans
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
res = ops[:num_res]
cts = map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
linear_trans = (False,) * num_res + (True,) * len(cts)
out = cond_p.bind(
index, *res, *cts, branches=branches_trans, linear=linear_trans)
assert all(map(core.typecheck, lin_in_avals, out))
out_iter = iter(out)
out = [next(out_iter) if l else None for l in linear]
assert next(out_iter, None) is None
return [None] + out
def _cond_axis_substitution(params, subst, traverse):
if not traverse:
return params
branches = tuple(core.subst_axis_names_jaxpr(jaxpr, subst) for jaxpr in params['branches'])
return dict(params, branches=branches)
def _cond_typecheck(*in_atoms, branches, linear):
avals = [x.aval for x in in_atoms]
tc = partial(_typecheck_param, 'cond')
tc(branches, 'branches', 'tuple of ClosedJaxpr',
type(branches) is tuple and
all(type(x) is core.ClosedJaxpr for x in branches))
tc(linear, 'linear', 'tuple of bool',
type(linear) is tuple and all(type(x) is bool for x in linear))
if len(branches) == 0:
raise core.JaxprTypeError('cond requires at least one branch function')
if len(linear) + 1 != len(avals):
raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for '
f'{len(avals) - 1} non-predicate operands')
jaxpr0 = branches[0]
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
joined_effects = core.join_effects(*(b.effects for b in branches))
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
for i, jaxpr in enumerate(branches[1:]):
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
raise core.JaxprTypeError(
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
f'branch {i+1} takes {len(jaxpr.in_avals)}')
if len(jaxpr0.out_avals) != len(jaxpr.out_avals):
raise core.JaxprTypeError(
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
if not all(map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching input types: '
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
if not all(map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching output types: '
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
if len(avals) != 1 + len(jaxpr0.in_avals):
raise core.JaxprTypeError(
f'cond called with {len(avals) - 1} non-predicate operands, '
f'but branches take {len(jaxpr0.in_avals)} inputs')
index_aval, *op_avals = avals
if index_aval.dtype != np.int32:
raise core.JaxprTypeError(
f'cond called with index of type {index_aval.dtype} instead of int32')
if not all(map(core.typecompat, jaxpr0.in_avals, op_avals)):
raise core.JaxprTypeError(
f'cond branches take input types {jaxpr0_in_avals_str}, '
f'called with operands of type {_avals_short(op_avals)}')
joined_effects = core.join_effects(*(b.effects for b in branches))
return jaxpr0.out_avals, joined_effects
def cond_bind(*args, branches, linear):
if config.jax_enable_checks:
avals = map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_cond_typecheck(*in_atoms, branches=branches, linear=linear)
for jaxpr in branches:
core.check_jaxpr(jaxpr.jaxpr)
return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear)
cond_p = core.AxisPrimitive('cond')
cond_p.multiple_results = True
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
xla.register_initial_style_primitive(cond_p)
core.custom_typechecks[cond_p] = _cond_typecheck
core.axis_substitution_rules[cond_p] = _cond_axis_substitution
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
pe.dce_rules[cond_p] = _cond_dce_rule
def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
joined_effects = core.join_effects(*(branch.effects for branch in branches))
ordered_effects = [eff for eff in joined_effects
if eff in core.ordered_effects]
num_tokens = len(ordered_effects)
tokens_in = ctx.tokens_in.subset(ordered_effects)
output_token_types = [mlir.token_type() for _ in ordered_effects]
output_types = [
*output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)]
flat_output_types = util.flatten(output_types)
# CaseOp takes a single argument 'index' and the corresponding blocks
# have no arguments; the computation within the block uses implicit
# captures.
case_op = hlo.CaseOp(flat_output_types, index=index,
num_branches=len(branches))
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
out_vals, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr.jaxpr, tokens_in,
map(mlir.ir_constants, jaxpr.consts),
[jax2tf] An alternative support for shape polymorphism for native serialization. jax2tf already supports many cases of shape polymorphism, e.g., those where the shapes of all intermediates can be expressed as polynomials in the dimension variables in the input. We want to achieve the same same coverage, or more, while using StableHLO as the lowering format, rather than tf.Graph. For native serialization we will support two lowering implementations: * one is using the growing support in JAX for dynamic shapes, of which shape polymorphism is a special case. This implementation is enabled with the --jax_dynamic_shapes flag. At the moment, the JAX dynamic shapes support is still incomplete and over 300 jax2tf shape polymorphism tests fail. * a new one (added) here in which we form a Jaxpr using abstract values that express dimension sizes as dimension polynomials (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO. This implementation is enabled when --jax_dynamic_shapes is off. With this implementation only 50 jax2tf tests fail (to be fixed separately). The key contribution here is to enable lowering a Jaxpr that contains dimension polynomials in some of the intermediate shapes. Many lowering rules already have some partial support for Jaxprs where the shapes contain `Var`s. To the extent possible, we try to write lowering rules that should cover both cases of dynamic shapes: Var or polynomials in shapes. The lowering convention is that at top level we collect the sorted list of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars. All IR functions will take N additional prefix arguments of int32 type containing the values of the dimension variables. This is stored as a list of `ir.Value` in `LoweringContext.dim_var_values`. Note that the Jaxprs are not changed to have extra Vars for the dimension variable values. An alternative implementation could work by transforming the Jaxpr to replace dimension polynomials into Vars. The key code pattern used in the lowering rule is:: if not core.is_constant_shape(shape): # Handles both Var, and polynomials shape = mlir.eval_dynamic_shape(ctx, shape) return mhlo.DynamicXXX(..., shape) else: return mhlo.XXX(..., shape) with `mlir.eval_dynamic_shape` handling both cases:: def eval_dynamic_shape(ctx, shape): if config.jax_dynamic_shapes: # Using Var return ... subst using ctx.axis_size_env ... else: # Using polynomials return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values In order to support the above some lowering functions need to take a LoweringContext parameter, e.g., mlir.broadcast_mhlo. I expect that the changes here will improve the --jax_dynamic_shapes coverage as well.
2022-11-28 13:16:07 +01:00
*map(mlir.wrap_singleton_ir_values, args),
dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
hlo.ReturnOp(util.flatten(out_vals))
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
return outputs
mlir.register_lowering(cond_p, _cond_lowering)
@state.register_discharge_rule(cond_p)
def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
discharged_branches = tuple(
core.ClosedJaxpr(state.discharge_state(branch.jaxpr, ())[0], ())
for branch in branches)
out_vals = cond_p.bind(*args, branches=discharged_branches, linear=linear)
out_ref_vals, out_vals = util.split_list(
out_vals, [len(out_vals) - len(out_avals)])
ref_val_iter = iter(out_ref_vals)
new_invals = []
for aval in in_avals:
new_invals.append(
next(ref_val_iter) if isinstance(aval, state.ShapedArrayRef) else None)
return new_invals, out_vals