From af2306c0a8a50c474b44415fca4ddd256e8b0202 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 1 Feb 2023 17:50:00 -0800 Subject: [PATCH] Refactor effects system to use effect types, not objects --- jax/_src/ad_checkpoint.py | 13 +- jax/_src/api.py | 9 +- jax/_src/callback.py | 17 +- jax/_src/checkify.py | 10 +- jax/_src/core.py | 13 +- jax/_src/custom_derivatives.py | 17 +- jax/_src/debugging.py | 46 ++-- jax/_src/dispatch.py | 9 +- jax/_src/effects.py | 46 ++++ jax/_src/interpreters/mlir.py | 11 +- jax/_src/interpreters/pxla.py | 32 +-- jax/_src/lax/control_flow/common.py | 5 +- jax/_src/lax/control_flow/conditionals.py | 12 +- jax/_src/lax/control_flow/loops.py | 18 +- jax/_src/lax/lax.py | 14 +- jax/_src/maps.py | 10 +- jax/_src/pjit.py | 4 +- jax/_src/state/types.py | 7 +- jax/core.py | 1 - jax/experimental/jax2tf/call_tf.py | 30 ++- jax/interpreters/partial_eval.py | 3 +- tests/jaxpr_effects_test.py | 290 ++++++++++++---------- tests/python_callback_test.py | 6 +- 23 files changed, 352 insertions(+), 271 deletions(-) create mode 100644 jax/_src/effects.py diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index e08c00098..808dbc38a 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -13,8 +13,8 @@ # limitations under the License. from functools import partial -from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any, - FrozenSet) +from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple, + Union) import types import jax @@ -25,6 +25,7 @@ from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import core from jax._src import linear_util as lu +from jax._src import effects from jax._src import source_info_util from jax._src import traceback_util from jax._src import util @@ -44,6 +45,7 @@ traceback_util.register_exclusion(__file__) map = safe_map zip = safe_zip +allowed_effects: effects.EffectTypeSet = effects.remat_allowed_effects ### Policies @@ -451,14 +453,11 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): return out_primals, out_tangents ad.primitive_jvps[remat_p] = remat_jvp -remat_allowed_effects: Set[core.Effect] = set() -remat_allowed_effects.add(lax_internal.InOutFeedEffect.Infeed) -remat_allowed_effects.add(lax_internal.InOutFeedEffect.Outfeed) +allowed_effects.add_type(lax_internal.InOutFeedEffect) def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars - disallowed_effects = {eff for eff in jaxpr.effects - if eff not in remat_allowed_effects} + disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( 'Effects not supported in partial-eval of `checkpoint`/`remat`: ' diff --git a/jax/_src/api.py b/jax/_src/api.py index 33b939aa9..e81e19f3c 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -43,6 +43,7 @@ from jax._src import callback as jcb from jax._src import core from jax._src import device_array from jax._src import dispatch +from jax._src import effects from jax._src import array from jax._src import dtypes from jax._src import source_info_util @@ -1071,10 +1072,10 @@ def xla_computation(fun: Callable, else: out_parts_flat = tuple(flatten_axes( "xla_computation out_parts", out_tree(), out_parts)) - unordered_effects = [eff for eff in jaxpr.effects - if eff not in core.ordered_effects] - ordered_effects = [eff for eff in jaxpr.effects - if eff in core.ordered_effects] + unordered_effects = list( + effects.ordered_effects.filter_not_in(jaxpr.effects)) + ordered_effects = list( + effects.ordered_effects.filter_in(jaxpr.effects)) lowering_result = mlir.lower_jaxpr_to_module( f"xla_computation_{fun_name}", core.ClosedJaxpr(jaxpr, consts), diff --git a/jax/_src/callback.py b/jax/_src/callback.py index f8511ef40..7f68947c5 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -24,6 +24,7 @@ from jax.interpreters import mlir from jax._src import core from jax._src import dtypes +from jax._src import effects from jax._src import util from jax._src import dispatch from jax._src.interpreters import ad @@ -158,17 +159,19 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, io_callback_p = core.Primitive("io_callback") io_callback_p.multiple_results = True -class IOEffect: +class IOEffect(effects.Effect): __str__ = lambda _: "IO" -class OrderedIOEffect: + +class OrderedIOEffect(effects.Effect): __str__ = lambda _: "OrderedIO" + _IOEffect = IOEffect() _OrderedIOEffect = OrderedIOEffect() -mlir.lowerable_effects.add(_IOEffect) -mlir.lowerable_effects.add(_OrderedIOEffect) -core.control_flow_allowed_effects.add(_IOEffect) -core.control_flow_allowed_effects.add(_OrderedIOEffect) -core.ordered_effects.add(_OrderedIOEffect) +effects.lowerable_effects.add_type(IOEffect) +effects.lowerable_effects.add_type(OrderedIOEffect) +effects.control_flow_allowed_effects.add_type(IOEffect) +effects.control_flow_allowed_effects.add_type(OrderedIOEffect) +effects.ordered_effects.add_type(OrderedIOEffect) def io_callback_impl(*args, result_avals, callback: Callable[..., Any], diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 3c96c0ccb..a1422a1a2 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -36,13 +36,13 @@ from jax.tree_util import tree_unflatten from jax._src import linear_util as lu from jax._src import core from jax._src import custom_derivatives +from jax._src import effects from jax._src import prng from jax._src import source_info_util from jax._src import traceback_util from jax._src.config import config from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.lax import control_flow as cf from jax._src.sharding import GSPMDSharding from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, @@ -100,20 +100,18 @@ class JaxException(Exception): @functools.total_ordering @dataclasses.dataclass(eq=True, frozen=True) -class ErrorEffect: +class ErrorEffect(effects.Effect): error_type: Type[JaxException] shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...] - def __post_init__(self): - cf.allowed_effects.add(self) - mlir.lowerable_effects.add(self) - def __lt__(self, other: 'ErrorEffect'): shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable for sd in x.shape_dtypes) unpack = lambda x: (str(x.error_type), shape_dtypes(x)) return (unpack(self) < unpack(other)) +effects.control_flow_allowed_effects.add_type(ErrorEffect) +effects.lowerable_effects.add_type(ErrorEffect) class DivisionByZeroError(JaxException): diff --git a/jax/_src/core.py b/jax/_src/core.py index 9d2b79d33..16ae136ca 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -37,6 +37,7 @@ import numpy as np from jax._src import dtypes from jax._src import config as jax_config +from jax._src import effects from jax._src.config import FLAGS, config from jax.errors import (ConcretizationTypeError, TracerArrayConversionError, TracerIntegerConversionError, UnexpectedTracerError) @@ -59,12 +60,10 @@ map, unsafe_map = safe_map, map # -------------------- jaxprs -------------------- -Effect = Hashable -Effects = Set[Effect] -no_effects: Effects = set() -ordered_effects: Set[Effect] = set() -control_flow_allowed_effects: Set[Effect] = set() - +Effect = effects.Effect +Effects = effects.Effects +EffectTypeSet = effects.EffectTypeSet +no_effects: Effects = effects.no_effects class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects'] @@ -2829,7 +2828,7 @@ def _check_map(ctx_factory, prim, in_avals, params): if "call_jaxpr" not in params: raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter") call_jaxpr = params["call_jaxpr"] - ordered_effects_ = call_jaxpr.effects & ordered_effects + ordered_effects_ = effects.ordered_effects.filter_in(call_jaxpr.effects) if ordered_effects_: raise JaxprTypeError( f"Map primitive {prim} mapping ordered effects: {ordered_effects_}") diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 15e944888..b34343b5c 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -14,8 +14,7 @@ from functools import update_wrapper, reduce, partial import inspect -from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set, - Any) +from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Any) from jax.custom_transpose import custom_transpose from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, @@ -29,6 +28,7 @@ from jax.config import config from jax._src import core from jax._src import custom_api_util from jax._src import dtypes +from jax._src import effects from jax._src import linear_util as lu from jax._src import traceback_util from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p @@ -47,6 +47,8 @@ traceback_util.register_exclusion(__file__) map = safe_map zip = safe_zip +allowed_effects: effects.EffectTypeSet = ( + effects.custom_derivatives_allowed_effects) ### util @@ -380,18 +382,14 @@ def process_env_traces(primitive, level: int, jvp_was_run: bool, *args): yield outs, tuple(todo) # Ensure the aux output is immutable -allowed_effects: Set[core.Effect] = set() -allowed_effects.add(lax.InOutFeedEffect.Infeed) -allowed_effects.add(lax.InOutFeedEffect.Outfeed) - +allowed_effects.add_type(lax.InOutFeedEffect) custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts): # TODO(mattjj): could do more checking here... del in_avals, jvp_jaxpr_thunk, num_consts - disallowed_effects = {eff for eff in call_jaxpr.effects if eff not in - allowed_effects} + disallowed_effects = allowed_effects.filter_not_in(call_jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `custom_jvp`: {disallowed_effects}') @@ -714,8 +712,7 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): return core.jaxpr_as_fun(fun_jaxpr)(*args) def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): - disallowed_effects = {eff for eff in fun_jaxpr.effects if eff not in - allowed_effects} + disallowed_effects = allowed_effects.filter_not_in(fun_jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `custom_vjp`: {disallowed_effects}') diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 74e6d9529..26ab5c8fa 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -13,7 +13,6 @@ # limitations under the License. """Module for JAX debugging primitives and related functionality.""" -import enum import functools import string import sys @@ -25,24 +24,21 @@ import numpy as np import jax.numpy as jnp from jax import tree_util from jax import lax -from jax.config import config -from jax.experimental import pjit -from jax.interpreters import partial_eval as pe -from jax.interpreters import pxla - -from jax._src import ad_checkpoint from jax._src import core -from jax._src import custom_derivatives +from jax._src import effects from jax._src import linear_util as lu +from jax._src import pjit from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lax import control_flow as lcf +from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import Sharding, GSPMDSharding, NamedSharding +from jax.interpreters import partial_eval as pe +from jax.config import config # pytype: disable=import-error try: @@ -58,17 +54,23 @@ except: RICH_ENABLED = False # pytype: enable=import-error -DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT']) +class DebugEffect(effects.Effect): + pass +debug_effect = DebugEffect() -core.ordered_effects.add(DebugEffect.ORDERED_PRINT) -mlir.lowerable_effects.add(DebugEffect.PRINT) -mlir.lowerable_effects.add(DebugEffect.ORDERED_PRINT) -lcf.allowed_effects.add(DebugEffect.PRINT) -lcf.allowed_effects.add(DebugEffect.ORDERED_PRINT) -ad_checkpoint.remat_allowed_effects.add(DebugEffect.PRINT) -ad_checkpoint.remat_allowed_effects.add(DebugEffect.ORDERED_PRINT) -custom_derivatives.allowed_effects.add(DebugEffect.PRINT) -custom_derivatives.allowed_effects.add(DebugEffect.ORDERED_PRINT) +class OrderedDebugEffect(effects.Effect): + pass +ordered_debug_effect = OrderedDebugEffect() + +effects.ordered_effects.add_type(OrderedDebugEffect) +effects.lowerable_effects.add_type(DebugEffect) +effects.lowerable_effects.add_type(OrderedDebugEffect) +effects.control_flow_allowed_effects.add_type(DebugEffect) +effects.control_flow_allowed_effects.add_type(OrderedDebugEffect) +effects.remat_allowed_effects.add_type(DebugEffect) +effects.remat_allowed_effects.add_type(OrderedDebugEffect) +effects.custom_derivatives_allowed_effects.add_type(DebugEffect) +effects.custom_derivatives_allowed_effects.add_type(OrderedDebugEffect) # `debug_callback_p` is the main primitive for staging out Python callbacks. debug_callback_p = core.Primitive('debug_callback') @@ -133,7 +135,7 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): return tuple( debug_callback_p.impl( *flat_args, effect=effect, callback=callback, **params)) - if effect in core.ordered_effects: + if effects.ordered_effects.contains(effect): token = ctx.tokens_in.get(effect)[0] result, token, keepalive = mlir.emit_python_callback( ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True) @@ -209,7 +211,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any, The value of `callback(*args, **kwargs)`. """ flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT + effect = ordered_debug_effect if ordered else debug_effect def _flat_callback(*flat_args): args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) callback(*args, **kwargs) @@ -273,7 +275,7 @@ inspect_sharding_p.def_impl(_inspect_sharding_impl) def _inspect_sharding_abstract_eval(aval, **_): del aval # Effectful abstract avoids DCE - return [], {DebugEffect.PRINT} + return [], {debug_effect} inspect_sharding_p.def_effectful_abstract_eval(_inspect_sharding_abstract_eval) def _inspect_sharding_batching_rule(args, _, *, callback): diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 5a94d7e77..d4c775f95 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -42,6 +42,7 @@ from jax._src import array from jax._src import core from jax._src import device_array from jax._src import dtypes +from jax._src import effects from jax._src import linear_util as lu from jax._src import path from jax._src import profiler @@ -507,10 +508,10 @@ def lower_xla_callable( if type(d) is pe.InDBIdx else d for d in a.shape)) if type(a) is core.DShapedArray else a, b) for a, b in out_type] module_name = f"jit_{fun.__name__}" - unordered_effects = [eff for eff in closed_jaxpr.effects - if eff not in core.ordered_effects] - ordered_effects = [eff for eff in closed_jaxpr.effects - if eff in core.ordered_effects] + unordered_effects = list( + effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) + ordered_effects = list( + effects.ordered_effects.filter_in(closed_jaxpr.effects)) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, unordered_effects, ordered_effects, backend, backend.platform, diff --git a/jax/_src/effects.py b/jax/_src/effects.py new file mode 100644 index 000000000..7416f818c --- /dev/null +++ b/jax/_src/effects.py @@ -0,0 +1,46 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Type, Set + + +class Effect: + """A generic side-effect.""" + +Effects = Set[Effect] + +class EffectTypeSet: + + def __init__(self): + self._effect_types: Set[Type[Effect]] = set() + + def add_type(self, effect_type: Type[Effect]): + self._effect_types.add(effect_type) + + def contains(self, eff: Effect) -> bool: + return any(isinstance(eff, eff_type) for eff_type in self._effect_types) + + def filter_in(self, effects: Effects) -> Effects: + return {eff for eff in effects if self.contains(eff)} + + def filter_not_in(self, effects: Effects) -> Effects: + return {eff for eff in effects if not self.contains(eff)} + +no_effects: Effects = set() +ordered_effects: EffectTypeSet = EffectTypeSet() +lowerable_effects: EffectTypeSet = EffectTypeSet() +control_flow_allowed_effects: EffectTypeSet = EffectTypeSet() +custom_derivatives_allowed_effects: EffectTypeSet = EffectTypeSet() +remat_allowed_effects: EffectTypeSet = EffectTypeSet() diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index fc78dcb8d..8e677a9e8 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -39,6 +39,7 @@ from jax._src import ad_util from jax._src import core from jax._src import device_array from jax._src import dtypes +from jax._src import effects as effects_lib from jax._src import source_info_util from jax._src import util from jax._src.lib import xla_bridge as xb @@ -58,7 +59,7 @@ Value = Any # = ir.Value # mypy implicitly sets this variable to true when type checking. MYPY = False -lowerable_effects: Set[core.Effect] = set() +lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects # IR Helpers @@ -695,7 +696,8 @@ def lower_jaxpr_to_module( if platform in _platforms_with_donation: input_output_aliases, donated_args = _set_up_aliases( in_avals, out_avals, donated_args) - if any(eff not in lowerable_effects for eff in jaxpr.effects): + unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) + if unlowerable_effects: raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') if any(donated_args): # TODO(tomhennigan): At call time we should mark these buffers as deleted. @@ -731,8 +733,7 @@ def lower_jaxpr_to_module( module_name = _module_name_regex.sub("_", module_name) ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get( module_name) - unlowerable_effects = {eff for eff in jaxpr.effects - if eff not in lowerable_effects} + unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) if unlowerable_effects: raise ValueError( f'Cannot lower jaxpr with unlowerable effects: {unlowerable_effects}') @@ -1140,7 +1141,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, f"found for platform {ctx.platform}") eqn_ctx = ctx.replace(name_stack=source_info.name_stack) - effects = [eff for eff in eqn.effects if eff in core.ordered_effects] + effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) tokens_in = tokens.subset(effects) avals_in = map(aval, eqn.invars) rule_ctx = LoweringRuleContext( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c8b912b15..06ed98e9a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -57,6 +57,7 @@ from jax._src import core from jax._src import device_array from jax._src import dispatch from jax._src import dtypes +from jax._src import effects from jax._src import linear_util as lu from jax._src import profiler from jax._src import sharding as sharding_internal @@ -1487,12 +1488,12 @@ def lower_parallel_callable( backend.platform) module_name = f"pmap_{fun.__name__}" with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore - if any(eff in core.ordered_effects for eff in closed_jaxpr.effects): + ordered_effects = list( + effects.ordered_effects.filter_in(closed_jaxpr.effects)) + if ordered_effects: raise ValueError("Ordered effects not supported in `pmap`.") - unordered_effects = [eff for eff in closed_jaxpr.effects - if eff not in core.ordered_effects] - ordered_effects = [eff for eff in closed_jaxpr.effects - if eff in core.ordered_effects] + unordered_effects = list( + effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, @@ -3007,12 +3008,12 @@ def lower_sharding_computation( module_name = f"{api_name}_{fun_name}" if len(device_assignment) > 1: - if any(eff in core.ordered_effects for eff in closed_jaxpr.effects): + if any(effects.ordered_effects.contains(eff) for eff + in closed_jaxpr.effects): raise ValueError("Ordered effects are not supported for more than 1 device.") - unordered_effects = [eff for eff in closed_jaxpr.effects - if eff not in core.ordered_effects] - ordered_effects = [eff for eff in closed_jaxpr.effects - if eff in core.ordered_effects] + unordered_effects = list( + effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) + ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, @@ -3180,12 +3181,13 @@ def lower_mesh_computation( module: Union[str, xc.XlaComputation] module_name = f"{api_name}_{fun_name}" with core.extend_axis_env_nd(mesh.shape.items()): - if any(eff in core.ordered_effects for eff in closed_jaxpr.effects): + if any(effects.ordered_effects.contains(eff) for eff + in closed_jaxpr.effects): raise ValueError("Ordered effects not supported in mesh computations.") - unordered_effects = [eff for eff in closed_jaxpr.effects - if eff not in core.ordered_effects] - ordered_effects = [eff for eff in closed_jaxpr.effects - if eff in core.ordered_effects] + unordered_effects = list(effects.ordered_effects.filter_not_in( + closed_jaxpr.effects)) + ordered_effects = list(effects.ordered_effects.filter_in( + closed_jaxpr.effects)) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index d8e15442d..fe430a408 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -19,8 +19,8 @@ from typing import Callable, Optional, Sequence from jax._src import core from jax._src import linear_util as lu -from jax._src.core import control_flow_allowed_effects as allowed_effects from jax._src.lax import lax +from jax._src.effects import control_flow_allowed_effects as allowed_effects from jax._src import ad_util from jax._src import util from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3 @@ -30,8 +30,7 @@ from jax.tree_util import tree_map, tree_unflatten map, unsafe_map = safe_map, map -allowed_effects.add(lax.InOutFeedEffect.Infeed) -allowed_effects.add(lax.InOutFeedEffect.Outfeed) +allowed_effects.add_type(lax.InOutFeedEffect) def _abstractify(x): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 52754d5fc..d75b5173b 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -31,6 +31,7 @@ from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import core from jax._src import dtypes +from jax._src import effects from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import util @@ -136,7 +137,7 @@ def switch(index, branches: Sequence[Callable], *operands, out_trees[0], jaxprs[0].out_avals, out_tree, jaxpr.out_avals) joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) - disallowed_effects = joined_effects - allowed_effects + disallowed_effects = allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') @@ -245,7 +246,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, out_tree, true_jaxpr.out_avals, false_out_tree, false_jaxpr.out_avals) joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects) - disallowed_effects = joined_effects - allowed_effects + disallowed_effects = allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') @@ -304,7 +305,7 @@ def _cond_with_per_branch_args(pred, def _cond_abstract_eval(*avals, branches, **_): joined_effects = core.join_effects(*(b.effects for b in branches)) - disallowed_effects = joined_effects - allowed_effects + disallowed_effects = allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') @@ -751,7 +752,7 @@ def _cond_typecheck(*in_atoms, branches, linear): jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals) jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals) joined_effects = core.join_effects(*(b.effects for b in branches)) - disallowed_effects = joined_effects - allowed_effects + disallowed_effects = allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') @@ -818,8 +819,7 @@ pe.dce_rules[cond_p] = _cond_dce_rule def _cond_lowering(ctx, index, *args, branches, linear): del linear # Unused. joined_effects = core.join_effects(*(branch.effects for branch in branches)) - ordered_effects = [eff for eff in joined_effects - if eff in core.ordered_effects] + ordered_effects = list(effects.ordered_effects.filter_in(joined_effects)) num_tokens = len(ordered_effects) tokens_in = ctx.tokens_in.subset(ordered_effects) output_token_types = [mlir.token_type() for _ in ordered_effects] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 7755660af..7231a6909 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -35,6 +35,7 @@ from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api from jax._src import dtypes +from jax._src import effects from jax._src import source_info_util from jax._src import util from jax._src.lax import lax @@ -266,7 +267,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], # Extract the subtree and avals for the first element of the return tuple out_tree_children[0], carry_avals_out, init_tree, carry_avals) - disallowed_effects = jaxpr.effects - allowed_effects + disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `scan`: {disallowed_effects}') @@ -1098,7 +1099,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric], body_tree, body_jaxpr.out_avals, in_tree_children[0], init_avals) effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) - disallowed_effects = effects - allowed_effects + disallowed_effects = allowed_effects.filter_not_in(effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') @@ -1110,7 +1111,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric], def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs): del args, kwargs joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) - disallowed_effects = joined_effects - allowed_effects + disallowed_effects = allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') @@ -1422,8 +1423,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): pred_aval = cond_jaxpr.out_avals[0] batched = bool(pred_aval.shape) - cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in - core.ordered_effects] + cond_ordered_effects = effects.ordered_effects.filter_in(cond_jaxpr.effects) if cond_ordered_effects: def cond(args): # Pred can be batched @@ -1448,8 +1448,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, return mlir.lower_fun(fun)(ctx, *args) loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in) - body_effects = [eff for eff in body_jaxpr.effects - if eff in core.ordered_effects] + body_effects = effects.ordered_effects.filter_in(body_jaxpr.effects) num_tokens = len(body_effects) tokens = [ctx.tokens_in.get(eff) for eff in body_effects] token_types = [mlir.token_type() for _ in tokens] @@ -1533,9 +1532,10 @@ def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): # TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) - if joined_effects - allowed_effects: + disallowed_effects = allowed_effects.filter_not_in(joined_effects) + if disallowed_effects: raise NotImplementedError( - f'Effects not supported in `while`: {joined_effects - allowed_effects}') + f'Effects not supported in `while`: {disallowed_effects}') return body_jaxpr.out_avals, joined_effects while_p = core.AxisPrimitive('while') diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 3bd96abe8..8a2f68156 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -39,6 +39,7 @@ from jax._src import core from jax._src import device_array from jax._src import dispatch from jax._src import dtypes +from jax._src import effects from jax._src import linear_util as lu from jax._src import pretty_printer as pp from jax._src import source_info_util @@ -4162,7 +4163,10 @@ def _after_all_lowering(ctx, *operands): mlir.register_lowering(after_all_p, _after_all_lowering) -InOutFeedEffect = enum.Enum('InOutFeedEffect', ['Infeed', 'Outfeed']) +class InOutFeedEffect(effects.Effect): + pass +infeed_effect = InOutFeedEffect() +outfeed_effect = InOutFeedEffect() def infeed(token, shape=None, partitions=None): @@ -4190,14 +4194,14 @@ def infeed(token, shape=None, partitions=None): def _infeed_abstract_eval(token, *, shapes, partitions): if token is not abstract_token: raise TypeError("First argument to infeed must be a token") - return (*shapes, abstract_token), {InOutFeedEffect.Infeed} + return (*shapes, abstract_token), {infeed_effect} infeed_p = Primitive("infeed") infeed_p.multiple_results = True infeed_p.def_impl(partial(xla.apply_primitive, infeed_p)) infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval) -mlir.lowerable_effects.add(InOutFeedEffect.Infeed) +mlir.lowerable_effects.add_type(InOutFeedEffect) def _infeed_lowering(ctx, token, *, shapes, partitions): @@ -4243,12 +4247,12 @@ def outfeed(token, xs, partitions = None): def _outfeed_abstract_eval(token, *xs, partitions): if token is not abstract_token: raise TypeError("First argument to outfeed must be a token") - return abstract_token, {InOutFeedEffect.Outfeed} + return abstract_token, {outfeed_effect} outfeed_p = Primitive("outfeed") outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p)) outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval) -mlir.lowerable_effects.add(InOutFeedEffect.Outfeed) +mlir.lowerable_effects.add_type(InOutFeedEffect) def _outfeed_lowering(ctx, token, *xs, partitions): diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 3f631a4c8..e6f983b82 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -25,6 +25,7 @@ from jax._src import core from jax._src import linear_util as lu from jax import stages from jax._src import dispatch +from jax._src import effects from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map, treedef_tuple) from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes, @@ -1375,7 +1376,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(ctx.module_context.name_stack, wrap_name(name, 'xmap'))) - if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects): + if any(effects.ordered_effects.contains(eff) for eff + in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(), const_nodes, *tiled_ins, @@ -1442,7 +1444,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(ctx.module_context.name_stack, wrap_name(name, 'xmap'))) - if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects): + if any(effects.ordered_effects.contains(eff) for eff + in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(), const_nodes, *sharded_global_in_nodes, @@ -1494,7 +1497,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, name_stack=extend_name_stack(ctx.module_context.name_stack, wrap_name(name, 'xmap')), axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes)) - if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects): + if any(effects.ordered_effects.contains(eff) for eff + in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(), const_nodes, *([n] for n in global_in_nodes), diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8a59852c9..b9d4be453 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1465,9 +1465,9 @@ pe.custom_staging_rules[pjit_p] = pjit_staging_rule def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, out_positional_semantics, **_): - disallowed_effects = jaxpr.effects - mlir.lowerable_effects + disallowed_effects = mlir.lowerable_effects.filter_not_in(jaxpr.effects) if disallowed_effects: - raise ValueError('Effects not supported in `pjit`.') + raise ValueError(f'Effects not supported in `pjit`: {disallowed_effects}.') if config.jax_array: return jaxpr.out_avals, jaxpr.effects return global_to_local(out_positional_semantics, jaxpr.out_avals, diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e9c960e7e..8cb8a699a 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -17,9 +17,9 @@ from __future__ import annotations from typing import Any, List, Optional, Sequence, Set, Union from jax._src import core +from jax._src import effects from jax._src.lib import xla_bridge, xla_client from jax._src.util import safe_map, safe_zip, tuple_insert, tuple_delete, prod -from jax._src.lax.control_flow import common xc = xla_client xb = xla_bridge @@ -31,10 +31,9 @@ zip, unsafe_zip = safe_zip, zip Array = Any -class RefEffect: +class RefEffect(effects.Effect): def __init__(self, ref_aval: ShapedArrayRef): self.ref_aval = ref_aval - common.allowed_effects.add(self) def __eq__(self, other): if not isinstance(other, self.__class__): @@ -61,6 +60,8 @@ class AccumEffect(RefEffect): def __str__(self): return f"Accum<{self.ref_aval}>" +effects.control_flow_allowed_effects.add_type(RefEffect) + StateEffect = Union[ReadEffect, WriteEffect, AccumEffect] # ## `Ref`s diff --git a/jax/core.py b/jax/core.py index b44e6d88c..7aa361c1d 100644 --- a/jax/core.py +++ b/jax/core.py @@ -156,7 +156,6 @@ from jax._src.core import ( np as np, opaque_dtypes as opaque_dtypes, operator as operator, - ordered_effects as ordered_effects, outfeed_primitives as outfeed_primitives, partial as partial, partialmethod as partialmethod, diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 8fa1213dc..78773021e 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -33,19 +33,20 @@ from jax import dlpack from jax import dtypes from jax import numpy as jnp from jax import tree_util -from jax._src import util -from jax._src import ad_util -from jax._src.lax import control_flow as lax_control_flow +from jax._src import core from jax._src import ad_checkpoint from jax._src import custom_derivatives -from jax.interpreters import mlir -from jax.interpreters import xla -from jax._src import core +from jax._src import ad_util +from jax._src import effects +from jax._src import util +from jax._src.lax import control_flow as lax_control_flow from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client from jax.experimental.jax2tf import jax2tf as jax2tf_internal +from jax.interpreters import mlir +from jax.interpreters import xla import numpy as np import tensorflow as tf # type: ignore[import] @@ -296,13 +297,16 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc return function_flat_tf.get_concrete_function(*args_flat_sig_tf) -# Mark the effectful instancess of call_tf -CallTfEffect = enum.Enum('CallTfEffect', ['EFFECT']) +# Mark the effectful instances of call_tf +class CallTfEffect(effects.Effect): + __str__ = lambda _: "CallTfEffect" -mlir.lowerable_effects.add(CallTfEffect.EFFECT) -lax_control_flow.allowed_effects.add(CallTfEffect.EFFECT) -ad_checkpoint.remat_allowed_effects.add(CallTfEffect.EFFECT) -custom_derivatives.allowed_effects.add(CallTfEffect.EFFECT) +call_tf_effect = CallTfEffect() + +effects.lowerable_effects.add_type(CallTfEffect) +effects.control_flow_allowed_effects.add_type(CallTfEffect) +effects.remat_allowed_effects.add_type(CallTfEffect) +effects.custom_derivatives_allowed_effects.add_type(CallTfEffect) def _call_tf_abstract_eval(*_, @@ -316,7 +320,7 @@ def _call_tf_abstract_eval(*_, def is_fully_known_shape(s): return s.rank is not None and all([d is not None for d in s]) - effects = {CallTfEffect.EFFECT} if has_side_effects else set() + effects = {call_tf_effect} if has_side_effects else set() if all([is_fully_known_shape(s) for s in concrete_function_flat_tf.output_shapes]): return ( diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 519ef1f8b..0f4998592 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -30,6 +30,7 @@ from jax._src import linear_util as lu from jax.config import config from jax._src import api_util from jax._src import core +from jax._src import effects from jax._src import dtypes from jax._src import profiler from jax._src import source_info_util @@ -1782,7 +1783,7 @@ class DynamicJaxprTrace(core.Trace): with core.new_sublevel(): jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic( f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name)) - ordered_effects = jaxpr.effects & core.ordered_effects + ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: raise ValueError("Ordered effects not supported for " f"map primitives: {ordered_effects}") diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index b5b05300b..735709398 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -21,6 +21,7 @@ import jax import jax.numpy as jnp from jax._src import core from jax import lax +from jax._src import effects from jax._src import linear_util as lu from jax.config import config from jax.interpreters import ad @@ -32,7 +33,6 @@ from jax._src import ad_checkpoint from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util -from jax._src.lax import control_flow as lcf from jax._src.lib import xla_bridge import numpy as np @@ -49,22 +49,34 @@ def effect_jvp_rule(primals, tangents, effect): return effect_p.bind(*primals, effect=effect), tangents ad.primitive_jvps[effect_p] = effect_jvp_rule -mlir.lowerable_effects.add('foo') -mlir.lowerable_effects.add('foo2') -mlir.lowerable_effects.add('bar') -mlir.lowerable_effects.add('while') -mlir.lowerable_effects.add('while1') -mlir.lowerable_effects.add('while2') -core.ordered_effects.add('foo') -core.ordered_effects.add('foo2') -core.ordered_effects.add('while1') -core.ordered_effects.add('while2') +class BasicEffect(effects.Effect): + def __init__(self, name): + self.name = name -lcf.allowed_effects.add('while') -lcf.allowed_effects.add('while1') -lcf.allowed_effects.add('while2') + __repr__ = lambda self: self.name -ad_checkpoint.remat_allowed_effects.add('remat') +class OrderedEffect(BasicEffect): pass +class UnlowerableEffect(effects.Effect): pass +class WhileEffect(effects.Effect): pass +class RematEffect(effects.Effect): pass + +foo_effect = OrderedEffect("foo") +foo2_effect = OrderedEffect("foo2") +bar_effect = BasicEffect("bar") +baz_effect = UnlowerableEffect() +while_effect = WhileEffect() +while1_effect = WhileEffect() +while2_effect = WhileEffect() +log_effect = OrderedEffect("log") +unordered_log_effect = BasicEffect("unordered_log") + +effects.lowerable_effects.add_type(BasicEffect) +effects.lowerable_effects.add_type(WhileEffect) +effects.ordered_effects.add_type(OrderedEffect) +effects.ordered_effects.add_type(WhileEffect) +effects.control_flow_allowed_effects.add_type(WhileEffect) + +effects.remat_allowed_effects.add_type(RematEffect) def trivial_effect_lowering(ctx, *, effect): @@ -92,10 +104,6 @@ def function_effect_lowering(ctx, *, effect): callback_p = core.Primitive('callback') callback_p.multiple_results = True -mlir.lowerable_effects.add('log') -mlir.lowerable_effects.add('unordered_log') -core.ordered_effects.add('log') - @callback_p.def_impl def _(*args, callback, out_avals, effect): del out_avals, effect @@ -110,7 +118,7 @@ def _(*avals, callback, out_avals, effect): def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect): del out_avals token_in = None - if effect in core.ordered_effects: + if effects.ordered_effects.contains(effect): token_in = ctx.tokens_in.get(effect)[0] out_op, token_out, keep_alive = mlir.emit_python_callback( @@ -149,31 +157,31 @@ class JaxprEffectsTest(jtu.JaxTestCase): def test_effectful_primitive_in_jaxpr_creates_effects(self): def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. jaxpr = jax.make_jaxpr(f)(2.) - self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects) - self.assertEqual({'foo'}, jaxpr.effects) + self.assertEqual({foo_effect}, jaxpr.jaxpr.eqns[0].effects) + self.assertEqual({foo_effect}, jaxpr.effects) def test_different_effects_in_jaxpr(self): def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return x + 1. jaxpr = jax.make_jaxpr(f)(2.) - self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects) - self.assertEqual({'bar'}, jaxpr.jaxpr.eqns[1].effects) - self.assertEqual({'foo', 'bar'}, jaxpr.effects) + self.assertEqual({foo_effect}, jaxpr.jaxpr.eqns[0].effects) + self.assertEqual({bar_effect}, jaxpr.jaxpr.eqns[1].effects) + self.assertEqual({foo_effect, bar_effect}, jaxpr.effects) def test_jaxpr_typecheck_should_verify_eqn_effects_are_subset(self): def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return x + 1. jaxpr = jax.make_jaxpr(f)(2.).jaxpr # Edit jaxpr to make its type wrong - jaxpr = jaxpr.replace(effects={'foo'}) + jaxpr = jaxpr.replace(effects={foo_effect}) with self.assertRaisesRegex(core.JaxprTypeError, 'Equation effects are not subset of Jaxpr effects.'): @@ -186,48 +194,52 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase): def f(x): @lu.wrap_init def f_(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return [x] return core.call(f_, x)[0] - jax.make_jaxpr(f)(2.) + jaxpr = jax.make_jaxpr(f)(2.) + self.assertIn(foo_effect, jaxpr.jaxpr.effects) + self.assertIn(bar_effect, jaxpr.jaxpr.effects) def test_xla_call_primitive_inherits_effects(self): @jax.jit def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return x jax.make_jaxpr(f)(2.) + jaxpr = jax.make_jaxpr(f)(2.) + self.assertIn(foo_effect, jaxpr.jaxpr.effects) + self.assertIn(bar_effect, jaxpr.jaxpr.effects) - @jtu.sample_product(flavor=["old", "new"]) - def test_remat_call_primitive_inherits_effects(self, flavor): - remat = jax.remat if flavor == "old" else ad_checkpoint.checkpoint + def test_remat_call_primitive_inherits_effects(self): - @remat + @jax.checkpoint def f(x): - x, = effect_p.bind(x, effect='foo') - x, = effect_p.bind(x, effect='bar') + x, = effect_p.bind(x, effect=foo_effect) + x, = effect_p.bind(x, effect=bar_effect) return x jax.make_jaxpr(f)(2.) with self.assertRaisesRegex(NotImplementedError, "Effects not supported"): jax.make_jaxpr(lambda x: jax.linearize(f, x)[1](x))(2.) def test_new_remat_allows_certain_effects(self): + remat_effect = RematEffect() @ad_checkpoint.checkpoint def f(x): - x, = effect_p.bind(x, effect='remat') + x, = effect_p.bind(x, effect=remat_effect) return x jaxpr = jax.make_jaxpr(f)(2.) - self.assertSetEqual(jaxpr.effects, {"remat"}) + self.assertSetEqual(jaxpr.effects, {remat_effect}) def test_custom_jvp_primitive_inherits_effects(self): @jax.custom_jvp def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return x f.defjvp(lambda x, t: (x, t)) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): @@ -237,8 +249,8 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase): @jax.custom_vjp def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return x f.defvjp( fwd=lambda x: (x, ()), @@ -250,27 +262,27 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase): @jax.pmap def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return x with self.assertRaisesRegex( ValueError, - "Ordered effects not supported for map primitives: {'foo'}"): + "Ordered effects not supported for map primitives: {.*}"): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count())) def test_xmap_inherits_effects(self): def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return x f = maps.xmap(f, in_axes=['a'], out_axes=['a']) jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count())) - self.assertSetEqual(jaxpr.effects, {"foo", "bar"}) + self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect}) def test_pjit_inherits_effects(self): def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=bar_effect) return x mesh = jax.sharding.Mesh(np.array(jax.devices()), ['x']) if config.jax_array: @@ -280,7 +292,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase): f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec) with mesh: jaxpr = jax.make_jaxpr(f)(np.arange(jax.local_device_count())) - self.assertSetEqual(jaxpr.effects, {"foo", "bar"}) + self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect}) class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @@ -291,7 +303,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): config.update('jax_enable_x64', False) self._old_lowering = mlir._lowerings[effect_p] def _effect_lowering(ctx, *, effect): - if effect in core.ordered_effects: + if effects.ordered_effects.contains(effect): expected_effects = [effect] else: expected_effects = [] @@ -308,12 +320,20 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): config.update('jax_enable_x64', self.old_x64) mlir.register_lowering(effect_p, self._old_lowering) + def test_can_lower_lowerable_effect(self): + @jax.jit + def f(x): + effect_p.bind(effect=foo_effect) + return x + 1. + f.lower(2.) + def test_cannot_lower_unlowerable_effect(self): @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=baz_effect) return x + 1. - f.lower(2.) + with self.assertRaisesRegex(ValueError, "Cannot lower jaxpr with effects"): + f.lower(2.) def test_should_not_pass_tokens_into_unordered_effect(self): @@ -324,7 +344,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit def f(x): - effect_p.bind(effect='bar') + effect_p.bind(effect=bar_effect) return x + 1. f.lower(2.) @@ -337,7 +357,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` needs to ' 'set `tokens_out`'): @@ -346,13 +366,13 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): def test_lowering_that_sets_wrong_tokens_should_cause_error(self): def bad_effect_lowering(ctx, *, effect): - ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get('foo'))) + ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get(foo_effect))) return [] mlir.register_lowering(effect_p, bad_effect_lowering) @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` returns ' 'incorrect set of output token.'): @@ -367,7 +387,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. module = f.lower(2.).compiler_ir() main = module.body.operations[0] @@ -376,8 +396,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='foo2') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=foo2_effect) return x + 1. module = f.lower(2.).compiler_ir() main = module.body.operations[0] @@ -388,7 +408,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. module = f.lower(2.).compiler_ir() main = module.body.operations[0] @@ -397,8 +417,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='foo2') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=foo2_effect) return x + 1. module = f.lower(2.).compiler_ir() main = module.body.operations[0] @@ -413,7 +433,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. module = f.lower(2.).compiler_ir() main = module.body.operations[0] @@ -434,7 +454,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit def f(x): - effect_p.bind(effect='bar') + effect_p.bind(effect=bar_effect) return x + 1. module = f.lower(2.).compiler_ir() main = module.body.operations[0] @@ -450,7 +470,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): def test_lowered_jaxpr_without_ordered_effects_takes_no_dummy_inputs(self): @jax.jit def f(x): - effect_p.bind(effect='bar') + effect_p.bind(effect=bar_effect) return x + 1. module = f.lower(1.).compiler_ir() input_types = module.body.operations[0].type.inputs @@ -465,7 +485,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): def test_lowered_jaxpr_with_ordered_effects_takes_in_dummy_inputs(self): @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. module = f.lower(1.).compiler_ir() input_types = module.body.operations[0].type.inputs @@ -481,8 +501,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): def test_lowered_jaxpr_with_multiple_ordered_effects_takes_in_dummy_inputs(self): @jax.jit def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='foo2') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=foo2_effect) return x + 1. module = f.lower(1.).compiler_ir() input_types = module.body.operations[0].type.inputs @@ -500,14 +520,14 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): def test_can_lower_and_run_jaxpr_with_ordered_effects(self): @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. self.assertEqual(f(2.), 3.) def test_can_lower_and_run_jaxpr_with_unordered_effects(self): @jax.jit def f(x): - effect_p.bind(effect='bar') + effect_p.bind(effect=bar_effect) return x + 1. self.assertEqual(f(2.), 3.) @@ -517,7 +537,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit @jax.pmap def f(x): - effect_p.bind(effect='bar') + effect_p.bind(effect=bar_effect) return x + 1 with self.assertRaisesRegex( NotImplementedError, @@ -530,48 +550,48 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): @jax.jit @jax.pmap def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1 with self.assertRaisesRegex( ValueError, - "Ordered effects not supported for map primitives: {'foo'}"): + "Ordered effects not supported for map primitives: {foo}"): f(jnp.arange(jax.device_count())) def test_runtime_tokens_should_update_after_running_effectful_function(self): @jax.jit def f(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. - self.assertNotIn('foo', dispatch.runtime_tokens.tokens) + self.assertNotIn(foo_effect, dispatch.runtime_tokens.tokens) f(2.) - prev_token = dispatch.runtime_tokens.tokens['foo'] + prev_token = dispatch.runtime_tokens.tokens[foo_effect] f(2.) - curr_token = dispatch.runtime_tokens.tokens['foo'] + curr_token = dispatch.runtime_tokens.tokens[foo_effect] self.assertIsNot(prev_token, curr_token) def test_can_lower_multiple_effects(self): @jax.jit def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='foo2') + effect_p.bind(effect=foo_effect) + effect_p.bind(effect=foo2_effect) return x + 1. @jax.jit def g(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x + 1. - self.assertNotIn('foo', dispatch.runtime_tokens.tokens) - self.assertNotIn('foo2', dispatch.runtime_tokens.tokens) + self.assertNotIn(foo_effect, dispatch.runtime_tokens.tokens) + self.assertNotIn(foo2_effect, dispatch.runtime_tokens.tokens) + f(2.).block_until_ready() + foo_token = dispatch.runtime_tokens.tokens[foo_effect][0] + foo2_token = dispatch.runtime_tokens.tokens[foo2_effect][0] f(2.) - foo_token = dispatch.runtime_tokens.tokens['foo'][0] - foo2_token = dispatch.runtime_tokens.tokens['foo'][0] - f(2.) - self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0]) - self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0]) - foo_token = dispatch.runtime_tokens.tokens['foo'][0] - foo2_token = dispatch.runtime_tokens.tokens['foo2'][0] + self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0]) + self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0]) + foo_token = dispatch.runtime_tokens.tokens[foo_effect][0] + foo2_token = dispatch.runtime_tokens.tokens[foo2_effect][0] g(2.) - self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0]) - self.assertIs(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0]) + self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0]) + self.assertIs(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0]) @jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback class EffectOrderingTest(jtu.JaxTestCase): @@ -589,7 +609,7 @@ class EffectOrderingTest(jtu.JaxTestCase): @jax.jit def f(x): - return callback_p.bind(x, callback=log_value, effect='log', out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) f(2.) jax.effects_barrier() @@ -618,11 +638,11 @@ class EffectOrderingTest(jtu.JaxTestCase): # Expensive computation x = x.dot(x) x = jnp.log(x.sum()) - return callback_p.bind(x, callback=log_value, effect='log', out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) @functools.partial(jax.jit, device=jax.devices()[1]) def g(x): - return callback_p.bind(x, callback=log_value, effect='log', out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) f(jnp.ones((500, 500))) g(3.) @@ -641,12 +661,12 @@ class EffectOrderingTest(jtu.JaxTestCase): raise unittest.SkipTest("Test requires >= 2 devices.") tokens = [] def _noop(_): - tokens.append(dispatch.runtime_tokens.tokens['log'][0]) + tokens.append(dispatch.runtime_tokens.tokens[log_effect][0]) return () @functools.partial(jax.jit, device=jax.devices()[0]) def f(x): - return callback_p.bind(x, callback=_noop, effect='log', out_avals=[]) + return callback_p.bind(x, callback=_noop, effect=log_effect, out_avals=[]) t1 = threading.Thread(target=lambda: f(2.)) t2 = threading.Thread(target=lambda: f(3.)) @@ -673,7 +693,7 @@ class ParallelEffectsTest(jtu.JaxTestCase): def f(x): # foo is lowerable and ordered - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x with self.assertRaisesRegex( ValueError, "Ordered effects not supported in `pmap`."): @@ -683,7 +703,7 @@ class ParallelEffectsTest(jtu.JaxTestCase): def f(x): # bar is lowerable and unordered - effect_p.bind(effect='bar') + effect_p.bind(effect=bar_effect) return x jax.pmap(f)(jnp.arange(jax.local_device_count())) @@ -705,7 +725,7 @@ class ParallelEffectsTest(jtu.JaxTestCase): @jax.pmap def f(x): callback_p.bind( - x, callback=log_value, effect='unordered_log', out_avals=[]) + x, callback=log_value, effect=unordered_log_effect, out_avals=[]) return x + 1 f(jnp.arange(2)).block_until_ready() jax.effects_barrier() @@ -716,7 +736,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_effects_disallowed_in_cond(self): def f1(x): def true_fun(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x def false_fun(x): return x @@ -728,10 +748,10 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_allowed_effect_in_cond(self): def f(x): def true_fun(x): - effect_p.bind(effect='while') + effect_p.bind(effect=while_effect) return x def false_fun(x): - effect_p.bind(effect='while') + effect_p.bind(effect=while_effect) return x return lax.cond(x, true_fun, false_fun, x) f(2) @@ -739,16 +759,16 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_allowed_effect_in_cond_jvp(self): def f(x): def true_fun(x): - effect_p.bind(effect='while') + effect_p.bind(effect=while_effect) return x def false_fun(x): - effect_p.bind(effect='while') + effect_p.bind(effect=while_effect) return x return lax.cond(True, true_fun, false_fun, x) # test primal side gets effect primal_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(f, x)[0])(2.) - self.assertEqual(primal_jaxpr.effects, {'while'}) + self.assertEqual(primal_jaxpr.effects, {while_effect}) # and tangent side does not _, f_lin = jax.linearize(f, 2.) lin_jaxpr = f_lin.func.fun.args[0] @@ -763,8 +783,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): x, = primals t, = tangents # TODO(mattjj,sharadmv): don't require data dependence for jax.linearize! - # effect_p.bind(t, effect='while') - t, = effect_p.bind(t, effect='while') # data dep only on tangents + # effect_p.bind(t, effect=while_effect) + t, = effect_p.bind(t, effect=while_effect) # data dep only on tangents return x, t def f(x): @@ -780,15 +800,15 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): # and tangent side does _, f_lin = jax.linearize(f, 2.) lin_jaxpr = f_lin.func.fun.args[0] - self.assertEqual(lin_jaxpr.effects, {'while'}) + self.assertEqual(lin_jaxpr.effects, {while_effect}) def test_allowed_ordered_effect_in_cond(self): def f(x): def true_fun(x): - effect_p.bind(effect='while1') + effect_p.bind(effect=while1_effect) return x def false_fun(x): - effect_p.bind(effect='while1') + effect_p.bind(effect=while1_effect) return x return lax.cond(x, true_fun, false_fun, x) f(2) @@ -796,12 +816,12 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_multiple_allowed_ordered_effect_in_cond(self): def f(x): def true_fun(x): - effect_p.bind(effect='while1') - effect_p.bind(effect='while2') + effect_p.bind(effect=while1_effect) + effect_p.bind(effect=while2_effect) return x def false_fun(x): - effect_p.bind(effect='while1') - effect_p.bind(effect='while2') + effect_p.bind(effect=while1_effect) + effect_p.bind(effect=while2_effect) return x return lax.cond(x, true_fun, false_fun, x) f(2) @@ -810,7 +830,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def true_fun(x): return x def false_fun(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x return lax.cond(True, true_fun, false_fun, x) @@ -822,7 +842,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def cond_fun(x): return False def body_fun(x): - effect_p.bind(effect='while') + effect_p.bind(effect=while_effect) return x return lax.while_loop(cond_fun, body_fun, x) f(2) @@ -830,7 +850,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_allowed_effect_in_cond_body(self): def f(x): def cond_fun(x): - effect_p.bind(effect='while') + effect_p.bind(effect=while_effect) return False def body_fun(x): return x @@ -842,7 +862,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def cond_fun(x): return False def body_fun(x): - effect_p.bind(effect='while1') + effect_p.bind(effect=while1_effect) return x return lax.while_loop(cond_fun, body_fun, x) f(2) @@ -852,8 +872,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def cond_fun(x): return False def body_fun(x): - effect_p.bind(effect='while1') - effect_p.bind(effect='while2') + effect_p.bind(effect=while1_effect) + effect_p.bind(effect=while2_effect) return x return lax.while_loop(cond_fun, body_fun, x) f(2) @@ -861,7 +881,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_effects_disallowed_in_while(self): def f1(x): def cond_fun(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return False def body_fun(x): return x @@ -874,7 +894,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def cond_fun(x): return False def body_fun(x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return x return lax.while_loop(cond_fun, body_fun, x) @@ -884,7 +904,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_allowed_effect_in_scan(self): def f(x): def body_fun(carry, x): - effect_p.bind(effect='while') + effect_p.bind(effect=while_effect) return carry, x return lax.scan(body_fun, x, jnp.arange(5)) f(2) @@ -892,7 +912,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_allowed_ordered_effect_in_scan(self): def f(x): def body_fun(carry, x): - effect_p.bind(effect='while1') + effect_p.bind(effect=while1_effect) return carry, x return lax.scan(body_fun, x, jnp.arange(5)) f(2) @@ -900,8 +920,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def test_multiple_allowed_ordered_effect_in_scan(self): def f(x): def body_fun(carry, x): - effect_p.bind(effect='while1') - effect_p.bind(effect='while2') + effect_p.bind(effect=while1_effect) + effect_p.bind(effect=while2_effect) return carry, x return lax.scan(body_fun, x, jnp.arange(5)) f(2) @@ -910,7 +930,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase): def f(x): def body(carry, x): - effect_p.bind(effect='foo') + effect_p.bind(effect=foo_effect) return carry, x return lax.scan(body, x, jnp.arange(4)) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 1b98bc723..c3bd8f007 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -24,6 +24,7 @@ from jax import tree_util from jax._src import core from jax._src import debugging from jax._src import dispatch +from jax._src import effects from jax._src import sharding from jax._src import test_util as jtu from jax._src import util @@ -91,8 +92,7 @@ def callback(f, result_shape, *args, ordered: bool = False, **kwargs): core.ShapedArray(s.shape, s.dtype) for s in flat_result_shapes ] effect = ( - debugging.DebugEffect.ORDERED_PRINT - if ordered else debugging.DebugEffect.PRINT) + debugging.ordered_debug_effect if ordered else debugging.debug_effect) flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) def _flat_callback(*flat_args): args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) @@ -111,7 +111,7 @@ def callback_lowering(ctx, *args, effect, callback, **params): return tuple( callback_p.impl(*flat_args, effect=effect, callback=callback, **params)) - if effect in core.ordered_effects: + if effects.ordered_effects.contains(effect): token = ctx.tokens_in.get(effect)[0] result, token, keepalive = mlir.emit_python_callback( ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)