mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14549 from sharadmv:dbidx-effects
PiperOrigin-RevId: 510608031
This commit is contained in:
commit
c0107cc836
@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
|
||||
FrozenSet)
|
||||
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
|
||||
Union)
|
||||
import types
|
||||
|
||||
import jax
|
||||
@ -25,6 +25,7 @@ from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import effects
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -44,6 +45,7 @@ traceback_util.register_exclusion(__file__)
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
allowed_effects: effects.EffectTypeSet = effects.remat_allowed_effects
|
||||
|
||||
### Policies
|
||||
|
||||
@ -451,14 +453,11 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
|
||||
return out_primals, out_tangents
|
||||
ad.primitive_jvps[remat_p] = remat_jvp
|
||||
|
||||
remat_allowed_effects: Set[core.Effect] = set()
|
||||
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Infeed)
|
||||
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Outfeed)
|
||||
allowed_effects.add_type(lax_internal.InOutFeedEffect)
|
||||
|
||||
def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
assert not jaxpr.constvars
|
||||
disallowed_effects = {eff for eff in jaxpr.effects
|
||||
if eff not in remat_allowed_effects}
|
||||
disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
'Effects not supported in partial-eval of `checkpoint`/`remat`: '
|
||||
|
@ -43,6 +43,7 @@ from jax._src import callback as jcb
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import array
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
@ -1071,10 +1072,10 @@ def xla_computation(fun: Callable,
|
||||
else:
|
||||
out_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation out_parts", out_tree(), out_parts))
|
||||
unordered_effects = [eff for eff in jaxpr.effects
|
||||
if eff not in core.ordered_effects]
|
||||
ordered_effects = [eff for eff in jaxpr.effects
|
||||
if eff in core.ordered_effects]
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(jaxpr.effects))
|
||||
ordered_effects = list(
|
||||
effects.ordered_effects.filter_in(jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
f"xla_computation_{fun_name}",
|
||||
core.ClosedJaxpr(jaxpr, consts),
|
||||
|
@ -24,6 +24,7 @@ from jax.interpreters import mlir
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import util
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import ad
|
||||
@ -158,17 +159,19 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
io_callback_p = core.Primitive("io_callback")
|
||||
io_callback_p.multiple_results = True
|
||||
|
||||
class IOEffect:
|
||||
class IOEffect(effects.Effect):
|
||||
__str__ = lambda _: "IO"
|
||||
class OrderedIOEffect:
|
||||
|
||||
class OrderedIOEffect(effects.Effect):
|
||||
__str__ = lambda _: "OrderedIO"
|
||||
|
||||
_IOEffect = IOEffect()
|
||||
_OrderedIOEffect = OrderedIOEffect()
|
||||
mlir.lowerable_effects.add(_IOEffect)
|
||||
mlir.lowerable_effects.add(_OrderedIOEffect)
|
||||
core.control_flow_allowed_effects.add(_IOEffect)
|
||||
core.control_flow_allowed_effects.add(_OrderedIOEffect)
|
||||
core.ordered_effects.add(_OrderedIOEffect)
|
||||
effects.lowerable_effects.add_type(IOEffect)
|
||||
effects.lowerable_effects.add_type(OrderedIOEffect)
|
||||
effects.control_flow_allowed_effects.add_type(IOEffect)
|
||||
effects.control_flow_allowed_effects.add_type(OrderedIOEffect)
|
||||
effects.ordered_effects.add_type(OrderedIOEffect)
|
||||
|
||||
|
||||
def io_callback_impl(*args, result_avals, callback: Callable[..., Any],
|
||||
|
@ -36,13 +36,13 @@ from jax.tree_util import tree_unflatten
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import effects
|
||||
from jax._src import prng
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax import control_flow as cf
|
||||
from jax._src.sharding import GSPMDSharding
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
||||
@ -100,20 +100,18 @@ class JaxException(Exception):
|
||||
|
||||
@functools.total_ordering
|
||||
@dataclasses.dataclass(eq=True, frozen=True)
|
||||
class ErrorEffect:
|
||||
class ErrorEffect(effects.Effect):
|
||||
error_type: Type[JaxException]
|
||||
shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
cf.allowed_effects.add(self)
|
||||
mlir.lowerable_effects.add(self)
|
||||
|
||||
def __lt__(self, other: 'ErrorEffect'):
|
||||
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
|
||||
for sd in x.shape_dtypes)
|
||||
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
|
||||
return (unpack(self) < unpack(other))
|
||||
|
||||
effects.control_flow_allowed_effects.add_type(ErrorEffect)
|
||||
effects.lowerable_effects.add_type(ErrorEffect)
|
||||
|
||||
class DivisionByZeroError(JaxException):
|
||||
|
||||
|
@ -37,6 +37,7 @@ import numpy as np
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import effects
|
||||
from jax._src.config import FLAGS, config
|
||||
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
TracerIntegerConversionError, UnexpectedTracerError)
|
||||
@ -59,12 +60,10 @@ map, unsafe_map = safe_map, map
|
||||
|
||||
# -------------------- jaxprs --------------------
|
||||
|
||||
Effect = Hashable
|
||||
Effects = Set[Effect]
|
||||
no_effects: Effects = set()
|
||||
ordered_effects: Set[Effect] = set()
|
||||
control_flow_allowed_effects: Set[Effect] = set()
|
||||
|
||||
Effect = effects.Effect
|
||||
Effects = effects.Effects
|
||||
EffectTypeSet = effects.EffectTypeSet
|
||||
no_effects: Effects = effects.no_effects
|
||||
|
||||
class Jaxpr:
|
||||
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects']
|
||||
@ -2829,7 +2828,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
|
||||
if "call_jaxpr" not in params:
|
||||
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
|
||||
call_jaxpr = params["call_jaxpr"]
|
||||
ordered_effects_ = call_jaxpr.effects & ordered_effects
|
||||
ordered_effects_ = effects.ordered_effects.filter_in(call_jaxpr.effects)
|
||||
if ordered_effects_:
|
||||
raise JaxprTypeError(
|
||||
f"Map primitive {prim} mapping ordered effects: {ordered_effects_}")
|
||||
|
@ -14,8 +14,7 @@
|
||||
|
||||
from functools import update_wrapper, reduce, partial
|
||||
import inspect
|
||||
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
|
||||
Any)
|
||||
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Any)
|
||||
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
@ -29,6 +28,7 @@ from jax.config import config
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import traceback_util
|
||||
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
|
||||
@ -47,6 +47,8 @@ traceback_util.register_exclusion(__file__)
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
allowed_effects: effects.EffectTypeSet = (
|
||||
effects.custom_derivatives_allowed_effects)
|
||||
|
||||
### util
|
||||
|
||||
@ -380,18 +382,14 @@ def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):
|
||||
yield outs, tuple(todo) # Ensure the aux output is immutable
|
||||
|
||||
|
||||
allowed_effects: Set[core.Effect] = set()
|
||||
allowed_effects.add(lax.InOutFeedEffect.Infeed)
|
||||
allowed_effects.add(lax.InOutFeedEffect.Outfeed)
|
||||
|
||||
allowed_effects.add_type(lax.InOutFeedEffect)
|
||||
|
||||
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
|
||||
|
||||
def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts):
|
||||
# TODO(mattjj): could do more checking here...
|
||||
del in_avals, jvp_jaxpr_thunk, num_consts
|
||||
disallowed_effects = {eff for eff in call_jaxpr.effects if eff not in
|
||||
allowed_effects}
|
||||
disallowed_effects = allowed_effects.filter_not_in(call_jaxpr.effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
|
||||
@ -714,8 +712,7 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
|
||||
return core.jaxpr_as_fun(fun_jaxpr)(*args)
|
||||
|
||||
def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
|
||||
disallowed_effects = {eff for eff in fun_jaxpr.effects if eff not in
|
||||
allowed_effects}
|
||||
disallowed_effects = allowed_effects.filter_not_in(fun_jaxpr.effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `custom_vjp`: {disallowed_effects}')
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
"""Module for JAX debugging primitives and related functionality."""
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import string
|
||||
import sys
|
||||
@ -25,24 +24,21 @@ import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import control_flow as lcf
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding, GSPMDSharding, NamedSharding
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.config import config
|
||||
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
@ -58,17 +54,23 @@ except:
|
||||
RICH_ENABLED = False
|
||||
# pytype: enable=import-error
|
||||
|
||||
DebugEffect = enum.Enum('DebugEffect', ['PRINT', 'ORDERED_PRINT'])
|
||||
class DebugEffect(effects.Effect):
|
||||
pass
|
||||
debug_effect = DebugEffect()
|
||||
|
||||
core.ordered_effects.add(DebugEffect.ORDERED_PRINT)
|
||||
mlir.lowerable_effects.add(DebugEffect.PRINT)
|
||||
mlir.lowerable_effects.add(DebugEffect.ORDERED_PRINT)
|
||||
lcf.allowed_effects.add(DebugEffect.PRINT)
|
||||
lcf.allowed_effects.add(DebugEffect.ORDERED_PRINT)
|
||||
ad_checkpoint.remat_allowed_effects.add(DebugEffect.PRINT)
|
||||
ad_checkpoint.remat_allowed_effects.add(DebugEffect.ORDERED_PRINT)
|
||||
custom_derivatives.allowed_effects.add(DebugEffect.PRINT)
|
||||
custom_derivatives.allowed_effects.add(DebugEffect.ORDERED_PRINT)
|
||||
class OrderedDebugEffect(effects.Effect):
|
||||
pass
|
||||
ordered_debug_effect = OrderedDebugEffect()
|
||||
|
||||
effects.ordered_effects.add_type(OrderedDebugEffect)
|
||||
effects.lowerable_effects.add_type(DebugEffect)
|
||||
effects.lowerable_effects.add_type(OrderedDebugEffect)
|
||||
effects.control_flow_allowed_effects.add_type(DebugEffect)
|
||||
effects.control_flow_allowed_effects.add_type(OrderedDebugEffect)
|
||||
effects.remat_allowed_effects.add_type(DebugEffect)
|
||||
effects.remat_allowed_effects.add_type(OrderedDebugEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(DebugEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(OrderedDebugEffect)
|
||||
|
||||
# `debug_callback_p` is the main primitive for staging out Python callbacks.
|
||||
debug_callback_p = core.Primitive('debug_callback')
|
||||
@ -133,7 +135,7 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
return tuple(
|
||||
debug_callback_p.impl(
|
||||
*flat_args, effect=effect, callback=callback, **params))
|
||||
if effect in core.ordered_effects:
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token = ctx.tokens_in.get(effect)[0]
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
|
||||
@ -209,7 +211,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
|
||||
The value of `callback(*args, **kwargs)`.
|
||||
"""
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT
|
||||
effect = ordered_debug_effect if ordered else debug_effect
|
||||
def _flat_callback(*flat_args):
|
||||
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
||||
callback(*args, **kwargs)
|
||||
@ -273,7 +275,7 @@ inspect_sharding_p.def_impl(_inspect_sharding_impl)
|
||||
def _inspect_sharding_abstract_eval(aval, **_):
|
||||
del aval
|
||||
# Effectful abstract avoids DCE
|
||||
return [], {DebugEffect.PRINT}
|
||||
return [], {debug_effect}
|
||||
inspect_sharding_p.def_effectful_abstract_eval(_inspect_sharding_abstract_eval)
|
||||
|
||||
def _inspect_sharding_batching_rule(args, _, *, callback):
|
||||
|
@ -42,6 +42,7 @@ from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import path
|
||||
from jax._src import profiler
|
||||
@ -507,10 +508,10 @@ def lower_xla_callable(
|
||||
if type(d) is pe.InDBIdx else d for d in a.shape))
|
||||
if type(a) is core.DShapedArray else a, b) for a, b in out_type]
|
||||
module_name = f"jit_{fun.__name__}"
|
||||
unordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff not in core.ordered_effects]
|
||||
ordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff in core.ordered_effects]
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
ordered_effects = list(
|
||||
effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, unordered_effects,
|
||||
ordered_effects, backend, backend.platform,
|
||||
|
46
jax/_src/effects.py
Normal file
46
jax/_src/effects.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Type, Set
|
||||
|
||||
|
||||
class Effect:
|
||||
"""A generic side-effect."""
|
||||
|
||||
Effects = Set[Effect]
|
||||
|
||||
class EffectTypeSet:
|
||||
|
||||
def __init__(self):
|
||||
self._effect_types: Set[Type[Effect]] = set()
|
||||
|
||||
def add_type(self, effect_type: Type[Effect]):
|
||||
self._effect_types.add(effect_type)
|
||||
|
||||
def contains(self, eff: Effect) -> bool:
|
||||
return any(isinstance(eff, eff_type) for eff_type in self._effect_types)
|
||||
|
||||
def filter_in(self, effects: Effects) -> Effects:
|
||||
return {eff for eff in effects if self.contains(eff)}
|
||||
|
||||
def filter_not_in(self, effects: Effects) -> Effects:
|
||||
return {eff for eff in effects if not self.contains(eff)}
|
||||
|
||||
no_effects: Effects = set()
|
||||
ordered_effects: EffectTypeSet = EffectTypeSet()
|
||||
lowerable_effects: EffectTypeSet = EffectTypeSet()
|
||||
control_flow_allowed_effects: EffectTypeSet = EffectTypeSet()
|
||||
custom_derivatives_allowed_effects: EffectTypeSet = EffectTypeSet()
|
||||
remat_allowed_effects: EffectTypeSet = EffectTypeSet()
|
@ -39,6 +39,7 @@ from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects as effects_lib
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
@ -58,7 +59,7 @@ Value = Any # = ir.Value
|
||||
# mypy implicitly sets this variable to true when type checking.
|
||||
MYPY = False
|
||||
|
||||
lowerable_effects: Set[core.Effect] = set()
|
||||
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
|
||||
|
||||
|
||||
# IR Helpers
|
||||
@ -695,7 +696,8 @@ def lower_jaxpr_to_module(
|
||||
if platform in _platforms_with_donation:
|
||||
input_output_aliases, donated_args = _set_up_aliases(
|
||||
in_avals, out_avals, donated_args)
|
||||
if any(eff not in lowerable_effects for eff in jaxpr.effects):
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
|
||||
if unlowerable_effects:
|
||||
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
|
||||
if any(donated_args):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
@ -731,8 +733,7 @@ def lower_jaxpr_to_module(
|
||||
module_name = _module_name_regex.sub("_", module_name)
|
||||
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(
|
||||
module_name)
|
||||
unlowerable_effects = {eff for eff in jaxpr.effects
|
||||
if eff not in lowerable_effects}
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
|
||||
if unlowerable_effects:
|
||||
raise ValueError(
|
||||
f'Cannot lower jaxpr with unlowerable effects: {unlowerable_effects}')
|
||||
@ -1140,7 +1141,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
f"found for platform {ctx.platform}")
|
||||
|
||||
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
|
||||
effects = [eff for eff in eqn.effects if eff in core.ordered_effects]
|
||||
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
|
||||
tokens_in = tokens.subset(effects)
|
||||
avals_in = map(aval, eqn.invars)
|
||||
rule_ctx = LoweringRuleContext(
|
||||
|
@ -57,6 +57,7 @@ from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import profiler
|
||||
from jax._src import sharding as sharding_internal
|
||||
@ -1487,12 +1488,12 @@ def lower_parallel_callable(
|
||||
backend.platform)
|
||||
module_name = f"pmap_{fun.__name__}"
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
|
||||
ordered_effects = list(
|
||||
effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported in `pmap`.")
|
||||
unordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff not in core.ordered_effects]
|
||||
ordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff in core.ordered_effects]
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -3007,12 +3008,12 @@ def lower_sharding_computation(
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
|
||||
if len(device_assignment) > 1:
|
||||
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in closed_jaxpr.effects):
|
||||
raise ValueError("Ordered effects are not supported for more than 1 device.")
|
||||
unordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff not in core.ordered_effects]
|
||||
ordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff in core.ordered_effects]
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
@ -3180,12 +3181,13 @@ def lower_mesh_computation(
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in closed_jaxpr.effects):
|
||||
raise ValueError("Ordered effects not supported in mesh computations.")
|
||||
unordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff not in core.ordered_effects]
|
||||
ordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff in core.ordered_effects]
|
||||
unordered_effects = list(effects.ordered_effects.filter_not_in(
|
||||
closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(
|
||||
closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
|
@ -19,8 +19,8 @@ from typing import Callable, Optional, Sequence
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.core import control_flow_allowed_effects as allowed_effects
|
||||
from jax._src.lax import lax
|
||||
from jax._src.effects import control_flow_allowed_effects as allowed_effects
|
||||
from jax._src import ad_util
|
||||
from jax._src import util
|
||||
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
|
||||
@ -30,8 +30,7 @@ from jax.tree_util import tree_map, tree_unflatten
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
allowed_effects.add(lax.InOutFeedEffect.Infeed)
|
||||
allowed_effects.add(lax.InOutFeedEffect.Outfeed)
|
||||
allowed_effects.add_type(lax.InOutFeedEffect)
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
|
@ -31,6 +31,7 @@ from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -136,7 +137,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
out_trees[0], jaxprs[0].out_avals,
|
||||
out_tree, jaxpr.out_avals)
|
||||
joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `switch`: {disallowed_effects}')
|
||||
@ -245,7 +246,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
out_tree, true_jaxpr.out_avals,
|
||||
false_out_tree, false_jaxpr.out_avals)
|
||||
joined_effects = core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `cond`: {disallowed_effects}')
|
||||
@ -304,7 +305,7 @@ def _cond_with_per_branch_args(pred,
|
||||
|
||||
def _cond_abstract_eval(*avals, branches, **_):
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `cond`: {disallowed_effects}')
|
||||
@ -751,7 +752,7 @@ def _cond_typecheck(*in_atoms, branches, linear):
|
||||
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
|
||||
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
|
||||
joined_effects = core.join_effects(*(b.effects for b in branches))
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `cond`: {disallowed_effects}')
|
||||
@ -818,8 +819,7 @@ pe.dce_rules[cond_p] = _cond_dce_rule
|
||||
def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
del linear # Unused.
|
||||
joined_effects = core.join_effects(*(branch.effects for branch in branches))
|
||||
ordered_effects = [eff for eff in joined_effects
|
||||
if eff in core.ordered_effects]
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(joined_effects))
|
||||
num_tokens = len(ordered_effects)
|
||||
tokens_in = ctx.tokens_in.subset(ordered_effects)
|
||||
output_token_types = [mlir.token_type() for _ in ordered_effects]
|
||||
|
@ -35,6 +35,7 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax
|
||||
@ -266,7 +267,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
# Extract the subtree and avals for the first element of the return tuple
|
||||
out_tree_children[0], carry_avals_out,
|
||||
init_tree, carry_avals)
|
||||
disallowed_effects = jaxpr.effects - allowed_effects
|
||||
disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `scan`: {disallowed_effects}')
|
||||
@ -1098,7 +1099,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
body_tree, body_jaxpr.out_avals,
|
||||
in_tree_children[0], init_avals)
|
||||
effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
||||
disallowed_effects = effects - allowed_effects
|
||||
disallowed_effects = allowed_effects.filter_not_in(effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {disallowed_effects}')
|
||||
@ -1110,7 +1111,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
|
||||
del args, kwargs
|
||||
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
||||
disallowed_effects = joined_effects - allowed_effects
|
||||
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {disallowed_effects}')
|
||||
@ -1422,8 +1423,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
body_nconsts):
|
||||
pred_aval = cond_jaxpr.out_avals[0]
|
||||
batched = bool(pred_aval.shape)
|
||||
cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in
|
||||
core.ordered_effects]
|
||||
cond_ordered_effects = effects.ordered_effects.filter_in(cond_jaxpr.effects)
|
||||
if cond_ordered_effects:
|
||||
def cond(args):
|
||||
# Pred can be batched
|
||||
@ -1448,8 +1448,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
return mlir.lower_fun(fun)(ctx, *args)
|
||||
|
||||
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
body_effects = [eff for eff in body_jaxpr.effects
|
||||
if eff in core.ordered_effects]
|
||||
body_effects = effects.ordered_effects.filter_in(body_jaxpr.effects)
|
||||
num_tokens = len(body_effects)
|
||||
tokens = [ctx.tokens_in.get(eff) for eff in body_effects]
|
||||
token_types = [mlir.token_type() for _ in tokens]
|
||||
@ -1533,9 +1532,10 @@ def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
body_nconsts):
|
||||
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
|
||||
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
||||
if joined_effects - allowed_effects:
|
||||
disallowed_effects = allowed_effects.filter_not_in(joined_effects)
|
||||
if disallowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {joined_effects - allowed_effects}')
|
||||
f'Effects not supported in `while`: {disallowed_effects}')
|
||||
return body_jaxpr.out_avals, joined_effects
|
||||
|
||||
while_p = core.AxisPrimitive('while')
|
||||
|
@ -39,6 +39,7 @@ from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
@ -4162,7 +4163,10 @@ def _after_all_lowering(ctx, *operands):
|
||||
mlir.register_lowering(after_all_p, _after_all_lowering)
|
||||
|
||||
|
||||
InOutFeedEffect = enum.Enum('InOutFeedEffect', ['Infeed', 'Outfeed'])
|
||||
class InOutFeedEffect(effects.Effect):
|
||||
pass
|
||||
infeed_effect = InOutFeedEffect()
|
||||
outfeed_effect = InOutFeedEffect()
|
||||
|
||||
|
||||
def infeed(token, shape=None, partitions=None):
|
||||
@ -4190,14 +4194,14 @@ def infeed(token, shape=None, partitions=None):
|
||||
def _infeed_abstract_eval(token, *, shapes, partitions):
|
||||
if token is not abstract_token:
|
||||
raise TypeError("First argument to infeed must be a token")
|
||||
return (*shapes, abstract_token), {InOutFeedEffect.Infeed}
|
||||
return (*shapes, abstract_token), {infeed_effect}
|
||||
|
||||
|
||||
infeed_p = Primitive("infeed")
|
||||
infeed_p.multiple_results = True
|
||||
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
|
||||
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
|
||||
mlir.lowerable_effects.add(InOutFeedEffect.Infeed)
|
||||
mlir.lowerable_effects.add_type(InOutFeedEffect)
|
||||
|
||||
|
||||
def _infeed_lowering(ctx, token, *, shapes, partitions):
|
||||
@ -4243,12 +4247,12 @@ def outfeed(token, xs, partitions = None):
|
||||
def _outfeed_abstract_eval(token, *xs, partitions):
|
||||
if token is not abstract_token:
|
||||
raise TypeError("First argument to outfeed must be a token")
|
||||
return abstract_token, {InOutFeedEffect.Outfeed}
|
||||
return abstract_token, {outfeed_effect}
|
||||
|
||||
outfeed_p = Primitive("outfeed")
|
||||
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
|
||||
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
|
||||
mlir.lowerable_effects.add(InOutFeedEffect.Outfeed)
|
||||
mlir.lowerable_effects.add_type(InOutFeedEffect)
|
||||
|
||||
|
||||
def _outfeed_lowering(ctx, token, *xs, partitions):
|
||||
|
@ -25,6 +25,7 @@ from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
|
||||
treedef_tuple)
|
||||
from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
|
||||
@ -1375,7 +1376,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')))
|
||||
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(),
|
||||
const_nodes, *tiled_ins,
|
||||
@ -1442,7 +1444,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')))
|
||||
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
|
||||
mlir.TokenSet(), const_nodes, *sharded_global_in_nodes,
|
||||
@ -1494,7 +1497,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')),
|
||||
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
|
||||
if any(eff in core.ordered_effects for eff in vectorized_jaxpr.effects):
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
|
||||
mlir.TokenSet(), const_nodes, *([n] for n in global_in_nodes),
|
||||
|
@ -1465,9 +1465,9 @@ pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
||||
|
||||
def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
|
||||
out_positional_semantics, **_):
|
||||
disallowed_effects = jaxpr.effects - mlir.lowerable_effects
|
||||
disallowed_effects = mlir.lowerable_effects.filter_not_in(jaxpr.effects)
|
||||
if disallowed_effects:
|
||||
raise ValueError('Effects not supported in `pjit`.')
|
||||
raise ValueError(f'Effects not supported in `pjit`: {disallowed_effects}.')
|
||||
if config.jax_array:
|
||||
return jaxpr.out_avals, jaxpr.effects
|
||||
return global_to_local(out_positional_semantics, jaxpr.out_avals,
|
||||
|
@ -17,9 +17,9 @@ from __future__ import annotations
|
||||
from typing import Any, List, Optional, Sequence, Set, Union
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
from jax._src.util import safe_map, safe_zip, tuple_insert, tuple_delete, prod
|
||||
from jax._src.lax.control_flow import common
|
||||
|
||||
xc = xla_client
|
||||
xb = xla_bridge
|
||||
@ -31,10 +31,9 @@ zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Array = Any
|
||||
|
||||
class RefEffect:
|
||||
class RefEffect(effects.Effect):
|
||||
def __init__(self, ref_aval: ShapedArrayRef):
|
||||
self.ref_aval = ref_aval
|
||||
common.allowed_effects.add(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
@ -61,6 +60,8 @@ class AccumEffect(RefEffect):
|
||||
def __str__(self):
|
||||
return f"Accum<{self.ref_aval}>"
|
||||
|
||||
effects.control_flow_allowed_effects.add_type(RefEffect)
|
||||
|
||||
StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]
|
||||
|
||||
# ## `Ref`s
|
||||
|
@ -156,7 +156,6 @@ from jax._src.core import (
|
||||
np as np,
|
||||
opaque_dtypes as opaque_dtypes,
|
||||
operator as operator,
|
||||
ordered_effects as ordered_effects,
|
||||
outfeed_primitives as outfeed_primitives,
|
||||
partial as partial,
|
||||
partialmethod as partialmethod,
|
||||
|
@ -33,19 +33,20 @@ from jax import dlpack
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import ad_util
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src import core
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import custom_derivatives
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import effects
|
||||
from jax._src import util
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
@ -296,13 +297,16 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc
|
||||
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
|
||||
|
||||
|
||||
# Mark the effectful instancess of call_tf
|
||||
CallTfEffect = enum.Enum('CallTfEffect', ['EFFECT'])
|
||||
# Mark the effectful instances of call_tf
|
||||
class CallTfEffect(effects.Effect):
|
||||
__str__ = lambda _: "CallTfEffect"
|
||||
|
||||
mlir.lowerable_effects.add(CallTfEffect.EFFECT)
|
||||
lax_control_flow.allowed_effects.add(CallTfEffect.EFFECT)
|
||||
ad_checkpoint.remat_allowed_effects.add(CallTfEffect.EFFECT)
|
||||
custom_derivatives.allowed_effects.add(CallTfEffect.EFFECT)
|
||||
call_tf_effect = CallTfEffect()
|
||||
|
||||
effects.lowerable_effects.add_type(CallTfEffect)
|
||||
effects.control_flow_allowed_effects.add_type(CallTfEffect)
|
||||
effects.remat_allowed_effects.add_type(CallTfEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
|
||||
|
||||
|
||||
def _call_tf_abstract_eval(*_,
|
||||
@ -316,7 +320,7 @@ def _call_tf_abstract_eval(*_,
|
||||
|
||||
def is_fully_known_shape(s):
|
||||
return s.rank is not None and all([d is not None for d in s])
|
||||
effects = {CallTfEffect.EFFECT} if has_side_effects else set()
|
||||
effects = {call_tf_effect} if has_side_effects else set()
|
||||
if all([is_fully_known_shape(s)
|
||||
for s in concrete_function_flat_tf.output_shapes]):
|
||||
return (
|
||||
|
@ -30,6 +30,7 @@ from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src import source_info_util
|
||||
@ -1782,7 +1783,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
with core.new_sublevel():
|
||||
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name))
|
||||
ordered_effects = jaxpr.effects & core.ordered_effects
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
f"map primitives: {ordered_effects}")
|
||||
|
@ -21,6 +21,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.interpreters import ad
|
||||
@ -32,7 +33,6 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lax import control_flow as lcf
|
||||
from jax._src.lib import xla_bridge
|
||||
import numpy as np
|
||||
|
||||
@ -49,22 +49,34 @@ def effect_jvp_rule(primals, tangents, effect):
|
||||
return effect_p.bind(*primals, effect=effect), tangents
|
||||
ad.primitive_jvps[effect_p] = effect_jvp_rule
|
||||
|
||||
mlir.lowerable_effects.add('foo')
|
||||
mlir.lowerable_effects.add('foo2')
|
||||
mlir.lowerable_effects.add('bar')
|
||||
mlir.lowerable_effects.add('while')
|
||||
mlir.lowerable_effects.add('while1')
|
||||
mlir.lowerable_effects.add('while2')
|
||||
core.ordered_effects.add('foo')
|
||||
core.ordered_effects.add('foo2')
|
||||
core.ordered_effects.add('while1')
|
||||
core.ordered_effects.add('while2')
|
||||
class BasicEffect(effects.Effect):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
lcf.allowed_effects.add('while')
|
||||
lcf.allowed_effects.add('while1')
|
||||
lcf.allowed_effects.add('while2')
|
||||
__repr__ = lambda self: self.name
|
||||
|
||||
ad_checkpoint.remat_allowed_effects.add('remat')
|
||||
class OrderedEffect(BasicEffect): pass
|
||||
class UnlowerableEffect(effects.Effect): pass
|
||||
class WhileEffect(effects.Effect): pass
|
||||
class RematEffect(effects.Effect): pass
|
||||
|
||||
foo_effect = OrderedEffect("foo")
|
||||
foo2_effect = OrderedEffect("foo2")
|
||||
bar_effect = BasicEffect("bar")
|
||||
baz_effect = UnlowerableEffect()
|
||||
while_effect = WhileEffect()
|
||||
while1_effect = WhileEffect()
|
||||
while2_effect = WhileEffect()
|
||||
log_effect = OrderedEffect("log")
|
||||
unordered_log_effect = BasicEffect("unordered_log")
|
||||
|
||||
effects.lowerable_effects.add_type(BasicEffect)
|
||||
effects.lowerable_effects.add_type(WhileEffect)
|
||||
effects.ordered_effects.add_type(OrderedEffect)
|
||||
effects.ordered_effects.add_type(WhileEffect)
|
||||
effects.control_flow_allowed_effects.add_type(WhileEffect)
|
||||
|
||||
effects.remat_allowed_effects.add_type(RematEffect)
|
||||
|
||||
|
||||
def trivial_effect_lowering(ctx, *, effect):
|
||||
@ -92,10 +104,6 @@ def function_effect_lowering(ctx, *, effect):
|
||||
callback_p = core.Primitive('callback')
|
||||
callback_p.multiple_results = True
|
||||
|
||||
mlir.lowerable_effects.add('log')
|
||||
mlir.lowerable_effects.add('unordered_log')
|
||||
core.ordered_effects.add('log')
|
||||
|
||||
@callback_p.def_impl
|
||||
def _(*args, callback, out_avals, effect):
|
||||
del out_avals, effect
|
||||
@ -110,7 +118,7 @@ def _(*avals, callback, out_avals, effect):
|
||||
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect):
|
||||
del out_avals
|
||||
token_in = None
|
||||
if effect in core.ordered_effects:
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token_in = ctx.tokens_in.get(effect)[0]
|
||||
|
||||
out_op, token_out, keep_alive = mlir.emit_python_callback(
|
||||
@ -149,31 +157,31 @@ class JaxprEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_effectful_primitive_in_jaxpr_creates_effects(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
|
||||
self.assertEqual({'foo'}, jaxpr.effects)
|
||||
self.assertEqual({foo_effect}, jaxpr.jaxpr.eqns[0].effects)
|
||||
self.assertEqual({foo_effect}, jaxpr.effects)
|
||||
|
||||
def test_different_effects_in_jaxpr(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1.
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertEqual({'foo'}, jaxpr.jaxpr.eqns[0].effects)
|
||||
self.assertEqual({'bar'}, jaxpr.jaxpr.eqns[1].effects)
|
||||
self.assertEqual({'foo', 'bar'}, jaxpr.effects)
|
||||
self.assertEqual({foo_effect}, jaxpr.jaxpr.eqns[0].effects)
|
||||
self.assertEqual({bar_effect}, jaxpr.jaxpr.eqns[1].effects)
|
||||
self.assertEqual({foo_effect, bar_effect}, jaxpr.effects)
|
||||
|
||||
def test_jaxpr_typecheck_should_verify_eqn_effects_are_subset(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1.
|
||||
jaxpr = jax.make_jaxpr(f)(2.).jaxpr
|
||||
|
||||
# Edit jaxpr to make its type wrong
|
||||
jaxpr = jaxpr.replace(effects={'foo'})
|
||||
jaxpr = jaxpr.replace(effects={foo_effect})
|
||||
|
||||
with self.assertRaisesRegex(core.JaxprTypeError,
|
||||
'Equation effects are not subset of Jaxpr effects.'):
|
||||
@ -186,48 +194,52 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
@lu.wrap_init
|
||||
def f_(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return [x]
|
||||
return core.call(f_, x)[0]
|
||||
jax.make_jaxpr(f)(2.)
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
|
||||
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
|
||||
|
||||
def test_xla_call_primitive_inherits_effects(self):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x
|
||||
jax.make_jaxpr(f)(2.)
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
|
||||
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
|
||||
|
||||
@jtu.sample_product(flavor=["old", "new"])
|
||||
def test_remat_call_primitive_inherits_effects(self, flavor):
|
||||
remat = jax.remat if flavor == "old" else ad_checkpoint.checkpoint
|
||||
def test_remat_call_primitive_inherits_effects(self):
|
||||
|
||||
@remat
|
||||
@jax.checkpoint
|
||||
def f(x):
|
||||
x, = effect_p.bind(x, effect='foo')
|
||||
x, = effect_p.bind(x, effect='bar')
|
||||
x, = effect_p.bind(x, effect=foo_effect)
|
||||
x, = effect_p.bind(x, effect=bar_effect)
|
||||
return x
|
||||
jax.make_jaxpr(f)(2.)
|
||||
with self.assertRaisesRegex(NotImplementedError, "Effects not supported"):
|
||||
jax.make_jaxpr(lambda x: jax.linearize(f, x)[1](x))(2.)
|
||||
|
||||
def test_new_remat_allows_certain_effects(self):
|
||||
remat_effect = RematEffect()
|
||||
@ad_checkpoint.checkpoint
|
||||
def f(x):
|
||||
x, = effect_p.bind(x, effect='remat')
|
||||
x, = effect_p.bind(x, effect=remat_effect)
|
||||
return x
|
||||
jaxpr = jax.make_jaxpr(f)(2.)
|
||||
self.assertSetEqual(jaxpr.effects, {"remat"})
|
||||
self.assertSetEqual(jaxpr.effects, {remat_effect})
|
||||
|
||||
def test_custom_jvp_primitive_inherits_effects(self):
|
||||
|
||||
@jax.custom_jvp
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x
|
||||
f.defjvp(lambda x, t: (x, t))
|
||||
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
|
||||
@ -237,8 +249,8 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x
|
||||
f.defvjp(
|
||||
fwd=lambda x: (x, ()),
|
||||
@ -250,27 +262,27 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Ordered effects not supported for map primitives: {'foo'}"):
|
||||
"Ordered effects not supported for map primitives: {.*}"):
|
||||
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
def test_xmap_inherits_effects(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
|
||||
self.assertSetEqual(jaxpr.effects, {"foo", "bar"})
|
||||
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
|
||||
|
||||
def test_pjit_inherits_effects(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['x'])
|
||||
if config.jax_array:
|
||||
@ -280,7 +292,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
|
||||
with mesh:
|
||||
jaxpr = jax.make_jaxpr(f)(np.arange(jax.local_device_count()))
|
||||
self.assertSetEqual(jaxpr.effects, {"foo", "bar"})
|
||||
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
|
||||
|
||||
|
||||
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
@ -291,7 +303,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
config.update('jax_enable_x64', False)
|
||||
self._old_lowering = mlir._lowerings[effect_p]
|
||||
def _effect_lowering(ctx, *, effect):
|
||||
if effect in core.ordered_effects:
|
||||
if effects.ordered_effects.contains(effect):
|
||||
expected_effects = [effect]
|
||||
else:
|
||||
expected_effects = []
|
||||
@ -308,12 +320,20 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
config.update('jax_enable_x64', self.old_x64)
|
||||
mlir.register_lowering(effect_p, self._old_lowering)
|
||||
|
||||
def test_can_lower_lowerable_effect(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
f.lower(2.)
|
||||
|
||||
def test_cannot_lower_unlowerable_effect(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=baz_effect)
|
||||
return x + 1.
|
||||
f.lower(2.)
|
||||
with self.assertRaisesRegex(ValueError, "Cannot lower jaxpr with effects"):
|
||||
f.lower(2.)
|
||||
|
||||
def test_should_not_pass_tokens_into_unordered_effect(self):
|
||||
|
||||
@ -324,7 +344,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1.
|
||||
f.lower(2.)
|
||||
|
||||
@ -337,7 +357,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` needs to '
|
||||
'set `tokens_out`'):
|
||||
@ -346,13 +366,13 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def test_lowering_that_sets_wrong_tokens_should_cause_error(self):
|
||||
|
||||
def bad_effect_lowering(ctx, *, effect):
|
||||
ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get('foo')))
|
||||
ctx.set_tokens_out(mlir.TokenSet(bar=ctx.tokens_in.get(foo_effect)))
|
||||
return []
|
||||
mlir.register_lowering(effect_p, bad_effect_lowering)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
with self.assertRaisesRegex(ValueError, 'Lowering rule for `effect` returns '
|
||||
'incorrect set of output token.'):
|
||||
@ -367,7 +387,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
@ -376,8 +396,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=foo2_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
@ -388,7 +408,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
@ -397,8 +417,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=foo2_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
@ -413,7 +433,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
@ -434,7 +454,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
@ -450,7 +470,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def test_lowered_jaxpr_without_ordered_effects_takes_no_dummy_inputs(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1.
|
||||
module = f.lower(1.).compiler_ir()
|
||||
input_types = module.body.operations[0].type.inputs
|
||||
@ -465,7 +485,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def test_lowered_jaxpr_with_ordered_effects_takes_in_dummy_inputs(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
module = f.lower(1.).compiler_ir()
|
||||
input_types = module.body.operations[0].type.inputs
|
||||
@ -481,8 +501,8 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def test_lowered_jaxpr_with_multiple_ordered_effects_takes_in_dummy_inputs(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=foo2_effect)
|
||||
return x + 1.
|
||||
module = f.lower(1.).compiler_ir()
|
||||
input_types = module.body.operations[0].type.inputs
|
||||
@ -500,14 +520,14 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def test_can_lower_and_run_jaxpr_with_ordered_effects(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
self.assertEqual(f(2.), 3.)
|
||||
|
||||
def test_can_lower_and_run_jaxpr_with_unordered_effects(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1.
|
||||
self.assertEqual(f(2.), 3.)
|
||||
|
||||
@ -517,7 +537,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
@ -530,48 +550,48 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Ordered effects not supported for map primitives: {'foo'}"):
|
||||
"Ordered effects not supported for map primitives: {foo}"):
|
||||
f(jnp.arange(jax.device_count()))
|
||||
|
||||
def test_runtime_tokens_should_update_after_running_effectful_function(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
self.assertNotIn('foo', dispatch.runtime_tokens.tokens)
|
||||
self.assertNotIn(foo_effect, dispatch.runtime_tokens.tokens)
|
||||
f(2.)
|
||||
prev_token = dispatch.runtime_tokens.tokens['foo']
|
||||
prev_token = dispatch.runtime_tokens.tokens[foo_effect]
|
||||
f(2.)
|
||||
curr_token = dispatch.runtime_tokens.tokens['foo']
|
||||
curr_token = dispatch.runtime_tokens.tokens[foo_effect]
|
||||
self.assertIsNot(prev_token, curr_token)
|
||||
|
||||
def test_can_lower_multiple_effects(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=foo2_effect)
|
||||
return x + 1.
|
||||
@jax.jit
|
||||
def g(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
self.assertNotIn('foo', dispatch.runtime_tokens.tokens)
|
||||
self.assertNotIn('foo2', dispatch.runtime_tokens.tokens)
|
||||
self.assertNotIn(foo_effect, dispatch.runtime_tokens.tokens)
|
||||
self.assertNotIn(foo2_effect, dispatch.runtime_tokens.tokens)
|
||||
f(2.).block_until_ready()
|
||||
foo_token = dispatch.runtime_tokens.tokens[foo_effect][0]
|
||||
foo2_token = dispatch.runtime_tokens.tokens[foo2_effect][0]
|
||||
f(2.)
|
||||
foo_token = dispatch.runtime_tokens.tokens['foo'][0]
|
||||
foo2_token = dispatch.runtime_tokens.tokens['foo'][0]
|
||||
f(2.)
|
||||
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
|
||||
self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
|
||||
foo_token = dispatch.runtime_tokens.tokens['foo'][0]
|
||||
foo2_token = dispatch.runtime_tokens.tokens['foo2'][0]
|
||||
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0])
|
||||
self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0])
|
||||
foo_token = dispatch.runtime_tokens.tokens[foo_effect][0]
|
||||
foo2_token = dispatch.runtime_tokens.tokens[foo2_effect][0]
|
||||
g(2.)
|
||||
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
|
||||
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
|
||||
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens[foo_effect][0])
|
||||
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens[foo2_effect][0])
|
||||
|
||||
@jtu.pytest_mark_if_available('pjrt_c_api_unimplemented') # host callback
|
||||
class EffectOrderingTest(jtu.JaxTestCase):
|
||||
@ -589,7 +609,7 @@ class EffectOrderingTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
|
||||
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
|
||||
|
||||
f(2.)
|
||||
jax.effects_barrier()
|
||||
@ -618,11 +638,11 @@ class EffectOrderingTest(jtu.JaxTestCase):
|
||||
# Expensive computation
|
||||
x = x.dot(x)
|
||||
x = jnp.log(x.sum())
|
||||
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
|
||||
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
|
||||
|
||||
@functools.partial(jax.jit, device=jax.devices()[1])
|
||||
def g(x):
|
||||
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
|
||||
return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[])
|
||||
|
||||
f(jnp.ones((500, 500)))
|
||||
g(3.)
|
||||
@ -641,12 +661,12 @@ class EffectOrderingTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
tokens = []
|
||||
def _noop(_):
|
||||
tokens.append(dispatch.runtime_tokens.tokens['log'][0])
|
||||
tokens.append(dispatch.runtime_tokens.tokens[log_effect][0])
|
||||
return ()
|
||||
|
||||
@functools.partial(jax.jit, device=jax.devices()[0])
|
||||
def f(x):
|
||||
return callback_p.bind(x, callback=_noop, effect='log', out_avals=[])
|
||||
return callback_p.bind(x, callback=_noop, effect=log_effect, out_avals=[])
|
||||
|
||||
t1 = threading.Thread(target=lambda: f(2.))
|
||||
t2 = threading.Thread(target=lambda: f(3.))
|
||||
@ -673,7 +693,7 @@ class ParallelEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
def f(x):
|
||||
# foo is lowerable and ordered
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Ordered effects not supported in `pmap`."):
|
||||
@ -683,7 +703,7 @@ class ParallelEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
def f(x):
|
||||
# bar is lowerable and unordered
|
||||
effect_p.bind(effect='bar')
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x
|
||||
jax.pmap(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
@ -705,7 +725,7 @@ class ParallelEffectsTest(jtu.JaxTestCase):
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
callback_p.bind(
|
||||
x, callback=log_value, effect='unordered_log', out_avals=[])
|
||||
x, callback=log_value, effect=unordered_log_effect, out_avals=[])
|
||||
return x + 1
|
||||
f(jnp.arange(2)).block_until_ready()
|
||||
jax.effects_barrier()
|
||||
@ -716,7 +736,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_effects_disallowed_in_cond(self):
|
||||
def f1(x):
|
||||
def true_fun(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x
|
||||
def false_fun(x):
|
||||
return x
|
||||
@ -728,10 +748,10 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_allowed_effect_in_cond(self):
|
||||
def f(x):
|
||||
def true_fun(x):
|
||||
effect_p.bind(effect='while')
|
||||
effect_p.bind(effect=while_effect)
|
||||
return x
|
||||
def false_fun(x):
|
||||
effect_p.bind(effect='while')
|
||||
effect_p.bind(effect=while_effect)
|
||||
return x
|
||||
return lax.cond(x, true_fun, false_fun, x)
|
||||
f(2)
|
||||
@ -739,16 +759,16 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_allowed_effect_in_cond_jvp(self):
|
||||
def f(x):
|
||||
def true_fun(x):
|
||||
effect_p.bind(effect='while')
|
||||
effect_p.bind(effect=while_effect)
|
||||
return x
|
||||
def false_fun(x):
|
||||
effect_p.bind(effect='while')
|
||||
effect_p.bind(effect=while_effect)
|
||||
return x
|
||||
return lax.cond(True, true_fun, false_fun, x)
|
||||
|
||||
# test primal side gets effect
|
||||
primal_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(f, x)[0])(2.)
|
||||
self.assertEqual(primal_jaxpr.effects, {'while'})
|
||||
self.assertEqual(primal_jaxpr.effects, {while_effect})
|
||||
# and tangent side does not
|
||||
_, f_lin = jax.linearize(f, 2.)
|
||||
lin_jaxpr = f_lin.func.fun.args[0]
|
||||
@ -763,8 +783,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
x, = primals
|
||||
t, = tangents
|
||||
# TODO(mattjj,sharadmv): don't require data dependence for jax.linearize!
|
||||
# effect_p.bind(t, effect='while')
|
||||
t, = effect_p.bind(t, effect='while') # data dep only on tangents
|
||||
# effect_p.bind(t, effect=while_effect)
|
||||
t, = effect_p.bind(t, effect=while_effect) # data dep only on tangents
|
||||
return x, t
|
||||
|
||||
def f(x):
|
||||
@ -780,15 +800,15 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
# and tangent side does
|
||||
_, f_lin = jax.linearize(f, 2.)
|
||||
lin_jaxpr = f_lin.func.fun.args[0]
|
||||
self.assertEqual(lin_jaxpr.effects, {'while'})
|
||||
self.assertEqual(lin_jaxpr.effects, {while_effect})
|
||||
|
||||
def test_allowed_ordered_effect_in_cond(self):
|
||||
def f(x):
|
||||
def true_fun(x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect=while1_effect)
|
||||
return x
|
||||
def false_fun(x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect=while1_effect)
|
||||
return x
|
||||
return lax.cond(x, true_fun, false_fun, x)
|
||||
f(2)
|
||||
@ -796,12 +816,12 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_multiple_allowed_ordered_effect_in_cond(self):
|
||||
def f(x):
|
||||
def true_fun(x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect='while2')
|
||||
effect_p.bind(effect=while1_effect)
|
||||
effect_p.bind(effect=while2_effect)
|
||||
return x
|
||||
def false_fun(x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect='while2')
|
||||
effect_p.bind(effect=while1_effect)
|
||||
effect_p.bind(effect=while2_effect)
|
||||
return x
|
||||
return lax.cond(x, true_fun, false_fun, x)
|
||||
f(2)
|
||||
@ -810,7 +830,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def true_fun(x):
|
||||
return x
|
||||
def false_fun(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x
|
||||
return lax.cond(True, true_fun, false_fun, x)
|
||||
|
||||
@ -822,7 +842,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def cond_fun(x):
|
||||
return False
|
||||
def body_fun(x):
|
||||
effect_p.bind(effect='while')
|
||||
effect_p.bind(effect=while_effect)
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
f(2)
|
||||
@ -830,7 +850,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_allowed_effect_in_cond_body(self):
|
||||
def f(x):
|
||||
def cond_fun(x):
|
||||
effect_p.bind(effect='while')
|
||||
effect_p.bind(effect=while_effect)
|
||||
return False
|
||||
def body_fun(x):
|
||||
return x
|
||||
@ -842,7 +862,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def cond_fun(x):
|
||||
return False
|
||||
def body_fun(x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect=while1_effect)
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
f(2)
|
||||
@ -852,8 +872,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def cond_fun(x):
|
||||
return False
|
||||
def body_fun(x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect='while2')
|
||||
effect_p.bind(effect=while1_effect)
|
||||
effect_p.bind(effect=while2_effect)
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
f(2)
|
||||
@ -861,7 +881,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_effects_disallowed_in_while(self):
|
||||
def f1(x):
|
||||
def cond_fun(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return False
|
||||
def body_fun(x):
|
||||
return x
|
||||
@ -874,7 +894,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def cond_fun(x):
|
||||
return False
|
||||
def body_fun(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x
|
||||
return lax.while_loop(cond_fun, body_fun, x)
|
||||
|
||||
@ -884,7 +904,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_allowed_effect_in_scan(self):
|
||||
def f(x):
|
||||
def body_fun(carry, x):
|
||||
effect_p.bind(effect='while')
|
||||
effect_p.bind(effect=while_effect)
|
||||
return carry, x
|
||||
return lax.scan(body_fun, x, jnp.arange(5))
|
||||
f(2)
|
||||
@ -892,7 +912,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_allowed_ordered_effect_in_scan(self):
|
||||
def f(x):
|
||||
def body_fun(carry, x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect=while1_effect)
|
||||
return carry, x
|
||||
return lax.scan(body_fun, x, jnp.arange(5))
|
||||
f(2)
|
||||
@ -900,8 +920,8 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
def test_multiple_allowed_ordered_effect_in_scan(self):
|
||||
def f(x):
|
||||
def body_fun(carry, x):
|
||||
effect_p.bind(effect='while1')
|
||||
effect_p.bind(effect='while2')
|
||||
effect_p.bind(effect=while1_effect)
|
||||
effect_p.bind(effect=while2_effect)
|
||||
return carry, x
|
||||
return lax.scan(body_fun, x, jnp.arange(5))
|
||||
f(2)
|
||||
@ -910,7 +930,7 @@ class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
def f(x):
|
||||
def body(carry, x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return carry, x
|
||||
return lax.scan(body, x, jnp.arange(4))
|
||||
|
||||
|
@ -24,6 +24,7 @@ from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import sharding
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
@ -91,8 +92,7 @@ def callback(f, result_shape, *args, ordered: bool = False, **kwargs):
|
||||
core.ShapedArray(s.shape, s.dtype) for s in flat_result_shapes
|
||||
]
|
||||
effect = (
|
||||
debugging.DebugEffect.ORDERED_PRINT
|
||||
if ordered else debugging.DebugEffect.PRINT)
|
||||
debugging.ordered_debug_effect if ordered else debugging.debug_effect)
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
def _flat_callback(*flat_args):
|
||||
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
||||
@ -111,7 +111,7 @@ def callback_lowering(ctx, *args, effect, callback, **params):
|
||||
return tuple(
|
||||
callback_p.impl(*flat_args, effect=effect, callback=callback, **params))
|
||||
|
||||
if effect in core.ordered_effects:
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token = ctx.tokens_in.get(effect)[0]
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user