2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2022 The JAX Authors.
|
2022-06-02 11:50:03 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Module for conditional control flow primitives."""
|
2023-10-27 03:04:39 +03:00
|
|
|
from __future__ import annotations
|
2023-07-21 14:20:39 -04:00
|
|
|
|
2022-06-02 11:50:03 -07:00
|
|
|
import collections
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable, Sequence
|
2022-06-02 11:50:03 -07:00
|
|
|
import functools
|
|
|
|
from functools import partial
|
|
|
|
import inspect
|
|
|
|
import itertools
|
2022-07-19 08:53:23 -07:00
|
|
|
import operator
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import Any, TypeVar
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
from jax.tree_util import tree_flatten, tree_unflatten
|
|
|
|
from jax._src import ad_util
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
from jax._src import api_util
|
2023-10-11 08:45:30 -07:00
|
|
|
from jax._src import config
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import core
|
2023-03-27 13:29:59 -07:00
|
|
|
from jax._src import dispatch
|
2022-06-02 11:50:03 -07:00
|
|
|
from jax._src import dtypes
|
2023-02-01 17:50:00 -08:00
|
|
|
from jax._src import effects
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import linear_util as lu
|
2022-06-02 11:50:03 -07:00
|
|
|
from jax._src import source_info_util
|
|
|
|
from jax._src import util
|
2024-10-01 08:03:11 -07:00
|
|
|
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
|
2023-02-17 12:45:39 -08:00
|
|
|
from jax._src.state.types import AbstractRef, RefEffect
|
2024-11-05 07:16:32 -08:00
|
|
|
from jax._src.core import replace_jaxpr_effects
|
2023-03-27 13:29:59 -07:00
|
|
|
from jax._src.interpreters import ad
|
|
|
|
from jax._src.interpreters import batching
|
|
|
|
from jax._src.interpreters import mlir
|
|
|
|
from jax._src.interpreters import partial_eval as pe
|
|
|
|
from jax._src.interpreters import xla
|
2022-06-02 11:50:03 -07:00
|
|
|
from jax._src.lax import lax
|
|
|
|
from jax._src.traceback_util import api_boundary
|
2023-02-27 11:37:10 -08:00
|
|
|
from jax._src.util import (safe_map, split_list, partition_list)
|
2022-06-02 11:50:03 -07:00
|
|
|
from jax._src.lib.mlir import ir
|
2022-12-15 20:59:34 -08:00
|
|
|
from jax._src.lib.mlir.dialects import hlo
|
2022-06-02 11:50:03 -07:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from jax._src.lax.control_flow.common import (
|
|
|
|
_avals_short,
|
|
|
|
_check_tree_and_avals,
|
|
|
|
_initial_style_jaxprs_with_common_consts,
|
|
|
|
_make_closed_jaxpr,
|
|
|
|
_prune_zeros,
|
|
|
|
_typecheck_param,
|
|
|
|
)
|
|
|
|
|
2022-07-19 08:53:23 -07:00
|
|
|
map, unsafe_map = safe_map, map
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
|
|
|
|
# For backward compatibility with a previous switch/cond calling convention,
|
|
|
|
# we allow a single (pytree) `operand` argument to be passed by keyword. We use
|
|
|
|
# a sentinel object as its default value to indicate when it is _not_ passed.
|
|
|
|
_no_operand_sentinel = object()
|
|
|
|
|
|
|
|
@api_boundary
|
|
|
|
def switch(index, branches: Sequence[Callable], *operands,
|
|
|
|
operand=_no_operand_sentinel):
|
2024-04-05 14:21:33 +05:30
|
|
|
"""Apply exactly one of the ``branches`` given by ``index``.
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
If ``index`` is out of bounds, it is clamped to within bounds.
|
|
|
|
|
|
|
|
Has the semantics of the following Python::
|
|
|
|
|
|
|
|
def switch(index, branches, *operands):
|
|
|
|
index = clamp(0, index, len(branches) - 1)
|
|
|
|
return branches[index](*operands)
|
|
|
|
|
2023-06-09 19:12:21 +01:00
|
|
|
Internally this wraps XLA's `Conditional
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
|
|
|
|
operator. However, when transformed with :func:`~jax.vmap` to operate over a
|
|
|
|
batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.
|
|
|
|
|
2022-06-02 11:50:03 -07:00
|
|
|
Args:
|
|
|
|
index: Integer scalar type, indicating which branch function to apply.
|
|
|
|
branches: Sequence of functions (A -> B) to be applied based on ``index``.
|
2024-12-02 23:10:19 +05:30
|
|
|
All branches must return the same output structure.
|
2022-06-02 11:50:03 -07:00
|
|
|
operands: Operands (A) input to whichever branch is applied.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Value (B) of ``branch(*operands)`` for the branch that was selected based
|
|
|
|
on ``index``.
|
|
|
|
"""
|
|
|
|
if not all(callable(branch) for branch in branches):
|
|
|
|
raise TypeError("lax.switch: branches argument should be a sequence of callables.")
|
|
|
|
if operand is not _no_operand_sentinel:
|
|
|
|
if operands:
|
|
|
|
raise TypeError("if 'operand' keyword is passed then no positional "
|
2022-12-01 09:12:01 -08:00
|
|
|
f"operands can be passed, got {operand=} "
|
2022-06-02 11:50:03 -07:00
|
|
|
f"and positional operands {operands}")
|
|
|
|
operands = (operand,)
|
|
|
|
del operand
|
|
|
|
|
|
|
|
if len(np.shape(index)) != 0:
|
|
|
|
raise TypeError(
|
|
|
|
f"Branch index must be scalar, "
|
|
|
|
f"got {index} of shape {np.shape(index)}.")
|
|
|
|
|
|
|
|
try:
|
|
|
|
index_dtype = dtypes.result_type(index)
|
|
|
|
except TypeError as err:
|
|
|
|
msg = f"Index type must be an integer, got {index}."
|
|
|
|
raise TypeError(msg) from err
|
|
|
|
|
|
|
|
if index_dtype.kind not in 'iu':
|
|
|
|
raise TypeError(
|
|
|
|
f"Index type must be an integer, got {index} as {index_dtype}")
|
|
|
|
|
|
|
|
branches = tuple(branches)
|
|
|
|
|
|
|
|
if len(branches) == 0:
|
|
|
|
raise ValueError("Empty branch sequence")
|
|
|
|
elif len(branches) == 1:
|
|
|
|
return branches[0](*operands)
|
|
|
|
|
|
|
|
index = lax.convert_element_type(index, np.int32)
|
|
|
|
lo = np.array(0, np.int32)
|
|
|
|
hi = np.array(len(branches) - 1, np.int32)
|
|
|
|
index = lax.clamp(lo, index, hi)
|
|
|
|
|
2024-10-31 14:06:08 -07:00
|
|
|
if (config.disable_jit.value and core.is_concrete(index)):
|
2022-06-02 11:50:03 -07:00
|
|
|
return branches[int(index)](*operands)
|
|
|
|
|
2025-01-31 22:23:20 +02:00
|
|
|
dbgs = [api_util.debug_info("switch", branch, operands, {})
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
for branch in branches]
|
2022-06-02 11:50:03 -07:00
|
|
|
ops, ops_tree = tree_flatten(operands)
|
2024-12-12 09:49:06 -08:00
|
|
|
ops_avals = tuple(map(core.get_aval, ops))
|
2022-06-02 11:50:03 -07:00
|
|
|
|
2024-12-18 23:53:28 +00:00
|
|
|
if config.mutable_array_checks.value:
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
api_util._check_no_aliased_ref_args(dbgs[0], ops_avals, ops)
|
2024-12-18 23:53:28 +00:00
|
|
|
|
2022-06-02 11:50:03 -07:00
|
|
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
branches, ops_tree, ops_avals, dbgs)
|
2024-12-18 23:53:28 +00:00
|
|
|
if config.mutable_array_checks.value:
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops)
|
2022-06-02 11:50:03 -07:00
|
|
|
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
_check_tree_and_avals("branch 0 output",
|
2022-06-02 11:50:03 -07:00
|
|
|
out_trees[0], jaxprs[0].out_avals,
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
f"branch {i + 1} output",
|
2022-06-02 11:50:03 -07:00
|
|
|
out_tree, jaxpr.out_avals)
|
|
|
|
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
2023-09-18 02:49:53 -07:00
|
|
|
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
2022-06-02 11:50:03 -07:00
|
|
|
if disallowed_effects:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f'Effects not supported in `switch`: {disallowed_effects}')
|
2024-06-26 14:25:20 -04:00
|
|
|
out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs))
|
2022-06-02 11:50:03 -07:00
|
|
|
return tree_unflatten(out_trees[0], out)
|
|
|
|
|
|
|
|
|
|
|
|
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
2024-06-26 14:25:20 -04:00
|
|
|
operand=_no_operand_sentinel):
|
2022-06-02 11:50:03 -07:00
|
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
|
2023-01-06 11:31:26 -08:00
|
|
|
Wraps XLA's `Conditional
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
|
|
|
|
operator.
|
|
|
|
|
2022-07-26 13:12:16 -07:00
|
|
|
Provided arguments are correctly typed, ``cond()`` has equivalent
|
2023-01-06 11:31:26 -08:00
|
|
|
semantics to this Python implementation, where ``pred`` must be a
|
|
|
|
scalar type::
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
def cond(pred, true_fun, false_fun, *operands):
|
|
|
|
if pred:
|
|
|
|
return true_fun(*operands)
|
|
|
|
else:
|
|
|
|
return false_fun(*operands)
|
|
|
|
|
2023-01-06 11:31:26 -08:00
|
|
|
|
|
|
|
In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
|
|
|
|
the two branches is executed (up to compiler rewrites and optimizations).
|
|
|
|
However, when transformed with :func:`~jax.vmap` to operate over a batch of
|
|
|
|
predicates, ``cond`` is converted to :func:`~jax.lax.select`.
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
pred: Boolean scalar type, indicating which branch function to apply.
|
|
|
|
true_fun: Function (A -> B), to be applied if ``pred`` is True.
|
|
|
|
false_fun: Function (A -> B), to be applied if ``pred`` is False.
|
|
|
|
operands: Operands (A) input to either branch depending on ``pred``. The
|
|
|
|
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
|
|
|
|
thereof.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
|
|
|
|
depending on the value of ``pred``. The type can be a scalar, array, or any
|
|
|
|
pytree (nested Python tuple/list/dict) thereof.
|
|
|
|
"""
|
|
|
|
if not (callable(true_fun) and callable(false_fun)):
|
|
|
|
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
|
|
|
if operand is not _no_operand_sentinel:
|
|
|
|
if operands:
|
|
|
|
raise TypeError("if 'operand' keyword is passed then no positional "
|
2022-12-01 09:12:01 -08:00
|
|
|
f"operands can be passed, got {operand=} "
|
2022-06-02 11:50:03 -07:00
|
|
|
f"and positional operands {operands}")
|
|
|
|
operands = (operand,)
|
|
|
|
del operand
|
|
|
|
|
2022-07-26 13:12:16 -07:00
|
|
|
if pred is None:
|
|
|
|
raise TypeError("cond predicate is None")
|
2022-06-02 11:50:03 -07:00
|
|
|
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
|
|
|
raise TypeError(
|
|
|
|
f"Pred must be a scalar, got {pred} of " +
|
|
|
|
(f"type {type(pred)}" if isinstance(pred, Sequence)
|
|
|
|
else f"shape {np.shape(pred)}."))
|
|
|
|
|
|
|
|
try:
|
|
|
|
pred_dtype = dtypes.result_type(pred)
|
|
|
|
except TypeError as err:
|
|
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
|
|
raise TypeError(msg.format(pred)) from err
|
|
|
|
|
|
|
|
if pred_dtype.kind != 'b':
|
|
|
|
if pred_dtype.kind in 'iuf':
|
|
|
|
pred = pred != 0
|
|
|
|
else:
|
|
|
|
msg = ("Pred type must be either boolean or number, got {}.")
|
|
|
|
raise TypeError(msg.format(pred_dtype))
|
|
|
|
|
2024-10-31 14:06:08 -07:00
|
|
|
if config.disable_jit.value and core.is_concrete(pred):
|
2022-06-02 11:50:03 -07:00
|
|
|
if pred:
|
|
|
|
return true_fun(*operands)
|
|
|
|
else:
|
|
|
|
return false_fun(*operands)
|
|
|
|
|
|
|
|
ops, ops_tree = tree_flatten(operands)
|
2024-12-12 09:49:06 -08:00
|
|
|
ops_avals = tuple(map(core.get_aval, ops))
|
2022-06-02 11:50:03 -07:00
|
|
|
|
2025-01-31 22:23:20 +02:00
|
|
|
dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {})
|
2024-12-18 23:53:28 +00:00
|
|
|
if config.mutable_array_checks.value:
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops)
|
2025-01-31 22:23:20 +02:00
|
|
|
dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
|
2022-06-02 11:50:03 -07:00
|
|
|
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
(true_fun, false_fun), ops_tree, ops_avals,
|
|
|
|
[dbg_true_fun, dbg_false_fun])
|
2022-06-02 11:50:03 -07:00
|
|
|
true_jaxpr, false_jaxpr = jaxprs
|
2024-12-18 23:53:28 +00:00
|
|
|
if config.mutable_array_checks.value:
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops)
|
2024-06-05 16:39:55 +00:00
|
|
|
|
2022-06-02 11:50:03 -07:00
|
|
|
out_tree, false_out_tree = out_trees
|
2023-09-22 13:53:49 -07:00
|
|
|
if any(isinstance(out_aval, AbstractRef) for out_aval in
|
|
|
|
true_jaxpr.out_avals + false_jaxpr.out_avals):
|
|
|
|
raise ValueError("Cannot return `Ref`s from `cond`.")
|
2022-06-02 11:50:03 -07:00
|
|
|
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
_check_tree_and_avals("true_fun output",
|
2022-06-02 11:50:03 -07:00
|
|
|
out_tree, true_jaxpr.out_avals,
|
[better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-09 14:15:39 +02:00
|
|
|
"false_fun output",
|
2022-06-02 11:50:03 -07:00
|
|
|
false_out_tree, false_jaxpr.out_avals)
|
2024-06-11 15:33:05 +01:00
|
|
|
# prune passhtrough outputs
|
|
|
|
true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr)
|
|
|
|
false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr)
|
|
|
|
in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)]
|
|
|
|
keep = [f is None for f in in_fwd]
|
|
|
|
true_jaxpr = pe.prune_closed_jaxpr_outputs(true_jaxpr, keep)
|
|
|
|
false_jaxpr = pe.prune_closed_jaxpr_outputs(false_jaxpr, keep)
|
|
|
|
|
2022-06-02 11:50:03 -07:00
|
|
|
joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
|
2023-09-18 02:49:53 -07:00
|
|
|
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
2022-06-02 11:50:03 -07:00
|
|
|
if disallowed_effects:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f'Effects not supported in `cond`: {disallowed_effects}')
|
|
|
|
|
|
|
|
index = lax.convert_element_type(pred, np.int32)
|
2023-01-18 10:17:01 -08:00
|
|
|
false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects)
|
|
|
|
true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects)
|
2022-06-02 11:50:03 -07:00
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr))
|
2024-06-11 15:33:05 +01:00
|
|
|
num_consts = len(consts)
|
2024-06-05 16:39:55 +00:00
|
|
|
out_ = iter(out)
|
2024-06-11 15:33:05 +01:00
|
|
|
|
|
|
|
out = [
|
2024-09-01 07:49:49 -07:00
|
|
|
next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts])
|
2024-06-11 15:33:05 +01:00
|
|
|
for fwd in in_fwd
|
|
|
|
]
|
2024-06-05 16:39:55 +00:00
|
|
|
assert next(out_, None) is None
|
2022-06-02 11:50:03 -07:00
|
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
|
|
|
|
@api_boundary
|
|
|
|
@functools.wraps(_cond)
|
|
|
|
def cond(*args, **kwargs):
|
|
|
|
# detect an attempt to call the former, deprecated cond
|
|
|
|
try:
|
|
|
|
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
|
2023-06-27 18:48:39 -07:00
|
|
|
_, true_operand, true_fun, false_operand, false_fun = ba.args
|
|
|
|
if callable(true_operand) and callable(true_fun):
|
|
|
|
# treat this as modern cond (with two operands)
|
|
|
|
return _cond(*args, **kwargs)
|
|
|
|
if callable(true_fun) and callable(false_fun):
|
2022-06-02 11:50:03 -07:00
|
|
|
return _cond_with_per_branch_args(*ba.args)
|
|
|
|
|
|
|
|
return _cond(*args, **kwargs)
|
|
|
|
|
|
|
|
def _cond_with_per_branch_args(pred,
|
|
|
|
true_operand, true_fun: Callable,
|
|
|
|
false_operand, false_fun: Callable):
|
|
|
|
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
|
|
|
|
|
|
|
Has equivalent semantics to this Python implementation::
|
|
|
|
|
|
|
|
def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
|
|
|
if pred:
|
|
|
|
return true_fun(true_operand)
|
|
|
|
else:
|
|
|
|
return false_fun(false_operand)
|
|
|
|
|
|
|
|
Pred has to be a scalar type, collection types (list, tuple) are not supported
|
|
|
|
"""
|
|
|
|
if not (callable(true_fun) and callable(false_fun)):
|
|
|
|
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
|
|
|
|
return _cond(pred,
|
|
|
|
lambda op: true_fun(op[0]),
|
|
|
|
lambda op: false_fun(op[1]),
|
|
|
|
(true_operand, false_operand))
|
|
|
|
|
2023-02-17 12:45:39 -08:00
|
|
|
def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects:
|
|
|
|
joined_effects = set()
|
|
|
|
for b in branches:
|
|
|
|
for eff in b.effects:
|
|
|
|
if isinstance(eff, effects.JaxprInputEffect):
|
|
|
|
# Offset index to handle predicate
|
|
|
|
eff = eff.replace(input_index=eff.input_index + 1)
|
|
|
|
joined_effects.add(eff)
|
|
|
|
return joined_effects
|
|
|
|
|
2022-11-02 16:07:22 -07:00
|
|
|
def _cond_abstract_eval(*avals, branches, **_):
|
2023-02-17 12:45:39 -08:00
|
|
|
joined_effects = _join_cond_effects(branches)
|
2023-09-18 02:49:53 -07:00
|
|
|
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
2022-06-02 11:50:03 -07:00
|
|
|
if disallowed_effects:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f'Effects not supported in `cond`: {disallowed_effects}')
|
2024-11-05 07:16:32 -08:00
|
|
|
return branches[0].out_avals, joined_effects
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
def _bcast_select(pred, on_true, on_false):
|
|
|
|
if np.ndim(pred) != np.ndim(on_true):
|
|
|
|
idx = list(range(np.ndim(pred)))
|
|
|
|
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
|
|
|
|
return lax.select(pred, on_true, on_false)
|
|
|
|
|
|
|
|
def _bcast_select_n(pred, *cases):
|
|
|
|
if np.ndim(pred) != np.ndim(cases[0]):
|
|
|
|
idx = list(range(np.ndim(pred)))
|
|
|
|
pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
|
|
|
|
return lax.select_n(pred, *cases)
|
|
|
|
|
2024-10-29 11:03:49 -07:00
|
|
|
def _cond_batching_rule(axis_data, args, dims, branches):
|
2022-06-02 11:50:03 -07:00
|
|
|
index, *ops = args
|
|
|
|
index_dim, *op_dims = dims
|
2022-11-10 12:00:21 -08:00
|
|
|
# TODO(sharadmv): clean this up by adding a specific blocklist
|
2023-02-17 12:45:39 -08:00
|
|
|
if any(isinstance(eff, RefEffect) for branch in branches for eff in
|
2022-11-02 16:07:22 -07:00
|
|
|
branch.jaxpr.effects):
|
|
|
|
raise NotImplementedError(
|
2022-11-10 12:00:21 -08:00
|
|
|
"State effect not supported in vmap-of-cond.")
|
|
|
|
from jax._src.callback import _IOEffect, _OrderedIOEffect
|
|
|
|
if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect]
|
|
|
|
for branch in branches):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"IO effect not supported in vmap-of-cond.")
|
|
|
|
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
if index_dim is not batching.not_mapped:
|
|
|
|
# Convert to a lax.select. While we could get away with not broadcasting
|
|
|
|
# some operands yet, because all outputs must be broadcast together anyway
|
|
|
|
# for the select we broadcast the input operands for simplicity and leave
|
|
|
|
# optimizations to XLA.
|
|
|
|
# TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
|
|
|
|
index, *ops = (
|
2024-10-29 11:03:49 -07:00
|
|
|
batching.bdim_at_front(x, d, axis_data.size) for x, d in zip(args, dims))
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
in_batched = [True] * len(branches[0].in_avals)
|
|
|
|
out_batched = [True] * len(branches[0].out_avals)
|
|
|
|
|
|
|
|
branches_batched = [
|
2024-10-29 11:03:49 -07:00
|
|
|
batching.batch_jaxpr(jaxpr, axis_data, in_batched, out_batched)[0]
|
2022-06-02 11:50:03 -07:00
|
|
|
for jaxpr in branches]
|
|
|
|
|
|
|
|
branch_outs = []
|
|
|
|
for i, jaxpr in enumerate(branches_batched):
|
|
|
|
# Perform a select on the inputs for safety of reverse-mode autodiff; see
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/1052
|
2022-06-02 11:50:03 -07:00
|
|
|
predicate = lax.eq(index, lax._const(index, i))
|
|
|
|
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]
|
|
|
|
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
|
|
|
|
out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
|
|
|
|
return out, [0 if b else None for b in out_batched]
|
|
|
|
else:
|
|
|
|
ops_bat = [d is not batching.not_mapped for d in op_dims]
|
|
|
|
ops = [batching.moveaxis(x, d, 0) if b else x
|
|
|
|
for b, x, d in zip(ops_bat, ops, op_dims)]
|
|
|
|
|
|
|
|
branches_out_bat = [
|
2024-10-29 11:03:49 -07:00
|
|
|
batching.batch_jaxpr(jaxpr, axis_data, ops_bat, False)[1]
|
2022-06-02 11:50:03 -07:00
|
|
|
for jaxpr in branches]
|
|
|
|
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
|
|
|
|
branches_batched = tuple(
|
2024-10-29 11:03:49 -07:00
|
|
|
batching.batch_jaxpr(jaxpr, axis_data, ops_bat, out_bat)[0]
|
2022-06-02 11:50:03 -07:00
|
|
|
for jaxpr in branches)
|
|
|
|
|
|
|
|
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
2024-06-26 14:25:20 -04:00
|
|
|
out = cond_p.bind(index, *ops, branches=branches_batched)
|
2022-06-02 11:50:03 -07:00
|
|
|
return out, out_dims
|
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
def _cond_jvp(primals, tangents, branches):
|
2022-06-02 11:50:03 -07:00
|
|
|
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
|
|
|
|
index_nz, *ops_nz = nonzeros
|
|
|
|
assert index_nz is False
|
|
|
|
|
|
|
|
branches_out_nz = [ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=False)[1]
|
|
|
|
for jaxpr in branches]
|
|
|
|
out_nz = [any(nz) for nz in zip(*branches_out_nz)]
|
|
|
|
|
|
|
|
branches_jvp = tuple(ad.jvp_jaxpr(jaxpr, ops_nz, instantiate=out_nz)[0]
|
|
|
|
for jaxpr in branches)
|
|
|
|
|
|
|
|
index, *ops = primals
|
|
|
|
_, *ops_dot = tangents
|
|
|
|
ops_dot = _prune_zeros(ops_dot)
|
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp)
|
2022-06-02 11:50:03 -07:00
|
|
|
out_primals, out_tangents = split_list(out, [len(out_nz)])
|
|
|
|
out_tangents_iter = iter(out_tangents)
|
2024-09-18 13:43:14 -07:00
|
|
|
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p)
|
2022-06-02 11:50:03 -07:00
|
|
|
for p, nz in zip(out_primals, out_nz)]
|
|
|
|
return out_primals, out_tangents
|
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
def _cond_partial_eval(trace, *tracers, branches):
|
2022-06-02 11:50:03 -07:00
|
|
|
in_unknowns = [t.pval[0] is not None for t in tracers]
|
|
|
|
index_uk, *ops_uk = in_unknowns
|
2023-02-17 12:45:39 -08:00
|
|
|
if any(isinstance(eff, RefEffect) for branch in branches for eff in
|
2022-11-02 16:07:22 -07:00
|
|
|
branch.jaxpr.effects):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"State effect not supported in cond partial-eval.")
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
if index_uk:
|
|
|
|
# When the branch index is unknown, we stage out the whole cond.
|
|
|
|
# TODO(mattjj): remove this path when old remat is removed
|
2024-06-26 14:25:20 -04:00
|
|
|
params = dict(branches=branches)
|
2022-06-02 11:50:03 -07:00
|
|
|
return trace.default_process_primitive(cond_p, tracers, params)
|
|
|
|
|
|
|
|
branches_out_uks = []
|
|
|
|
for branch_jaxpr in branches:
|
|
|
|
_, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(
|
|
|
|
branch_jaxpr, ops_uk, instantiate=False)
|
|
|
|
branches_out_uks.append(out_uks)
|
|
|
|
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
|
|
|
|
|
|
|
|
branches_known, branches_unknown, branch_res_avals = [], [], []
|
|
|
|
for branch_jaxpr in branches:
|
|
|
|
branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
|
|
|
|
pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
|
|
|
|
branches_known.append(branch_jaxpr_known)
|
|
|
|
branches_unknown.append(branch_jaxpr_unknown)
|
|
|
|
branch_res_avals.append(res_avals)
|
|
|
|
|
|
|
|
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
|
|
|
|
num_res = len(all_res_avals)
|
|
|
|
|
|
|
|
num_known_outs = len(out_uks) - sum(out_uks)
|
|
|
|
branches_known = _join_cond_outputs(
|
|
|
|
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
|
|
|
|
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
|
|
|
|
branches_unknown, all_res_avals, res_avals_per_branch)
|
2022-07-19 08:53:23 -07:00
|
|
|
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
2022-06-02 11:50:03 -07:00
|
|
|
for j in branches_known[1:])
|
|
|
|
|
|
|
|
in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
|
2024-06-26 14:25:20 -04:00
|
|
|
out_consts_res = cond_p.bind(*in_consts, branches=branches_known)
|
2022-06-02 11:50:03 -07:00
|
|
|
out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res])
|
|
|
|
|
|
|
|
index_tracer = trace.instantiate_const(tracers[0])
|
|
|
|
ops_tracers = [trace.instantiate_const(t)
|
|
|
|
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
|
2022-07-19 08:53:23 -07:00
|
|
|
res_tracers = map(trace.new_instantiated_const, res)
|
2022-06-02 11:50:03 -07:00
|
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
|
|
|
for aval in branches_unknown[0].out_avals]
|
2024-06-26 14:25:20 -04:00
|
|
|
params = dict(branches=branches_unknown)
|
2022-06-02 11:50:03 -07:00
|
|
|
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
|
|
|
source = source_info_util.current().replace(name_stack=name_stack)
|
|
|
|
eqn = pe.new_eqn_recipe(
|
|
|
|
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
|
2022-07-19 08:53:23 -07:00
|
|
|
core.join_effects(*(j.effects for j in branches_unknown)), source)
|
2022-06-02 11:50:03 -07:00
|
|
|
for t in out_tracers: t.recipe = eqn
|
|
|
|
return util.merge_lists(out_uks, out_consts, out_tracers)
|
|
|
|
|
2022-07-19 08:53:23 -07:00
|
|
|
# TODO(mattjj): de-duplicate with _cond_partial_eval
|
|
|
|
def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
|
|
|
index_uk, *ops_uk = unks_in
|
|
|
|
branches = eqn.params['branches']
|
|
|
|
|
2022-08-09 12:22:10 -07:00
|
|
|
# Instantiate all inputs (b/c jaxpr_staged will take all inputs).
|
|
|
|
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
|
|
|
if type(x) is core.Var and not inst]
|
|
|
|
del inst_in
|
|
|
|
|
|
|
|
# NOTE(mattjj): I think it should be impossible for the index to be unknown,
|
|
|
|
# but asserting that caused a test failure in diffrax. So we handle it: if it
|
|
|
|
# is unknown, stage out the whole cond.
|
|
|
|
if index_uk:
|
|
|
|
all_true = [True] * len(branches[0].out_avals)
|
|
|
|
return None, eqn, all_true, all_true, new_inst
|
|
|
|
|
2022-07-19 08:53:23 -07:00
|
|
|
# First, compute output unknowns (unks_out), where an output of the cond is
|
|
|
|
# unknown if it would be unknown on any of the branches.
|
2023-06-23 15:11:37 -07:00
|
|
|
unks_out: list[bool] = [False] * len(eqn.outvars)
|
2022-07-19 08:53:23 -07:00
|
|
|
for jaxpr in branches:
|
|
|
|
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
|
2022-07-28 18:04:49 -07:00
|
|
|
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
|
2022-07-19 08:53:23 -07:00
|
|
|
ensure_out_unknowns=False, ensure_out_inst=True, saveable=saveable)
|
|
|
|
unks_out = map(operator.or_, unks_out, unks_out_)
|
|
|
|
|
|
|
|
# Next, use the computed output unknowns to build a known jaxpr and a staged
|
|
|
|
# jaxpr for each branch.
|
2023-06-23 15:11:37 -07:00
|
|
|
branches_known_ : list[core.ClosedJaxpr] = []
|
|
|
|
branches_staged_: list[core.ClosedJaxpr] = []
|
2024-12-09 06:52:25 -08:00
|
|
|
branch_res_avals: list[list[core.AbstractValue]] = []
|
2022-07-19 08:53:23 -07:00
|
|
|
for jaxpr in branches:
|
|
|
|
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
|
|
|
|
pe.partial_eval_jaxpr_custom(
|
2022-07-28 18:04:49 -07:00
|
|
|
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
|
2022-07-19 08:53:23 -07:00
|
|
|
ensure_out_unknowns=unks_out, ensure_out_inst=True,
|
|
|
|
saveable=saveable)
|
|
|
|
branches_known_.append( core.ClosedJaxpr(jaxpr_known, jaxpr.consts))
|
|
|
|
branches_staged_.append(core.ClosedJaxpr(jaxpr_staged, jaxpr.consts))
|
|
|
|
branch_res_avals.append(branches_staged_[-1].in_avals[:num_res])
|
|
|
|
|
|
|
|
# Residuals may differ across branches, so we merge them, then use the merged
|
|
|
|
# residuals to join the outputs of all branches to the same type.
|
|
|
|
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
|
|
|
|
num_res = len(all_res_avals)
|
|
|
|
num_known_outs = len(unks_out) - sum(unks_out)
|
|
|
|
branches_known = _join_cond_outputs(
|
|
|
|
branches_known_, all_res_avals, res_avals_per_branch, num_known_outs)
|
|
|
|
branches_staged = _join_cond_pe_staged_jaxpr_inputs(
|
|
|
|
branches_staged_, all_res_avals, res_avals_per_branch)
|
|
|
|
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
|
|
|
|
for j in branches_known[1:])
|
|
|
|
|
|
|
|
# Create residual variables.
|
|
|
|
newvar = core.gensym()
|
|
|
|
res_binders = map(newvar, all_res_avals)
|
|
|
|
|
|
|
|
# Build the known eqn.
|
|
|
|
ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar
|
|
|
|
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
2024-06-26 14:25:20 -04:00
|
|
|
params_known = dict(branches=branches_known)
|
2023-02-17 12:45:39 -08:00
|
|
|
effects_known = _join_cond_effects(branches_known)
|
2022-07-19 08:53:23 -07:00
|
|
|
eqn_known = pe.new_jaxpr_eqn(
|
|
|
|
ins_known, [*out_binders_known, *res_binders], cond_p, params_known,
|
|
|
|
effects_known, eqn.source_info)
|
|
|
|
|
|
|
|
# Build the staged eqn.
|
|
|
|
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
2024-06-26 14:25:20 -04:00
|
|
|
params_staged = dict(branches=branches_staged)
|
2023-02-17 12:45:39 -08:00
|
|
|
effects_staged = _join_cond_effects(branches_staged)
|
2022-07-19 08:53:23 -07:00
|
|
|
eqn_staged = pe.new_jaxpr_eqn(
|
|
|
|
[eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged,
|
|
|
|
cond_p, params_staged, effects_staged, eqn.source_info)
|
|
|
|
|
|
|
|
new_vars = [*new_inst, *res_binders]
|
|
|
|
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
|
|
|
|
2022-06-02 11:50:03 -07:00
|
|
|
# When partially evaluating conditionals, each branch produces residuals
|
|
|
|
# depending on the computation carried out by the branch, and a corresponding
|
|
|
|
# staged jaxpr that accepts those residuals as its first few inputs. The
|
|
|
|
# residual-producing branches are staged as jaxprs and bound right away in a
|
|
|
|
# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
|
|
|
|
# conditional. The following helper functions ensure that both collections of
|
|
|
|
# jaxprs (those evaluated and those staged) are valid for joint use under their
|
|
|
|
# respective conditionals.
|
|
|
|
#
|
|
|
|
# In particular, the residuals derived from each original branch may have
|
|
|
|
# distinct types. Because the branches of conditionals must have identical type
|
|
|
|
# signatures, we join residuals together across branches into a common format.
|
|
|
|
|
|
|
|
# In order to set up a type signature that all branches can conform to, it would
|
|
|
|
# suffice to concatenate all branches' residuals. But concatenation can result
|
|
|
|
# in redundant inputs and outputs, and might lead to memory allocation that
|
|
|
|
# scales unnecessarily with the branch count. This function finds common
|
|
|
|
# residual types across branches for reuse, so as to avoid redundant
|
|
|
|
# allocation. It returns a list L of types (avals) representing the collection
|
|
|
|
# of residuals merged according to type, and, for each branch, a lookup table to
|
|
|
|
# match its residuals to their positions/types in L. Example input/output:
|
|
|
|
#
|
|
|
|
# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]]
|
|
|
|
# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]]
|
|
|
|
# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
|
|
|
|
def _merge_branch_residuals(branch_res_avals):
|
|
|
|
def enumerate_equal(xs):
|
|
|
|
counts = {v: itertools.count() for v in set(xs)}
|
|
|
|
return [(x, next(counts[x])) for x in xs]
|
2022-07-19 08:53:23 -07:00
|
|
|
branch_res_tagged_avals = map(enumerate_equal, branch_res_avals)
|
2022-06-02 11:50:03 -07:00
|
|
|
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
|
|
|
|
indices = {v: i for i, v in enumerate(all_tagged_avals)}
|
|
|
|
branch_indices = [
|
|
|
|
[indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
|
|
|
|
all_avals = [x for x, _ in all_tagged_avals]
|
|
|
|
return all_avals, branch_indices
|
|
|
|
|
|
|
|
# This function augments branch outputs to agree with the merged residual
|
|
|
|
# format: each branch is made to return zero-filled values in the places of
|
|
|
|
# residual outputs that it does not populate.
|
|
|
|
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
|
|
|
num_non_res_outputs):
|
|
|
|
def augment_jaxpr(jaxpr, res_indices):
|
|
|
|
@lu.wrap_init
|
|
|
|
def f_aug(*args):
|
|
|
|
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
|
|
|
|
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
|
2022-07-19 08:53:23 -07:00
|
|
|
aug_residuals = map(ad_util.zeros_like_aval, all_res_avals)
|
2022-06-02 11:50:03 -07:00
|
|
|
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
|
|
|
|
return outs + list(aug_residuals)
|
|
|
|
|
|
|
|
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
|
|
|
|
|
2022-07-19 08:53:23 -07:00
|
|
|
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
# This function augments branch inputs to agree with the merged residual format:
|
|
|
|
# each branch is made to accept all residuals, even though it will ignore those
|
|
|
|
# that it does not read.
|
|
|
|
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
|
|
|
res_aval_indices_per_jaxpr):
|
Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.
Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)
But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.
So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.
In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:
```
import jax
def f(x):
for _ in range(3):
x = jax.numpy.sin(x)
return x
jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
# b:f32[] = sin a
# c:f32[] = sin b
# d:f32[] = sin c
# in (d,) }
_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```
Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!
So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```
PiperOrigin-RevId: 607664497
2024-02-16 05:56:45 -08:00
|
|
|
newvar = core.gensym(suffix='_')
|
2022-07-19 08:53:23 -07:00
|
|
|
all_res_vars = map(newvar, all_res_avals)
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
def augment_jaxpr(jaxpr, res_indices):
|
|
|
|
num_res = len(res_indices)
|
|
|
|
res_vars = jaxpr.jaxpr.invars[:num_res]
|
|
|
|
non_res_vars = jaxpr.jaxpr.invars[num_res:]
|
|
|
|
|
|
|
|
aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
|
|
|
|
aug_invars = aug_res_vars + non_res_vars
|
|
|
|
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
|
|
|
|
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
|
|
|
|
jaxpr.jaxpr.effects)
|
|
|
|
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
|
|
|
return jaxpr_aug
|
|
|
|
|
2022-07-19 08:53:23 -07:00
|
|
|
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
def _ordered_unique(xs):
|
|
|
|
d = collections.OrderedDict((x, None) for x in xs)
|
|
|
|
return list(d.keys())
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
|
2024-11-12 10:36:14 -08:00
|
|
|
) -> tuple[list[bool], core.JaxprEqn | None]:
|
|
|
|
|
|
|
|
if not any(used_outputs) and not pe.has_effects(eqn):
|
|
|
|
return [False] * len(eqn.invars), None
|
|
|
|
|
2022-07-19 08:53:23 -07:00
|
|
|
closed_branches = eqn.params['branches']
|
|
|
|
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
|
|
|
|
|
|
|
|
# First, compute which inputs are used in any branch (not including `pred`).
|
2023-06-23 15:11:37 -07:00
|
|
|
used_inputs: list[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
|
2022-07-19 08:53:23 -07:00
|
|
|
for jaxpr in branches:
|
|
|
|
_, used_inputs_ = pe.dce_jaxpr(jaxpr, used_outputs, instantiate=False)
|
|
|
|
used_inputs = map(operator.or_, used_inputs, used_inputs_)
|
|
|
|
|
|
|
|
# Next, compute DCEd branches, instantiating according to used_inputs.
|
|
|
|
dce_branches_ = [pe.dce_jaxpr(jaxpr, used_outputs, instantiate=used_inputs)[0]
|
|
|
|
for jaxpr in branches]
|
|
|
|
dce_branches = [core.ClosedJaxpr(jaxpr, closed_jaxpr.consts)
|
|
|
|
for closed_jaxpr, jaxpr in zip(closed_branches, dce_branches_)]
|
|
|
|
|
|
|
|
# Finally, update parameters and form the new eqn.
|
2024-06-26 14:25:20 -04:00
|
|
|
new_params = dict(eqn.params, branches=tuple(dce_branches))
|
2022-07-19 08:53:23 -07:00
|
|
|
new_effects = core.join_effects(*(b.effects for b in dce_branches))
|
2023-02-17 12:45:39 -08:00
|
|
|
new_effects = _join_cond_effects(dce_branches_)
|
2022-07-19 08:53:23 -07:00
|
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
|
|
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
|
|
|
|
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
|
|
|
eqn.primitive, new_params, new_effects, eqn.source_info)
|
|
|
|
|
|
|
|
assert all(len(new_eqn.invars ) == 1 + len(jaxpr.in_avals )
|
|
|
|
for jaxpr in new_params['branches'])
|
|
|
|
assert all(len(new_eqn.outvars) == len(jaxpr.out_avals)
|
|
|
|
for jaxpr in new_params['branches'])
|
|
|
|
return [True, *used_inputs], new_eqn
|
|
|
|
|
|
|
|
|
2024-02-24 16:11:41 -08:00
|
|
|
def _transpose_cond_jaxpr(jaxpr, num_res):
|
2022-06-02 11:50:03 -07:00
|
|
|
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
|
|
|
|
|
|
|
@lu.wrap_init
|
|
|
|
def transposed(*args):
|
|
|
|
res, cts_out = split_list(args, [num_res])
|
|
|
|
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
|
|
|
cts_in = ad.backward_pass(
|
2024-02-24 16:11:41 -08:00
|
|
|
jaxpr.jaxpr, False, jaxpr.consts, primals, cts_out)
|
2022-06-02 11:50:03 -07:00
|
|
|
_, cts_in = split_list(cts_in, [num_res])
|
2023-12-20 12:47:43 -08:00
|
|
|
return map(ad.instantiate_zeros, cts_in)
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
def _cond_transpose(cts, *args, branches):
|
2022-06-02 11:50:03 -07:00
|
|
|
index, *ops = args
|
2024-06-26 14:25:20 -04:00
|
|
|
assert type(index) is not ad.UndefinedPrimal
|
2023-01-16 10:29:59 -08:00
|
|
|
linear = [type(x) is ad.UndefinedPrimal for x in ops]
|
2024-11-05 07:16:32 -08:00
|
|
|
in_avals = branches[0].in_avals
|
2022-06-02 11:50:03 -07:00
|
|
|
num_res = len(ops) - sum(linear)
|
2023-02-17 12:45:39 -08:00
|
|
|
if any(isinstance(eff, RefEffect) for branch in branches for eff in
|
2022-11-02 16:07:22 -07:00
|
|
|
branch.jaxpr.effects):
|
|
|
|
raise NotImplementedError("State effect not supported in cond transpose.")
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
branches_trans = tuple(
|
2024-02-24 16:11:41 -08:00
|
|
|
_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches)
|
2024-11-05 07:16:32 -08:00
|
|
|
lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l]
|
2022-06-02 11:50:03 -07:00
|
|
|
assert all(core.typematch(out_aval, lin_in_aval)
|
|
|
|
for jaxpr in branches_trans
|
|
|
|
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
|
|
|
|
|
|
|
|
res = ops[:num_res]
|
2023-12-20 12:47:43 -08:00
|
|
|
cts = map(ad.instantiate_zeros, cts)
|
2022-06-02 11:50:03 -07:00
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
out = cond_p.bind(index, *res, *cts, branches=branches_trans)
|
2022-07-19 08:53:23 -07:00
|
|
|
assert all(map(core.typecheck, lin_in_avals, out))
|
2022-06-02 11:50:03 -07:00
|
|
|
|
|
|
|
out_iter = iter(out)
|
|
|
|
out = [next(out_iter) if l else None for l in linear]
|
|
|
|
assert next(out_iter, None) is None
|
|
|
|
return [None] + out
|
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
def _cond_typecheck(bind_time, *in_atoms, branches):
|
2023-03-21 21:43:20 -07:00
|
|
|
if not bind_time:
|
|
|
|
_, *in_atoms = in_atoms
|
2022-06-11 15:46:05 -07:00
|
|
|
avals = [x.aval for x in in_atoms]
|
2022-06-02 11:50:03 -07:00
|
|
|
tc = partial(_typecheck_param, 'cond')
|
|
|
|
tc(branches, 'branches', 'tuple of ClosedJaxpr',
|
|
|
|
type(branches) is tuple and
|
|
|
|
all(type(x) is core.ClosedJaxpr for x in branches))
|
|
|
|
|
|
|
|
if len(branches) == 0:
|
|
|
|
raise core.JaxprTypeError('cond requires at least one branch function')
|
|
|
|
|
|
|
|
jaxpr0 = branches[0]
|
|
|
|
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
|
|
|
|
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
|
2023-02-17 12:45:39 -08:00
|
|
|
joined_effects = _join_cond_effects(branches)
|
2023-09-18 02:49:53 -07:00
|
|
|
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
2022-06-02 11:50:03 -07:00
|
|
|
if disallowed_effects:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f'Effects not supported in `cond`: {disallowed_effects}')
|
|
|
|
|
|
|
|
for i, jaxpr in enumerate(branches[1:]):
|
|
|
|
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
|
|
|
|
raise core.JaxprTypeError(
|
|
|
|
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
|
|
|
|
f'branch {i+1} takes {len(jaxpr.in_avals)}')
|
|
|
|
if len(jaxpr0.out_avals) != len(jaxpr.out_avals):
|
|
|
|
raise core.JaxprTypeError(
|
|
|
|
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
|
|
|
|
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
|
2022-07-19 08:53:23 -07:00
|
|
|
if not all(map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
|
2022-06-02 11:50:03 -07:00
|
|
|
raise core.JaxprTypeError(
|
|
|
|
f'cond branches 0 and {i+1} have mismatching input types: '
|
|
|
|
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
|
2022-07-19 08:53:23 -07:00
|
|
|
if not all(map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
|
2022-06-02 11:50:03 -07:00
|
|
|
raise core.JaxprTypeError(
|
|
|
|
f'cond branches 0 and {i+1} have mismatching output types: '
|
|
|
|
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
|
|
|
|
|
|
|
|
if len(avals) != 1 + len(jaxpr0.in_avals):
|
|
|
|
raise core.JaxprTypeError(
|
|
|
|
f'cond called with {len(avals) - 1} non-predicate operands, '
|
|
|
|
f'but branches take {len(jaxpr0.in_avals)} inputs')
|
|
|
|
|
|
|
|
index_aval, *op_avals = avals
|
|
|
|
if index_aval.dtype != np.int32:
|
|
|
|
raise core.JaxprTypeError(
|
|
|
|
f'cond called with index of type {index_aval.dtype} instead of int32')
|
2022-07-19 08:53:23 -07:00
|
|
|
if not all(map(core.typecompat, jaxpr0.in_avals, op_avals)):
|
2022-06-02 11:50:03 -07:00
|
|
|
raise core.JaxprTypeError(
|
|
|
|
f'cond branches take input types {jaxpr0_in_avals_str}, '
|
|
|
|
f'called with operands of type {_avals_short(op_avals)}')
|
2022-06-11 15:46:05 -07:00
|
|
|
return jaxpr0.out_avals, joined_effects
|
2022-06-02 11:50:03 -07:00
|
|
|
|
2024-10-29 11:03:49 -07:00
|
|
|
cond_p = core.Primitive('cond')
|
2022-06-02 11:50:03 -07:00
|
|
|
cond_p.multiple_results = True
|
2023-03-27 13:29:59 -07:00
|
|
|
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
|
2022-06-02 11:50:03 -07:00
|
|
|
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
|
|
|
|
ad.primitive_jvps[cond_p] = _cond_jvp
|
2024-12-05 05:44:40 +00:00
|
|
|
ad.primitive_transposes[cond_p] = _cond_transpose
|
2022-06-02 11:50:03 -07:00
|
|
|
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
2024-10-29 11:03:49 -07:00
|
|
|
batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule
|
2022-06-02 11:50:03 -07:00
|
|
|
xla.register_initial_style_primitive(cond_p)
|
2023-03-21 21:43:20 -07:00
|
|
|
core.custom_typechecks[cond_p] = partial(_cond_typecheck, False)
|
2022-07-19 08:53:23 -07:00
|
|
|
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
|
|
|
|
pe.dce_rules[cond_p] = _cond_dce_rule
|
2024-10-14 14:00:58 -07:00
|
|
|
batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule
|
2022-06-02 11:50:03 -07:00
|
|
|
|
2024-06-26 14:25:20 -04:00
|
|
|
def _cond_lowering(ctx, index, *args, branches):
|
2022-06-02 11:50:03 -07:00
|
|
|
joined_effects = core.join_effects(*(branch.effects for branch in branches))
|
2023-02-01 17:50:00 -08:00
|
|
|
ordered_effects = list(effects.ordered_effects.filter_in(joined_effects))
|
2022-06-02 11:50:03 -07:00
|
|
|
num_tokens = len(ordered_effects)
|
|
|
|
tokens_in = ctx.tokens_in.subset(ordered_effects)
|
|
|
|
output_token_types = [mlir.token_type() for _ in ordered_effects]
|
|
|
|
output_types = [
|
2024-07-03 16:38:18 -04:00
|
|
|
*output_token_types, *map(mlir.aval_to_ir_type, ctx.avals_out)]
|
|
|
|
flat_output_types = mlir.flatten_ir_types(output_types)
|
2022-06-02 11:50:03 -07:00
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
# CaseOp takes a single argument 'index' and the corresponding blocks
|
2022-06-02 11:50:03 -07:00
|
|
|
# have no arguments; the computation within the block uses implicit
|
|
|
|
# captures.
|
2022-12-15 20:59:34 -08:00
|
|
|
case_op = hlo.CaseOp(flat_output_types, index=index,
|
|
|
|
num_branches=len(branches))
|
2024-02-20 07:16:38 -08:00
|
|
|
name_stack = ctx.name_stack.extend('cond')
|
2022-06-02 11:50:03 -07:00
|
|
|
for i, jaxpr in enumerate(branches):
|
|
|
|
branch = case_op.regions[i].blocks.append()
|
|
|
|
with ir.InsertionPoint(branch):
|
2024-07-01 08:42:48 -04:00
|
|
|
consts = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
2022-06-02 11:50:03 -07:00
|
|
|
out_vals, tokens_out = mlir.jaxpr_subcomp(
|
2024-02-20 07:16:38 -08:00
|
|
|
ctx.module_context, jaxpr.jaxpr, name_stack.extend(f'branch_{i}_fun'),
|
2024-07-01 08:42:48 -04:00
|
|
|
tokens_in, consts, *args,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
dim_var_values=ctx.dim_var_values)
|
2022-06-02 11:50:03 -07:00
|
|
|
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
|
|
|
|
out_vals = [*out_tokens, *out_vals]
|
2024-07-01 08:42:48 -04:00
|
|
|
hlo.return_(mlir.flatten_ir_values(out_vals))
|
2022-06-02 11:50:03 -07:00
|
|
|
|
2024-07-03 16:38:18 -04:00
|
|
|
tokens_and_outputs = mlir.unflatten_ir_values_like_types(
|
|
|
|
case_op.results, output_types)
|
2022-06-02 11:50:03 -07:00
|
|
|
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
|
|
|
|
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
mlir.register_lowering(cond_p, _cond_lowering)
|
2022-11-02 16:07:22 -07:00
|
|
|
|
2024-10-01 08:03:11 -07:00
|
|
|
@register_partial_discharge_rule(cond_p)
|
|
|
|
def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, branches):
|
|
|
|
assert not should_discharge[0], "Can't discharge the index."
|
2022-11-02 16:07:22 -07:00
|
|
|
discharged_branches = tuple(
|
2024-10-02 09:27:57 -07:00
|
|
|
discharge_state(branch.jaxpr, (), should_discharge=should_discharge[1:])[0]
|
|
|
|
for branch in branches
|
|
|
|
)
|
|
|
|
# Don't thread the ref values through the cond if they never change.
|
|
|
|
forwarded_outvars = None
|
|
|
|
for branch in discharged_branches:
|
|
|
|
invar_pos = {v: i for i, v in enumerate(branch.invars)}
|
|
|
|
branch_forwarding = [
|
|
|
|
invar_pos.get(v, None) if isinstance(v, core.Var) else None
|
|
|
|
for v in branch.outvars[len(out_avals) :]
|
|
|
|
]
|
|
|
|
if forwarded_outvars is None:
|
|
|
|
forwarded_outvars = branch_forwarding
|
|
|
|
else:
|
|
|
|
forwarded_outvars = [
|
|
|
|
i if i == j else None
|
|
|
|
for i, j in zip(forwarded_outvars, branch_forwarding)
|
|
|
|
]
|
|
|
|
assert forwarded_outvars is not None
|
|
|
|
all_outvars_fwd = [None] * len(out_avals) + forwarded_outvars
|
|
|
|
new_branches = tuple(
|
2024-10-01 08:03:11 -07:00
|
|
|
core.ClosedJaxpr(
|
2024-10-02 09:27:57 -07:00
|
|
|
branch.replace(outvars=[v for v, fwd in zip(branch.outvars, all_outvars_fwd)
|
|
|
|
if fwd is None]), ())
|
|
|
|
for branch in discharged_branches
|
|
|
|
)
|
|
|
|
out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches)
|
|
|
|
out_vals, out_ref_vals_no_fwd = util.split_list(out_vals_no_fwd, [len(out_avals)])
|
|
|
|
# Insert forwarded values into reference outputs
|
|
|
|
ref_val_no_fwd_iter = iter(out_ref_vals_no_fwd)
|
|
|
|
out_ref_vals = [next(ref_val_no_fwd_iter) if fwd is None else args[fwd]
|
|
|
|
for fwd in forwarded_outvars]
|
|
|
|
# Map reference outputs back to their invars
|
2022-11-02 16:07:22 -07:00
|
|
|
ref_val_iter = iter(out_ref_vals)
|
|
|
|
new_invals = []
|
2024-10-01 08:03:11 -07:00
|
|
|
for should, aval in zip(should_discharge, in_avals):
|
|
|
|
discharged_inval = isinstance(aval, AbstractRef) and should
|
|
|
|
new_invals.append(next(ref_val_iter) if discharged_inval else None)
|
2022-11-02 16:07:22 -07:00
|
|
|
return new_invals, out_vals
|
2023-10-27 03:04:39 +03:00
|
|
|
|
|
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
def platform_dependent(*args: Any,
|
|
|
|
default: Callable[..., _T] | None = None,
|
|
|
|
**per_platform: Callable[..., _T]):
|
|
|
|
"""Stages out platform-specific code.
|
|
|
|
|
|
|
|
In JAX the actual platform on which a computation is run is determined
|
|
|
|
very late, e.g., based on where the data is located. When using AOT
|
|
|
|
lowering or serialization, the computation may be compiled and executed
|
|
|
|
on a different machine, or even on a platform that is not available at
|
|
|
|
lowering time. This means that it is not safe to write platform-dependent
|
|
|
|
code using Python conditionals, e.g., based on the current default
|
|
|
|
JAX platform. Instead, one can use ``platform_dependent``:
|
|
|
|
|
|
|
|
Usage::
|
|
|
|
|
|
|
|
def cpu_code(*args): ...
|
|
|
|
def tpu_code(*args): ...
|
|
|
|
def other_platforms_code(*args): ...
|
|
|
|
res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code,
|
|
|
|
default=other_platforms_code)
|
|
|
|
|
|
|
|
When the staged out code is executed on a CPU, this is equivalent to
|
|
|
|
``cpu_code(*args)``, on a TPU is equivalent to ``tpu_code(*args)`` and on
|
|
|
|
any other platform to ``other_platforms_code(*args)``.
|
|
|
|
Unlike a Python conditional, all alternatives are traced
|
|
|
|
and staged out to Jaxpr. This is similar to, and is implemented in terms of,
|
|
|
|
:func:`~switch`, from which it inherits the behavior
|
|
|
|
under transformations.
|
|
|
|
|
|
|
|
Unlike a :func:`~switch` the choice of what gets executed is made earlier:
|
|
|
|
in most cases during lowering when the lowering platform is known; in the
|
|
|
|
rare case of multi-platform lowering and serialization, the StableHLO code
|
|
|
|
will contain a conditional on the actual platform. This conditional is
|
|
|
|
resolved just in time prior to compilation when the compilation platform is
|
|
|
|
known. This means that the compiler actually never sees a conditional.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
*args: JAX arrays passed to each of the branches. May be PyTrees.
|
|
|
|
**per_platform: branches to use for different platforms. The branches are
|
|
|
|
JAX callables invoked with ``*args``. The keywords are platform names,
|
|
|
|
e.g., 'cpu', 'tpu', 'cuda', 'rocm'.
|
|
|
|
default: optional default branch to use for a platform not mentioned in
|
|
|
|
``per_platform``. If there is no ``default`` there will be an error when
|
|
|
|
the code is lowered for a platform not mentioned in ``per_platform``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The value ``per_platform[execution_platform](*args)``.
|
|
|
|
"""
|
|
|
|
# Join identical branches
|
|
|
|
platform_branches: list[tuple[list[str], Callable]] = []
|
|
|
|
for pname, pbranch in per_platform.items():
|
|
|
|
if pname == "gpu":
|
2023-11-03 08:28:59 +01:00
|
|
|
raise ValueError("Use 'cuda' or 'rocm' for lax.platform_dependent.")
|
2023-10-27 03:04:39 +03:00
|
|
|
for ps, b in platform_branches:
|
|
|
|
if b == pbranch:
|
|
|
|
ps.append(pname)
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
platform_branches.append(([pname], pbranch))
|
|
|
|
|
|
|
|
platforms_lists, branches = util.unzip2(platform_branches)
|
|
|
|
platform_index = platform_index_p.bind(
|
|
|
|
platforms=tuple(tuple(ps) for ps in platforms_lists),
|
|
|
|
has_default=(default is not None))
|
2024-12-22 00:50:12 -08:00
|
|
|
|
2023-10-27 03:04:39 +03:00
|
|
|
if default is not None:
|
|
|
|
branches = branches + (default,)
|
|
|
|
# Use a switch, to get the proper transformation rules for free. Since
|
|
|
|
# platform index has no dependence on the input data, it won't be vectorized
|
|
|
|
# under vmap.
|
2024-06-20 20:16:01 +03:00
|
|
|
# If the switch and the platform_index_p above are in the same compilation
|
|
|
|
# unit then constant-folding will remove the unnecessary branches. However,
|
|
|
|
# if we run in eager mode the switch below cannot be constant-folded and
|
|
|
|
# the compilation may fail if some of the branches contain custom calls not
|
|
|
|
# recognized on the compilation platform. Detect eager mode and keep only the
|
|
|
|
# needed branch.
|
|
|
|
try:
|
2024-12-22 00:50:12 -08:00
|
|
|
# Note/TODO(mvoz): This actually rarely seems to concretize - we could look into
|
|
|
|
# core.ensure_compile_time_eval to get better single-branch selection.
|
2024-06-20 20:16:01 +03:00
|
|
|
platform_index_concrete = core.concrete_or_error(operator.index, platform_index)
|
|
|
|
except core.ConcretizationTypeError:
|
|
|
|
return switch(platform_index, branches, *args)
|
|
|
|
else:
|
|
|
|
assert 0 <= platform_index_concrete < len(branches)
|
|
|
|
return branches[platform_index_concrete](*args)
|
2023-10-27 03:04:39 +03:00
|
|
|
|
|
|
|
# A primitive to compute the index of a platform into a list of platforms.
|
|
|
|
# Args:
|
|
|
|
# platforms: Sequence[Sequence[str]]: a sequence of sequences of platform
|
|
|
|
# names. If the current lowering platform is in one of the inner sequences
|
|
|
|
# returns the index of that inner sequence in the outer sequence.
|
|
|
|
# has_default: if True, and if the lowering platform is not found in
|
|
|
|
# `platforms` then return `len(platforms)`. Otherwise, raise an error.
|
|
|
|
platform_index_p = core.Primitive("platform_index")
|
|
|
|
platform_index_p.multiple_results = False
|
|
|
|
platform_index_p.def_impl(functools.partial(dispatch.apply_primitive,
|
|
|
|
platform_index_p))
|
|
|
|
|
|
|
|
@platform_index_p.def_abstract_eval
|
|
|
|
def _platform_index_aval(*_, **__):
|
|
|
|
return core.ShapedArray((), np.int32)
|
|
|
|
|
|
|
|
def _platform_index_lowering(ctx: mlir.LoweringRuleContext,
|
|
|
|
*,
|
|
|
|
platforms: Sequence[Sequence[str]],
|
|
|
|
has_default: bool):
|
2024-05-16 15:10:01 +01:00
|
|
|
def lower_constant(
|
|
|
|
ctx: mlir.LoweringRuleContext, *, i: int
|
|
|
|
) -> Sequence[ir.Value]:
|
2024-07-01 08:42:48 -04:00
|
|
|
v = mlir.ir_constant(np.int32(i))
|
|
|
|
assert isinstance(v, ir.Value), v
|
|
|
|
return [v]
|
2023-11-03 08:28:59 +01:00
|
|
|
platform_rules: dict[str, mlir.LoweringRule] = {}
|
|
|
|
for i, ps in enumerate(platforms):
|
|
|
|
rule = partial(lower_constant, i=i)
|
|
|
|
for p in ps:
|
|
|
|
platform_rules[p] = rule
|
|
|
|
|
|
|
|
default_rule = (
|
|
|
|
partial(lower_constant, i=len(platforms)) if has_default else None)
|
|
|
|
return mlir.lower_per_platform(
|
2023-10-27 03:04:39 +03:00
|
|
|
ctx,
|
|
|
|
f"platform_index(platforms={platforms}, has_default={has_default})",
|
2023-11-03 08:28:59 +01:00
|
|
|
platform_rules, default_rule, effects.no_effects)
|
2023-10-27 03:04:39 +03:00
|
|
|
|
|
|
|
mlir.register_lowering(platform_index_p, _platform_index_lowering)
|