Merge pull request #14549 from sharadmv:dbidx-effects

PiperOrigin-RevId: 510608031
This commit is contained in:
jax authors 2023-02-17 23:43:38 -08:00
commit c0107cc836
23 changed files with 352 additions and 271 deletions

View File

@ -13,8 +13,8 @@
# limitations under the License.
from functools import partial
from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
FrozenSet)
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
Union)
import types
import jax
@ -25,6 +25,7 @@ from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src import effects
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
@ -44,6 +45,7 @@ traceback_util.register_exclusion(__file__)
map = safe_map
zip = safe_zip
allowed_effects: effects.EffectTypeSet = effects.remat_allowed_effects
### Policies
@ -451,14 +453,11 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
return out_primals, out_tangents
ad.primitive_jvps[remat_p] = remat_jvp
remat_allowed_effects: Set[core.Effect] = set()
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Infeed)
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Outfeed)
allowed_effects.add_type(lax_internal.InOutFeedEffect)
def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars
disallowed_effects = {eff for eff in jaxpr.effects
if eff not in remat_allowed_effects}
disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
'Effects not supported in partial-eval of `checkpoint`/`remat`: '

View File

@ -43,6 +43,7 @@ from jax._src import callback as jcb
from jax._src import core
from jax._src import device_array
from jax._src import dispatch
from jax._src import effects
from jax._src import array
from jax._src import dtypes
from jax._src import source_info_util
@ -1071,10 +1072,10 @@ def xla_computation(fun: Callable,
else:
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
unordered_effects = [eff for eff in jaxpr.effects
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in jaxpr.effects
if eff in core.ordered_effects]
unordered_effects = list(
effects.ordered_effects.filter_not_in(jaxpr.effects))
ordered_effects = list(
effects.ordered_effects.filter_in(jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
f"xla_computation_{fun_name}",
core.ClosedJaxpr(jaxpr, consts),

View File

@ -24,6 +24,7 @@ from jax.interpreters import mlir
from jax._src import core
from jax._src import dtypes
from jax._src import effects
from jax._src import util
from jax._src import dispatch
from jax._src.interpreters import ad
@ -158,17 +159,19 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
io_callback_p = core.Primitive("io_callback")
io_callback_p.multiple_results = True
class IOEffect:
class IOEffect(effects.Effect):
__str__ = lambda _: "IO"
class OrderedIOEffect:
class OrderedIOEffect(effects.Effect):
__str__ = lambda _: "OrderedIO"
_IOEffect = IOEffect()
_OrderedIOEffect = OrderedIOEffect()
mlir.lowerable_effects.add(_IOEffect)
mlir.lowerable_effects.add(_OrderedIOEffect)
core.control_flow_allowed_effects.add(_IOEffect)
core.control_flow_allowed_effects.add(_OrderedIOEffect)
core.ordered_effects.add(_OrderedIOEffect)
effects.lowerable_effects.add_type(IOEffect)
effects.lowerable_effects.add_type(OrderedIOEffect)
effects.control_flow_allowed_effects.add_type(IOEffect)
effects.control_flow_allowed_effects.add_type(OrderedIOEffect)
effects.ordered_effects.add_type(OrderedIOEffect)
def io_callback_impl(*args, result_avals, callback: Callable[..., Any],

View File

@ -36,13 +36,13 @@ from jax.tree_util import tree_unflatten
from jax._src import linear_util as lu
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import control_flow as cf
from jax._src.sharding import GSPMDSharding
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
@ -100,20 +100,18 @@ class JaxException(Exception):
@functools.total_ordering
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect:
class ErrorEffect(effects.Effect):
error_type: Type[JaxException]
shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...]
def __post_init__(self):
cf.allowed_effects.add(self)
mlir.lowerable_effects.add(self)
def __lt__(self, other: 'ErrorEffect'):
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
for sd in x.shape_dtypes)
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
return (unpack(self) < unpack(other))
effects.control_flow_allowed_effects.add_type(ErrorEffect)
effects.lowerable_effects.add_type(ErrorEffect)
class DivisionByZeroError(JaxException):

View File

@ -37,6 +37,7 @@ import numpy as np
from jax._src import dtypes
from jax._src import config as jax_config
from jax._src import effects
from jax._src.config import FLAGS, config
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
@ -59,12 +60,10 @@ map, unsafe_map = safe_map, map
# -------------------- jaxprs --------------------
Effect = Hashable
Effects = Set[Effect]
no_effects: Effects = set()
ordered_effects: Set[Effect] = set()
control_flow_allowed_effects: Set[Effect] = set()
Effect = effects.Effect
Effects = effects.Effects
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects']
@ -2829,7 +2828,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
if "call_jaxpr" not in params:
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
ordered_effects_ = call_jaxpr.effects & ordered_effects
ordered_effects_ = effects.ordered_effects.filter_in(call_jaxpr.effects)
if ordered_effects_:
raise JaxprTypeError(
f"Map primitive {prim} mapping ordered effects: {ordered_effects_}")

View File

@ -14,8 +14,7 @@
from functools import update_wrapper, reduce, partial
import inspect
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
Any)
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Any)
from jax.custom_transpose import custom_transpose
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
@ -29,6 +28,7 @@ from jax.config import config
from jax._src import core
from jax._src import custom_api_util
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import traceback_util
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
@ -47,6 +47,8 @@ traceback_util.register_exclusion(__file__)
map = safe_map
zip = safe_zip
allowed_effects: effects.EffectTypeSet = (
effects.custom_derivatives_allowed_effects)
### util
@ -380,18 +382,14 @@ def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):
yield outs, tuple(todo) # Ensure the aux output is immutable
allowed_effects: Set[core.Effect] = set()
allowed_effects.add(lax.InOutFeedEffect.Infeed)
allowed_effects.add(lax.InOutFeedEffect.Outfeed)
allowed_effects.add_type(lax.InOutFeedEffect)
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts):
# TODO(mattjj): could do more checking here...
del in_avals, jvp_jaxpr_thunk, num_consts
disallowed_effects = {eff for eff in call_jaxpr.effects if eff not in
allowed_effects}
disallowed_effects = allowed_effects.filter_not_in(call_jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
@ -714,8 +712,7 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
return core.jaxpr_as_fun(fun_jaxpr)(*args)
def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
disallowed_effects = {eff for eff in fun_jaxpr.effects if eff not in
allowed_effects}
disallowed_effects = allowed_effects.filter_not_in(fun_jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `custom_vjp`: {disallowed_effects}')

View File

@ -13,7 +13,6 @@
# limitations under the License.
"""Module for JAX debugging primitives and related functionality."""
import enum
import functools
import string
import sys
@ -25,24 +24,21 @@ import numpy as np
import jax.numpy as jnp
from jax import tree_util
from jax import lax
from jax.config import config
from jax.experimental import pjit
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax._src import ad_checkpoint
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import control_flow as lcf
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding, GSPMDSharding, NamedSharding
from jax.interpreters import partial_eval as pe
from jax.config import config
# pytype: disable=import-error
try:
@ -58,17 +54,23 @@ except:
RICH_ENABLED = False
# pytype: enable=import-error
DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT'])
class DebugEffect(effects.Effect):
pass
debug_effect = DebugEffect()
core.ordered_effects.add(DebugEffect.ORDERED_PRINT)
mlir.lowerable_effects.add(DebugEffect.PRINT)
mlir.lowerable_effects.add(DebugEffect.ORDERED_PRINT)
lcf.allowed_effects.add(DebugEffect.PRINT)
lcf.allowed_effects.add(DebugEffect.ORDERED_PRINT)
ad_checkpoint.remat_allowed_effects.add(DebugEffect.PRINT)
ad_checkpoint.remat_allowed_effects.add(DebugEffect.ORDERED_PRINT)
custom_derivatives.allowed_effects.add(DebugEffect.PRINT)
custom_derivatives.allowed_effects.add(DebugEffect.ORDERED_PRINT)
class OrderedDebugEffect(effects.Effect):
pass
ordered_debug_effect = OrderedDebugEffect()
effects.ordered_effects.add_type(OrderedDebugEffect)
effects.lowerable_effects.add_type(DebugEffect)
effects.lowerable_effects.add_type(OrderedDebugEffect)
effects.control_flow_allowed_effects.add_type(DebugEffect)
effects.control_flow_allowed_effects.add_type(OrderedDebugEffect)
effects.remat_allowed_effects.add_type(DebugEffect)
effects.remat_allowed_effects.add_type(OrderedDebugEffect)
effects.custom_derivatives_allowed_effects.add_type(DebugEffect)
effects.custom_derivatives_allowed_effects.add_type(OrderedDebugEffect)
# `debug_callback_p` is the main primitive for staging out Python callbacks.
debug_callback_p = core.Primitive('debug_callback')
@ -133,7 +135,7 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
return tuple(
debug_callback_p.impl(
*flat_args, effect=effect, callback=callback, **params))
if effect in core.ordered_effects:
if effects.ordered_effects.contains(effect):
token = ctx.tokens_in.get(effect)[0]
result, token, keepalive = mlir.emit_python_callback(
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
@ -209,7 +211,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
The value of `callback(*args, **kwargs)`.
"""
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT
effect = ordered_debug_effect if ordered else debug_effect
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
callback(*args, **kwargs)
@ -273,7 +275,7 @@ inspect_sharding_p.def_impl(_inspect_sharding_impl)
def _inspect_sharding_abstract_eval(aval, **_):
del aval
# Effectful abstract avoids DCE
return [], {DebugEffect.PRINT}
return [], {debug_effect}
inspect_sharding_p.def_effectful_abstract_eval(_inspect_sharding_abstract_eval)
def _inspect_sharding_batching_rule(args, _, *, callback):

View File

@ -42,6 +42,7 @@ from jax._src import array
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import path
from jax._src import profiler
@ -507,10 +508,10 @@ def lower_xla_callable(
if type(d) is pe.InDBIdx else d for d in a.shape))
if type(a) is core.DShapedArray else a, b) for a, b in out_type]
module_name = f"jit_{fun.__name__}"
unordered_effects = [eff for eff in closed_jaxpr.effects
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in closed_jaxpr.effects
if eff in core.ordered_effects]
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(
effects.ordered_effects.filter_in(closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, unordered_effects,
ordered_effects, backend, backend.platform,

46
jax/_src/effects.py Normal file
View File

@ -0,0 +1,46 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Type, Set
class Effect:
"""A generic side-effect."""
Effects = Set[Effect]
class EffectTypeSet:
def __init__(self):
self._effect_types: Set[Type[Effect]] = set()
def add_type(self, effect_type: Type[Effect]):
self._effect_types.add(effect_type)
def contains(self, eff: Effect) -> bool:
return any(isinstance(eff, eff_type) for eff_type in self._effect_types)
def filter_in(self, effects: Effects) -> Effects:
return {eff for eff in effects if self.contains(eff)}
def filter_not_in(self, effects: Effects) -> Effects:
return {eff for eff in effects if not self.contains(eff)}
no_effects: Effects = set()
ordered_effects: EffectTypeSet = EffectTypeSet()
lowerable_effects: EffectTypeSet = EffectTypeSet()
control_flow_allowed_effects: EffectTypeSet = EffectTypeSet()
custom_derivatives_allowed_effects: EffectTypeSet = EffectTypeSet()
remat_allowed_effects: EffectTypeSet = EffectTypeSet()

View File

@ -39,6 +39,7 @@ from jax._src import ad_util
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import effects as effects_lib
from jax._src import source_info_util
from jax._src import util
from jax._src.lib import xla_bridge as xb
@ -58,7 +59,7 @@ Value = Any # = ir.Value
# mypy implicitly sets this variable to true when type checking.
MYPY = False
lowerable_effects: Set[core.Effect] = set()
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
# IR Helpers
@ -695,7 +696,8 @@ def lower_jaxpr_to_module(
if platform in _platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
in_avals, out_avals, donated_args)
if any(eff not in lowerable_effects for eff in jaxpr.effects):
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
if any(donated_args):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
@ -731,8 +733,7 @@ def lower_jaxpr_to_module(
module_name = _module_name_regex.sub("_", module_name)
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(
module_name)
unlowerable_effects = {eff for eff in jaxpr.effects
if eff not in lowerable_effects}
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
if unlowerable_effects:
raise ValueError(
f'Cannot lower jaxpr with unlowerable effects: {unlowerable_effects}')
@ -1140,7 +1141,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
f"found for platform {ctx.platform}")
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
effects = [eff for eff in eqn.effects if eff in core.ordered_effects]
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
tokens_in = tokens.subset(effects)
avals_in = map(aval, eqn.invars)
rule_ctx = LoweringRuleContext(

View File

@ -57,6 +57,7 @@ from jax._src import core
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import profiler
from jax._src import sharding as sharding_internal
@ -1487,12 +1488,12 @@ def lower_parallel_callable(
backend.platform)
module_name = f"pmap_{fun.__name__}"
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
ordered_effects = list(
effects.ordered_effects.filter_in(closed_jaxpr.effects))
if ordered_effects:
raise ValueError("Ordered effects not supported in `pmap`.")
unordered_effects = [eff for eff in closed_jaxpr.effects
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in closed_jaxpr.effects
if eff in core.ordered_effects]
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -3007,12 +3008,12 @@ def lower_sharding_computation(
module_name = f"{api_name}_{fun_name}"
if len(device_assignment) > 1:
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
if any(effects.ordered_effects.contains(eff) for eff
in closed_jaxpr.effects):
raise ValueError("Ordered effects are not supported for more than 1 device.")
unordered_effects = [eff for eff in closed_jaxpr.effects
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in closed_jaxpr.effects
if eff in core.ordered_effects]
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
@ -3180,12 +3181,13 @@ def lower_mesh_computation(
module: Union[str, xc.XlaComputation]
module_name = f"{api_name}_{fun_name}"
with core.extend_axis_env_nd(mesh.shape.items()):
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
if any(effects.ordered_effects.contains(eff) for eff
in closed_jaxpr.effects):
raise ValueError("Ordered effects not supported in mesh computations.")
unordered_effects = [eff for eff in closed_jaxpr.effects
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in closed_jaxpr.effects
if eff in core.ordered_effects]
unordered_effects = list(effects.ordered_effects.filter_not_in(
closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(
closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,

View File

@ -19,8 +19,8 @@ from typing import Callable, Optional, Sequence
from jax._src import core
from jax._src import linear_util as lu
from jax._src.core import control_flow_allowed_effects as allowed_effects
from jax._src.lax import lax
from jax._src.effects import control_flow_allowed_effects as allowed_effects
from jax._src import ad_util
from jax._src import util
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
@ -30,8 +30,7 @@ from jax.tree_util import tree_map, tree_unflatten
map, unsafe_map = safe_map, map
allowed_effects.add(lax.InOutFeedEffect.Infeed)
allowed_effects.add(lax.InOutFeedEffect.Outfeed)
allowed_effects.add_type(lax.InOutFeedEffect)
def _abstractify(x):

View File

@ -31,6 +31,7 @@ from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import util
@ -136,7 +137,7 @@ def switch(index, branches: Sequence[Callable], *operands,
out_trees[0], jaxprs[0].out_avals,
out_tree, jaxpr.out_avals)
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
disallowed_effects = joined_effects - allowed_effects
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `switch`: {disallowed_effects}')
@ -245,7 +246,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
out_tree, true_jaxpr.out_avals,
false_out_tree, false_jaxpr.out_avals)
joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
disallowed_effects = joined_effects - allowed_effects
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
@ -304,7 +305,7 @@ def _cond_with_per_branch_args(pred,
def _cond_abstract_eval(*avals, branches, **_):
joined_effects = core.join_effects(*(b.effects for b in branches))
disallowed_effects = joined_effects - allowed_effects
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
@ -751,7 +752,7 @@ def _cond_typecheck(*in_atoms, branches, linear):
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
joined_effects = core.join_effects(*(b.effects for b in branches))
disallowed_effects = joined_effects - allowed_effects
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
@ -818,8 +819,7 @@ pe.dce_rules[cond_p] = _cond_dce_rule
def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
joined_effects = core.join_effects(*(branch.effects for branch in branches))
ordered_effects = [eff for eff in joined_effects
if eff in core.ordered_effects]
ordered_effects = list(effects.ordered_effects.filter_in(joined_effects))
num_tokens = len(ordered_effects)
tokens_in = ctx.tokens_in.subset(ordered_effects)
output_token_types = [mlir.token_type() for _ in ordered_effects]

View File

@ -35,6 +35,7 @@ from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import dtypes
from jax._src import effects
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import lax
@ -266,7 +267,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
# Extract the subtree and avals for the first element of the return tuple
out_tree_children[0], carry_avals_out,
init_tree, carry_avals)
disallowed_effects = jaxpr.effects - allowed_effects
disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `scan`: {disallowed_effects}')
@ -1098,7 +1099,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
body_tree, body_jaxpr.out_avals,
in_tree_children[0], init_avals)
effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
disallowed_effects = effects - allowed_effects
disallowed_effects = allowed_effects.filter_not_in(effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
@ -1110,7 +1111,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
del args, kwargs
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
disallowed_effects = joined_effects - allowed_effects
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {disallowed_effects}')
@ -1422,8 +1423,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)
cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in
core.ordered_effects]
cond_ordered_effects = effects.ordered_effects.filter_in(cond_jaxpr.effects)
if cond_ordered_effects:
def cond(args):
# Pred can be batched
@ -1448,8 +1448,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
return mlir.lower_fun(fun)(ctx, *args)
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
body_effects = [eff for eff in body_jaxpr.effects
if eff in core.ordered_effects]
body_effects = effects.ordered_effects.filter_in(body_jaxpr.effects)
num_tokens = len(body_effects)
tokens = [ctx.tokens_in.get(eff) for eff in body_effects]
token_types = [mlir.token_type() for _ in tokens]
@ -1533,9 +1532,10 @@ def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
if joined_effects - allowed_effects:
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {joined_effects - allowed_effects}')
f'Effects not supported in `while`: {disallowed_effects}')
return body_jaxpr.out_avals, joined_effects
while_p = core.AxisPrimitive('while')

View File

@ -39,6 +39,7 @@ from jax._src import core
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import pretty_printer as pp
from jax._src import source_info_util
@ -4162,7 +4163,10 @@ def _after_all_lowering(ctx, *operands):
mlir.register_lowering(after_all_p, _after_all_lowering)
InOutFeedEffect = enum.Enum('InOutFeedEffect', ['Infeed', 'Outfeed'])
class InOutFeedEffect(effects.Effect):
pass
infeed_effect = InOutFeedEffect()
outfeed_effect = InOutFeedEffect()
def infeed(token, shape=None, partitions=None):
@ -4190,14 +4194,14 @@ def infeed(token, shape=None, partitions=None):
def _infeed_abstract_eval(token, *, shapes, partitions):
if token is not abstract_token:
raise TypeError("First argument to infeed must be a token")
return (*shapes, abstract_token), {InOutFeedEffect.Infeed}
return (*shapes, abstract_token), {infeed_effect}
infeed_p = Primitive("infeed")
infeed_p.multiple_results = True
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
mlir.lowerable_effects.add(InOutFeedEffect.Infeed)
mlir.lowerable_effects.add_type(InOutFeedEffect)
def _infeed_lowering(ctx, token, *, shapes, partitions):
@ -4243,12 +4247,12 @@ def outfeed(token, xs, partitions = None):
def _outfeed_abstract_eval(token, *xs, partitions):
if token is not abstract_token:
raise TypeError("First argument to outfeed must be a token")
return abstract_token, {InOutFeedEffect.Outfeed}
return abstract_token, {outfeed_effect}
outfeed_p = Primitive("outfeed")
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
mlir.lowerable_effects.add(InOutFeedEffect.Outfeed)
mlir.lowerable_effects.add_type(InOutFeedEffect)
def _outfeed_lowering(ctx, token, *xs, partitions):

View File

@ -25,6 +25,7 @@ from jax._src import core
from jax._src import linear_util as lu
from jax import stages
from jax._src import dispatch
from jax._src import effects
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
treedef_tuple)
from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
@ -1375,7 +1376,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
sub_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')))
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
if any(effects.ordered_effects.contains(eff) for eff
in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(),
const_nodes, *tiled_ins,
@ -1442,7 +1444,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
sub_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')))
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
if any(effects.ordered_effects.contains(eff) for eff
in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
mlir.TokenSet(), const_nodes, *sharded_global_in_nodes,
@ -1494,7 +1497,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')),
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
if any(effects.ordered_effects.contains(eff) for eff
in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
mlir.TokenSet(), const_nodes, *([n] for n in global_in_nodes),

View File

@ -1465,9 +1465,9 @@ pe.custom_staging_rules[pjit_p] = pjit_staging_rule
def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
out_positional_semantics, **_):
disallowed_effects = jaxpr.effects - mlir.lowerable_effects
disallowed_effects = mlir.lowerable_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
raise ValueError('Effects not supported in `pjit`.')
raise ValueError(f'Effects not supported in `pjit`: {disallowed_effects}.')
if config.jax_array:
return jaxpr.out_avals, jaxpr.effects
return global_to_local(out_positional_semantics, jaxpr.out_avals,

View File

@ -17,9 +17,9 @@ from __future__ import annotations
from typing import Any, List, Optional, Sequence, Set, Union
from jax._src import core
from jax._src import effects
from jax._src.lib import xla_bridge, xla_client
from jax._src.util import safe_map, safe_zip, tuple_insert, tuple_delete, prod
from jax._src.lax.control_flow import common
xc = xla_client
xb = xla_bridge
@ -31,10 +31,9 @@ zip, unsafe_zip = safe_zip, zip
Array = Any
class RefEffect:
class RefEffect(effects.Effect):
def __init__(self, ref_aval: ShapedArrayRef):
self.ref_aval = ref_aval
common.allowed_effects.add(self)
def __eq__(self, other):
if not isinstance(other, self.__class__):
@ -61,6 +60,8 @@ class AccumEffect(RefEffect):
def __str__(self):
return f"Accum<{self.ref_aval}>"
effects.control_flow_allowed_effects.add_type(RefEffect)
StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]
# ## `Ref`s

View File

@ -156,7 +156,6 @@ from jax._src.core import (
np as np,
opaque_dtypes as opaque_dtypes,
operator as operator,
ordered_effects as ordered_effects,
outfeed_primitives as outfeed_primitives,
partial as partial,
partialmethod as partialmethod,

View File

@ -33,19 +33,20 @@ from jax import dlpack
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import util
from jax._src import ad_util
from jax._src.lax import control_flow as lax_control_flow
from jax._src import core
from jax._src import ad_checkpoint
from jax._src import custom_derivatives
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import core
from jax._src import ad_util
from jax._src import effects
from jax._src import util
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
from jax.interpreters import mlir
from jax.interpreters import xla
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -296,13 +297,16 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
# Mark the effectful instancess of call_tf
CallTfEffect = enum.Enum('CallTfEffect', ['EFFECT'])
# Mark the effectful instances of call_tf
class CallTfEffect(effects.Effect):
__str__ = lambda _: "CallTfEffect"
mlir.lowerable_effects.add(CallTfEffect.EFFECT)
lax_control_flow.allowed_effects.add(CallTfEffect.EFFECT)
ad_checkpoint.remat_allowed_effects.add(CallTfEffect.EFFECT)
custom_derivatives.allowed_effects.add(CallTfEffect.EFFECT)
call_tf_effect = CallTfEffect()
effects.lowerable_effects.add_type(CallTfEffect)
effects.control_flow_allowed_effects.add_type(CallTfEffect)
effects.remat_allowed_effects.add_type(CallTfEffect)
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
def _call_tf_abstract_eval(*_,
@ -316,7 +320,7 @@ def _call_tf_abstract_eval(*_,
def is_fully_known_shape(s):
return s.rank is not None and all([d is not None for d in s])
effects = {CallTfEffect.EFFECT} if has_side_effects else set()
effects = {call_tf_effect} if has_side_effects else set()
if all([is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes]):
return (

View File

@ -30,6 +30,7 @@ from jax._src import linear_util as lu
from jax.config import config
from jax._src import api_util
from jax._src import core
from jax._src import effects
from jax._src import dtypes
from jax._src import profiler
from jax._src import source_info_util
@ -1782,7 +1783,7 @@ class DynamicJaxprTrace(core.Trace):
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name))
ordered_effects = jaxpr.effects & core.ordered_effects
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
f"map primitives: {ordered_effects}")

View File

@ -21,6 +21,7 @@ import jax
import jax.numpy as jnp
from jax._src import core
from jax import lax
from jax._src import effects
from jax._src import linear_util as lu
from jax.config import config
from jax.interpreters import ad
@ -32,7 +33,6 @@ from jax._src import ad_checkpoint
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_bridge
import numpy as np
@ -49,22 +49,34 @@ def effect_jvp_rule(primals, tangents, effect):
return effect_p.bind(*primals, effect=effect), tangents
ad.primitive_jvps[effect_p] = effect_jvp_rule
mlir.lowerable_effects.add('foo')
mlir.lowerable_effects.add('foo2')
mlir.lowerable_effects.add('bar')
mlir.lowerable_effects.add('while')
mlir.lowerable_effects.add('while1')
mlir.lowerable_effects.add('while2')
core.ordered_effects.add('foo')
core.ordered_effects.add('foo2')
core.ordered_effects.add('while1')
core.ordered_effects.add('while2')
class BasicEffect(effects.Effect):
def __init__(self, name):
self.name = name
lcf.allowed_effects.add('while')
lcf.allowed_effects.add('while1')
lcf.allowed_effects.add('while2')
__repr__ = lambda self: self.name
ad_checkpoint.remat_allowed_effects.add('remat')
class OrderedEffect(BasicEffect): pass
class UnlowerableEffect(effects.Effect): pass
class WhileEffect(effects.Effect): pass
class RematEffect(effects.Effect): pass
foo_effect = OrderedEffect("foo")
foo2_effect = OrderedEffect("foo2")
bar_effect = BasicEffect("bar")
baz_effect = UnlowerableEffect()
while_effect = WhileEffect()
while1_effect = WhileEffect()
while2_effect = WhileEffect()
log_effect = OrderedEffect("log")
unordered_log_effect = BasicEffect("unordered_log")
effects.lowerable_effects.add_type(BasicEffect)
effects.lowerable_effects.add_type(WhileEffect)
effects.ordered_effects.add_type(OrderedEffect)
effects.ordered_effects.add_type(WhileEffect)
effects.control_flow_allowed_effects.add_type(WhileEffect)
effects.remat_allowed_effects.add_type(RematEffect)
def trivial_effect_lowering(ctx, *, effect):
@ -92,10 +104,6 @@ def function_effect_lowering(ctx, *, effect):
callback_p = core.Primitive('callback')
callback_p.multiple_results = True
mlir.lowerable_effects.add('log')
mlir.lowerable_effects.add('unordered_log')
core.ordered_effects.add('log')
@callback_p.def_impl
def _(*args, callback, out_avals, effect):
del out_avals, effect
@ -110,7 +118,7 @@ def _(*avals, callback, out_avals, effect):
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect):
del out_avals
token_in = None
if effect in core.ordered_effects:
if effects.ordered_effects.contains(effect):
token_in = ctx.tokens_in.get(effect)[0]
out_op, token_out, keep_alive = mlir.emit_python_callback(
@ -149,31 +157,31 @@ class JaxprEffectsTest(jtu.JaxTestCase):
def test_effectful_primitive_in_jaxpr_creates_effects(self):
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.)
self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
self.assertEqual({'foo'}, jaxpr.effects)
self.assertEqual({foo_effect}, jaxpr.jaxpr.eqns[0].effects)
self.assertEqual({foo_effect}, jaxpr.effects)
def test_different_effects_in_jaxpr(self):
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.)
self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
self.assertEqual({'bar'}, jaxpr.jaxpr.eqns[1].effects)
self.assertEqual({'foo', 'bar'}, jaxpr.effects)
self.assertEqual({foo_effect}, jaxpr.jaxpr.eqns[0].effects)
self.assertEqual({bar_effect}, jaxpr.jaxpr.eqns[1].effects)
self.assertEqual({foo_effect, bar_effect}, jaxpr.effects)
def test_jaxpr_typecheck_should_verify_eqn_effects_are_subset(self):
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x + 1.
jaxpr = jax.make_jaxpr(f)(2.).jaxpr
# Edit jaxpr to make its type wrong
jaxpr = jaxpr.replace(effects={'foo'})
jaxpr = jaxpr.replace(effects={foo_effect})
with self.assertRaisesRegex(core.JaxprTypeError,
'Equation effects are not subset of Jaxpr effects.'):
@ -186,48 +194,52 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
def f(x):
@lu.wrap_init
def f_(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return [x]
return core.call(f_, x)[0]
jax.make_jaxpr(f)(2.)
jaxpr = jax.make_jaxpr(f)(2.)
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
def test_xla_call_primitive_inherits_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
jax.make_jaxpr(f)(2.)
jaxpr = jax.make_jaxpr(f)(2.)
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
@jtu.sample_product(flavor=["old", "new"])
def test_remat_call_primitive_inherits_effects(self, flavor):
remat = jax.remat if flavor == "old" else ad_checkpoint.checkpoint
def test_remat_call_primitive_inherits_effects(self):
@remat
@jax.checkpoint
def f(x):
x, = effect_p.bind(x, effect='foo')
x, = effect_p.bind(x, effect='bar')
x, = effect_p.bind(x, effect=foo_effect)
x, = effect_p.bind(x, effect=bar_effect)
return x
jax.make_jaxpr(f)(2.)
with self.assertRaisesRegex(NotImplementedError, "Effects not supported"):
jax.make_jaxpr(lambda x: jax.linearize(f, x)[1](x))(2.)
def test_new_remat_allows_certain_effects(self):
remat_effect = RematEffect()
@ad_checkpoint.checkpoint
def f(x):
x, = effect_p.bind(x, effect='remat')
x, = effect_p.bind(x, effect=remat_effect)
return x
jaxpr = jax.make_jaxpr(f)(2.)
self.assertSetEqual(jaxpr.effects, {"remat"})
self.assertSetEqual(jaxpr.effects, {remat_effect})
def test_custom_jvp_primitive_inherits_effects(self):
@jax.custom_jvp
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
f.defjvp(lambda x, t: (x, t))
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
@ -237,8 +249,8 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
@jax.custom_vjp
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
f.defvjp(
fwd=lambda x: (x, ()),
@ -250,27 +262,27 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
@jax.pmap
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
with self.assertRaisesRegex(
ValueError,
"Ordered effects not supported for map primitives: {'foo'}"):
"Ordered effects not supported for map primitives: {.*}"):
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def test_xmap_inherits_effects(self):
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
self.assertSetEqual(jaxpr.effects, {"foo", "bar"})
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
def test_pjit_inherits_effects(self):
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['x'])
if config.jax_array:
@ -280,7 +292,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
with mesh:
jaxpr = jax.make_jaxpr(f)(np.arange(jax.local_device_count()))
self.assertSetEqual(jaxpr.effects, {"foo", "bar"})
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@ -291,7 +303,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
config.update('jax_enable_x64', False)
self._old_lowering = mlir._lowerings[effect_p]
def _effect_lowering(ctx, *, effect):
if effect in core.ordered_effects:
if effects.ordered_effects.contains(effect):
expected_effects = [effect]
else:
expected_effects = []
@ -308,12 +320,20 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
config.update('jax_enable_x64', self.old_x64)
mlir.register_lowering(effect_p, self._old_lowering)
def test_can_lower_lowerable_effect(self):
@jax.jit
def f(x):
effect_p.bind(effect=foo_effect)
return x + 1.
f.lower(2.)
def test_cannot_lower_unlowerable_effect(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=baz_effect)
return x + 1.
f.lower(2.)
with self.assertRaisesRegex(ValueError, "Cannot lower jaxpr with effects"):
f.lower(2.)
def test_should_not_pass_tokens_into_unordered_effect(self):
@ -324,7 +344,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
def f(x):
effect_p.bind(effect='bar')
effect_p.bind(effect=bar_effect)
return x + 1.
f.lower(2.)
@ -337,7 +357,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` needs to '
'set `tokens_out`'):
@ -346,13 +366,13 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_lowering_that_sets_wrong_tokens_should_cause_error(self):
def bad_effect_lowering(ctx, *, effect):
ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get('foo')))
ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get(foo_effect)))
return []
mlir.register_lowering(effect_p, bad_effect_lowering)
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` returns '
'incorrect set of output token.'):
@ -367,7 +387,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
@ -376,8 +396,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=foo2_effect)
return x + 1.
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
@ -388,7 +408,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
@ -397,8 +417,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=foo2_effect)
return x + 1.
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
@ -413,7 +433,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
@ -434,7 +454,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
def f(x):
effect_p.bind(effect='bar')
effect_p.bind(effect=bar_effect)
return x + 1.
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
@ -450,7 +470,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_lowered_jaxpr_without_ordered_effects_takes_no_dummy_inputs(self):
@jax.jit
def f(x):
effect_p.bind(effect='bar')
effect_p.bind(effect=bar_effect)
return x + 1.
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
@ -465,7 +485,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_lowered_jaxpr_with_ordered_effects_takes_in_dummy_inputs(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
@ -481,8 +501,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_lowered_jaxpr_with_multiple_ordered_effects_takes_in_dummy_inputs(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=foo2_effect)
return x + 1.
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
@ -500,14 +520,14 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_can_lower_and_run_jaxpr_with_ordered_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
self.assertEqual(f(2.), 3.)
def test_can_lower_and_run_jaxpr_with_unordered_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect='bar')
effect_p.bind(effect=bar_effect)
return x + 1.
self.assertEqual(f(2.), 3.)
@ -517,7 +537,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
@jax.pmap
def f(x):
effect_p.bind(effect='bar')
effect_p.bind(effect=bar_effect)
return x + 1
with self.assertRaisesRegex(
NotImplementedError,
@ -530,48 +550,48 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
@jax.jit
@jax.pmap
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1
with self.assertRaisesRegex(
ValueError,
"Ordered effects not supported for map primitives: {'foo'}"):
"Ordered effects not supported for map primitives: {foo}"):
f(jnp.arange(jax.device_count()))
def test_runtime_tokens_should_update_after_running_effectful_function(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
self.assertNotIn('foo', dispatch.runtime_tokens.tokens)
self.assertNotIn(foo_effect, dispatch.runtime_tokens.tokens)
f(2.)
prev_token = dispatch.runtime_tokens.tokens['foo']
prev_token = dispatch.runtime_tokens.tokens[foo_effect]
f(2.)
curr_token = dispatch.runtime_tokens.tokens['foo']
curr_token = dispatch.runtime_tokens.tokens[foo_effect]
self.assertIsNot(prev_token, curr_token)
def test_can_lower_multiple_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=foo2_effect)
return x + 1.
@jax.jit
def g(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x + 1.
self.assertNotIn('foo', dispatch.runtime_tokens.tokens)
self.assertNotIn('foo2', dispatch.runtime_tokens.tokens)
self.assertNotIn(foo_effect, dispatch.runtime_tokens.tokens)
self.assertNotIn(foo2_effect, dispatch.runtime_tokens.tokens)
f(2.).block_until_ready()
foo_token = dispatch.runtime_tokens.tokens[foo_effect][0]
foo2_token = dispatch.runtime_tokens.tokens[foo2_effect][0]
f(2.)
foo_token = dispatch.runtime_tokens.tokens['foo'][0]
foo2_token = dispatch.runtime_tokens.tokens['foo'][0]
f(2.)
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
foo_token = dispatch.runtime_tokens.tokens['foo'][0]
foo2_token = dispatch.runtime_tokens.tokens['foo2'][0]
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0])
self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0])
foo_token = dispatch.runtime_tokens.tokens[foo_effect][0]
foo2_token = dispatch.runtime_tokens.tokens[foo2_effect][0]
g(2.)
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0])
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0])
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
class EffectOrderingTest(jtu.JaxTestCase):
@ -589,7 +609,7 @@ class EffectOrderingTest(jtu.JaxTestCase):
@jax.jit
def f(x):
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
f(2.)
jax.effects_barrier()
@ -618,11 +638,11 @@ class EffectOrderingTest(jtu.JaxTestCase):
# Expensive computation
x = x.dot(x)
x = jnp.log(x.sum())
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
@functools.partial(jax.jit, device=jax.devices()[1])
def g(x):
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
f(jnp.ones((500, 500)))
g(3.)
@ -641,12 +661,12 @@ class EffectOrderingTest(jtu.JaxTestCase):
raise unittest.SkipTest("Test requires >= 2 devices.")
tokens = []
def _noop(_):
tokens.append(dispatch.runtime_tokens.tokens['log'][0])
tokens.append(dispatch.runtime_tokens.tokens[log_effect][0])
return ()
@functools.partial(jax.jit, device=jax.devices()[0])
def f(x):
return callback_p.bind(x, callback=_noop, effect='log', out_avals=[])
return callback_p.bind(x, callback=_noop, effect=log_effect, out_avals=[])
t1 = threading.Thread(target=lambda: f(2.))
t2 = threading.Thread(target=lambda: f(3.))
@ -673,7 +693,7 @@ class ParallelEffectsTest(jtu.JaxTestCase):
def f(x):
# foo is lowerable and ordered
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x
with self.assertRaisesRegex(
ValueError, "Ordered effects not supported in `pmap`."):
@ -683,7 +703,7 @@ class ParallelEffectsTest(jtu.JaxTestCase):
def f(x):
# bar is lowerable and unordered
effect_p.bind(effect='bar')
effect_p.bind(effect=bar_effect)
return x
jax.pmap(f)(jnp.arange(jax.local_device_count()))
@ -705,7 +725,7 @@ class ParallelEffectsTest(jtu.JaxTestCase):
@jax.pmap
def f(x):
callback_p.bind(
x, callback=log_value, effect='unordered_log', out_avals=[])
x, callback=log_value, effect=unordered_log_effect, out_avals=[])
return x + 1
f(jnp.arange(2)).block_until_ready()
jax.effects_barrier()
@ -716,7 +736,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_effects_disallowed_in_cond(self):
def f1(x):
def true_fun(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x
def false_fun(x):
return x
@ -728,10 +748,10 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_allowed_effect_in_cond(self):
def f(x):
def true_fun(x):
effect_p.bind(effect='while')
effect_p.bind(effect=while_effect)
return x
def false_fun(x):
effect_p.bind(effect='while')
effect_p.bind(effect=while_effect)
return x
return lax.cond(x, true_fun, false_fun, x)
f(2)
@ -739,16 +759,16 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_allowed_effect_in_cond_jvp(self):
def f(x):
def true_fun(x):
effect_p.bind(effect='while')
effect_p.bind(effect=while_effect)
return x
def false_fun(x):
effect_p.bind(effect='while')
effect_p.bind(effect=while_effect)
return x
return lax.cond(True, true_fun, false_fun, x)
# test primal side gets effect
primal_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(f, x)[0])(2.)
self.assertEqual(primal_jaxpr.effects, {'while'})
self.assertEqual(primal_jaxpr.effects, {while_effect})
# and tangent side does not
_, f_lin = jax.linearize(f, 2.)
lin_jaxpr = f_lin.func.fun.args[0]
@ -763,8 +783,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
x, = primals
t, = tangents
# TODO(mattjj,sharadmv): don't require data dependence for jax.linearize!
# effect_p.bind(t, effect='while')
t, = effect_p.bind(t, effect='while') # data dep only on tangents
# effect_p.bind(t, effect=while_effect)
t, = effect_p.bind(t, effect=while_effect) # data dep only on tangents
return x, t
def f(x):
@ -780,15 +800,15 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
# and tangent side does
_, f_lin = jax.linearize(f, 2.)
lin_jaxpr = f_lin.func.fun.args[0]
self.assertEqual(lin_jaxpr.effects, {'while'})
self.assertEqual(lin_jaxpr.effects, {while_effect})
def test_allowed_ordered_effect_in_cond(self):
def f(x):
def true_fun(x):
effect_p.bind(effect='while1')
effect_p.bind(effect=while1_effect)
return x
def false_fun(x):
effect_p.bind(effect='while1')
effect_p.bind(effect=while1_effect)
return x
return lax.cond(x, true_fun, false_fun, x)
f(2)
@ -796,12 +816,12 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_multiple_allowed_ordered_effect_in_cond(self):
def f(x):
def true_fun(x):
effect_p.bind(effect='while1')
effect_p.bind(effect='while2')
effect_p.bind(effect=while1_effect)
effect_p.bind(effect=while2_effect)
return x
def false_fun(x):
effect_p.bind(effect='while1')
effect_p.bind(effect='while2')
effect_p.bind(effect=while1_effect)
effect_p.bind(effect=while2_effect)
return x
return lax.cond(x, true_fun, false_fun, x)
f(2)
@ -810,7 +830,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def true_fun(x):
return x
def false_fun(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x
return lax.cond(True, true_fun, false_fun, x)
@ -822,7 +842,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect='while')
effect_p.bind(effect=while_effect)
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
@ -830,7 +850,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_allowed_effect_in_cond_body(self):
def f(x):
def cond_fun(x):
effect_p.bind(effect='while')
effect_p.bind(effect=while_effect)
return False
def body_fun(x):
return x
@ -842,7 +862,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect='while1')
effect_p.bind(effect=while1_effect)
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
@ -852,8 +872,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect='while1')
effect_p.bind(effect='while2')
effect_p.bind(effect=while1_effect)
effect_p.bind(effect=while2_effect)
return x
return lax.while_loop(cond_fun, body_fun, x)
f(2)
@ -861,7 +881,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_effects_disallowed_in_while(self):
def f1(x):
def cond_fun(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return False
def body_fun(x):
return x
@ -874,7 +894,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def cond_fun(x):
return False
def body_fun(x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return x
return lax.while_loop(cond_fun, body_fun, x)
@ -884,7 +904,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_allowed_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect='while')
effect_p.bind(effect=while_effect)
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
@ -892,7 +912,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_allowed_ordered_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect='while1')
effect_p.bind(effect=while1_effect)
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
@ -900,8 +920,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def test_multiple_allowed_ordered_effect_in_scan(self):
def f(x):
def body_fun(carry, x):
effect_p.bind(effect='while1')
effect_p.bind(effect='while2')
effect_p.bind(effect=while1_effect)
effect_p.bind(effect=while2_effect)
return carry, x
return lax.scan(body_fun, x, jnp.arange(5))
f(2)
@ -910,7 +930,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
def f(x):
def body(carry, x):
effect_p.bind(effect='foo')
effect_p.bind(effect=foo_effect)
return carry, x
return lax.scan(body, x, jnp.arange(4))

View File

@ -24,6 +24,7 @@ from jax import tree_util
from jax._src import core
from jax._src import debugging
from jax._src import dispatch
from jax._src import effects
from jax._src import sharding
from jax._src import test_util as jtu
from jax._src import util
@ -91,8 +92,7 @@ def callback(f, result_shape, *args, ordered: bool = False, **kwargs):
core.ShapedArray(s.shape, s.dtype) for s in flat_result_shapes
]
effect = (
debugging.DebugEffect.ORDERED_PRINT
if ordered else debugging.DebugEffect.PRINT)
debugging.ordered_debug_effect if ordered else debugging.debug_effect)
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
@ -111,7 +111,7 @@ def callback_lowering(ctx, *args, effect, callback, **params):
return tuple(
callback_p.impl(*flat_args, effect=effect, callback=callback, **params))
if effect in core.ordered_effects:
if effects.ordered_effects.contains(effect):
token = ctx.tokens_in.get(effect)[0]
result, token, keepalive = mlir.emit_python_callback(
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)