Add in runtime tokens for effectful jaxprs

This commit is contained in:
Sharad Vikram 2022-04-14 14:18:31 -07:00
parent 37ea024d39
commit 8031eee7ee
6 changed files with 359 additions and 20 deletions

View File

@ -476,18 +476,21 @@ def _cpp_jit(
# outputs that could be tracers (if f is capturing `Tracer` by closure).
execute: Optional[functools.partial] = (
dispatch._xla_callable.most_recent_entry())
# TODO(sharadmv): Enable fast path for effectful jaxprs
use_fastpath = (
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
execute.func is dispatch._execute_compiled and # not trivial, not pmap
# No effects in computation
not execute.args[5] and
# Not supported: ShardedDeviceArray
all(device_array.type_is_device_array(x) for x in out_flat) and
# Not supported: dynamic shapes
not jax.config.jax_dynamic_shapes)
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
_, xla_executable, _, _, result_handlers, kept_var_idx = execute.args
_, xla_executable, _, _, result_handlers, _, kept_var_idx = execute.args
sticky_device = None
avals = []
lazy_exprs = [None] * len(result_handlers)
@ -811,9 +814,11 @@ def xla_computation(fun: Callable,
else:
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
effects = list(jaxpr.effects)
m = mlir.lower_jaxpr_to_module(
f"xla_computation_{fun_name}",
core.ClosedJaxpr(jaxpr, consts),
effects=effects,
platform=backend,
axis_context=mlir.ReplicaAxisContext(axis_env_),
name_stack=new_name_stack(wrap_name(fun_name, "xla_computation")),

View File

@ -15,6 +15,7 @@
# Primitive dispatch and jit dispatch.
from __future__ import annotations
import atexit
import contextlib
from functools import partial
import itertools
@ -24,6 +25,7 @@ from typing import (
from typing_extensions import Protocol
import os
import re
import threading
import warnings
from absl import logging
@ -100,6 +102,37 @@ def apply_primitive(prim, *args, **params):
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
xla.apply_primitive = apply_primitive
RuntimeToken = Any
class RuntimeTokenSet(threading.local):
tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]]
def __init__(self):
self.tokens = {}
def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
if eff not in self.tokens:
self.tokens[eff] = device_put(np.zeros(0, np.bool_), device), device
elif self.tokens[eff][1] != device:
(old_token,), _ = self.tokens[eff]
old_token.aval = core.ShapedArray((0,), np.bool_)
self.tokens[eff] = device_put(old_token, device), device
return self.tokens[eff][0]
def update_token(self, eff: core.Effect, token: RuntimeToken):
self.tokens[eff] = token, self.tokens[eff][1]
def clear(self):
self.tokens = {}
def block_until_ready(self):
[t[0].block_until_ready() for t, _ in self.tokens.values()]
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
@atexit.register
def wait_for_tokens():
runtime_tokens.block_until_ready()
@util.cache()
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
@ -256,7 +289,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
if not jaxpr.eqns and not always_lower:
return XlaComputation(
name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device,
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
in_avals=abstract_args, out_avals=out_avals, effects=jaxpr.effects,
kept_var_idx=kept_var_idx)
if not _on_exit:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
@ -291,13 +325,15 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module_name = f"jit_{fun.__name__}"
effects = [eff for eff in closed_jaxpr.effects if eff in core.ordered_effects]
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform,
module_name, closed_jaxpr, effects, backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
return XlaComputation(
name, module, False, donated_invars, which_explicit, nreps=nreps,
device=device, backend=backend, tuple_args=tuple_args,
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
in_avals=abstract_args, out_avals=out_avals, effects=effects,
kept_var_idx=kept_var_idx)
def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
@ -544,27 +580,48 @@ def _check_special(name, xla_shape, buf):
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _add_tokens(effects: List[core.Effect], device, input_bufs):
tokens = [runtime_tokens.get_token(eff, device) for eff in effects]
tokens_flat = flatten(tokens)
input_bufs = [*tokens_flat, *input_bufs]
def _remove_tokens(output_bufs):
token_bufs, output_bufs = util.split_list(output_bufs, [len(effects)])
for eff, token_buf in zip(effects, token_bufs):
runtime_tokens.update_token(eff, token_buf)
return output_bufs
return input_bufs, _remove_tokens
def _execute_compiled(name: str, compiled: XlaExecutable,
input_handler: Optional[Callable],
output_buffer_counts: Optional[Sequence[int]],
result_handlers, kept_var_idx, *args):
result_handlers,
effects: List[core.Effect],
kept_var_idx, *args):
device, = compiled.local_devices()
args = input_handler(args) if input_handler else args
input_bufs_flat = flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx)
if effects:
input_bufs_flat, token_handler = _add_tokens(effects, device, input_bufs_flat)
out_bufs_flat = compiled.execute(input_bufs_flat)
check_special(name, out_bufs_flat)
if output_buffer_counts is None:
return (result_handlers[0](*out_bufs_flat),)
out_bufs = unflatten(out_bufs_flat, output_buffer_counts)
if effects:
out_bufs = token_handler(out_bufs)
return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
def _execute_replicated(name: str, compiled: XlaExecutable,
input_handler: Optional[Callable],
output_buffer_counts: Optional[Sequence[int]],
result_handlers, kept_var_idx, *args):
result_handlers,
effects: List[core.Effect],
kept_var_idx, *args):
if effects:
raise NotImplementedError('Cannot execute replicated computation with effects.')
if input_handler: raise NotImplementedError # TODO(mattjj, dougalm)
input_bufs = [flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx)
@ -580,7 +637,7 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
kept_var_idx, *args):
_: List[core.Effect], kept_var_idx, *args):
env = {core.unitvar: core.unit}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
@ -721,6 +778,7 @@ class XlaCompiledComputation(stages.Executable):
tuple_args: bool,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
effects: List[core.Effect],
kept_var_idx: Set[int]) -> XlaCompiledComputation:
sticky_device = device
input_handler = _input_handler(backend, explicit_args, in_avals)
@ -735,9 +793,13 @@ class XlaCompiledComputation(stages.Executable):
compiled = compile_or_get_cached(backend, xla_computation, options)
buffer_counts = (None if len(out_avals) == 1 and not config.jax_dynamic_shapes
else [aval_to_num_buffers(aval) for aval in out_avals])
if effects:
if buffer_counts is None:
buffer_counts = [1]
buffer_counts = ([1] * len(effects)) + buffer_counts
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts,
result_handlers, kept_var_idx)
result_handlers, effects, kept_var_idx)
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
def is_trivial(self):
@ -751,11 +813,11 @@ class XlaCompiledComputation(stages.Executable):
return self._xla_executable
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals, effects,
kept_var_idx) -> XlaCompiledComputation:
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
out_avals, result_handlers, kept_var_idx)
out_avals, result_handlers, effects, kept_var_idx)
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)
# -- stages.Executable protocol

View File

@ -473,8 +473,11 @@ def sharded_aval(aval: core.ShapedArray,
sharded_shape.append((aval.shape[i] + partitions - 1) // partitions)
return aval.update(tuple(sharded_shape))
def lower_jaxpr_to_module(
module_name: str, jaxpr: core.ClosedJaxpr, platform: str,
module_name: str, jaxpr: core.ClosedJaxpr,
effects: List[core.Effect],
platform: str,
axis_context: AxisContext,
name_stack: NameStack, donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
@ -504,6 +507,8 @@ def lower_jaxpr_to_module(
if platform in platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
in_avals, out_avals, donated_args)
if any(eff not in lowerable_effects for eff in jaxpr.effects):
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
if any(donated_args):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(a) for a, d in zip(in_avals, donated_args)
@ -528,7 +533,6 @@ def lower_jaxpr_to_module(
if unlowerable_effects:
raise ValueError(
f'Cannot lower jaxpr with unlowerable effects: {unlowerable_effects}')
effects = [eff for eff in jaxpr.effects if eff in core.ordered_effects]
lower_jaxpr_to_fun(
ctx, "main", jaxpr, effects, public=True, create_tokens=True,
replace_units_with_dummy=True,
@ -623,6 +627,12 @@ class TokenSet:
new_tokens.append(self._tokens[eff])
return TokenSet(zip(self.effects(), new_tokens))
def dummy_token_type() -> Sequence[ir.Type]:
return aval_to_ir_types(core.ShapedArray((0,), np.bool_))
def dummy_token() -> Sequence[ir.Value]:
return ir_constants(np.zeros(0, np.bool_))
def lower_jaxpr_to_fun(
ctx: ModuleContext,
name: str,
@ -650,6 +660,7 @@ def lower_jaxpr_to_fun(
jaxpr: the jaxpr to lower.
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
that will be created in or used by the lowered function.
create_tokens: if true, the MHLO will create tokens and ignore dummy input tokens.
public: if true, the function's visibility is set to "public".
replace_units_with_dummy: if true, unit arguments/return values are
replaced with bool arrays of size [0].
@ -676,10 +687,10 @@ def lower_jaxpr_to_fun(
input_types = map(aval_to_types, jaxpr.in_avals)
output_types = map(aval_to_types, jaxpr.out_avals)
num_tokens = len(effects)
if create_tokens:
# If we create the tokens they won't be inputs to the MLIR function.
num_tokens = 0
token_types = []
token_types = [dummy_token_type() for _ in effects]
else:
# If we aren't creating tokens they will be the initial inputs to the
# MLIR function.
@ -789,9 +800,8 @@ def lower_jaxpr_to_fun(
*args)
outs = []
if create_tokens:
# If we created the tokens in this function, we are done with them and can
# ignore `tokens_out`.
pass
for _ in effects:
outs.append(dummy_token())
else:
for token in tokens_out.tokens():
outs.append(token)

View File

@ -1053,8 +1053,9 @@ def lower_parallel_callable(
tuple_args = should_tuple_args(shards)
module_name = f"pmap_{fun.__name__}"
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
effects = list(closed_jaxpr.effects)
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env),
module_name, closed_jaxpr, effects, backend.platform, mlir.ReplicaAxisContext(axis_env),
name_stack, donated_invars, replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
@ -2236,8 +2237,9 @@ def lower_mesh_computation(
module: Union[str, xc.XlaComputation]
module_name = f"{api_name}_{fun_name}"
with core.extend_axis_env_nd(mesh.shape.items()):
effects = list(closed_jaxpr.effects)
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, axis_ctx, name_stack,
module_name, closed_jaxpr, effects, backend.platform, axis_ctx, name_stack,
donated_invars, replicated_args=replicated_args,
arg_shardings=in_partitions, result_shardings=out_partitions)

View File

@ -140,9 +140,11 @@ def _sharded_callable(
fun.__name__, nparts, global_abstract_args)
axis_env = xla.AxisEnv(nrep, (), ())
effects = list(jaxpr.effects)
module = mlir.lower_jaxpr_to_module(
"spjit_{}".format(fun.__name__),
core.ClosedJaxpr(jaxpr, consts),
effects,
platform=platform,
axis_context=mlir.ReplicaAxisContext(axis_env),
name_stack=new_name_stack(wrap_name(name, "sharded_jit")),

View File

@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import threading
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -20,10 +23,12 @@ from jax import ad_checkpoint
from jax import core
from jax import lax
from jax import linear_util as lu
from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.config import config
from jax.interpreters import mlir
from jax._src import lib as jaxlib
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
import numpy as np
@ -47,6 +52,7 @@ core.ordered_effects.add('foo2')
def trivial_effect_lowering(ctx, *, effect):
ctx.set_tokens_out(ctx.tokens_in)
return []
mlir.register_lowering(effect_p, trivial_effect_lowering)
def function_effect_lowering(ctx, *, effect):
def _f(ctx):
@ -65,6 +71,58 @@ def function_effect_lowering(ctx, *, effect):
ctx.set_tokens_out(mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)))
return out
callback_p = core.Primitive('callback')
callback_p.multiple_results = True
mlir.lowerable_effects.add('log')
core.ordered_effects.add('log')
@callback_p.def_impl
def _(*args, callback, out_avals, effect):
del out_avals, effect
callback(*args)
return []
@callback_p.def_effectful_abstract_eval
def _(*avals, callback, out_avals, effect):
del avals, callback
return out_avals, {effect}
# TODO(sharadmv): Attach keep alive to executable
leak = [].append
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect):
del out_avals
def _token_callback(token, *args):
out = callback(*args)
flat_out = jax.tree_util.tree_leaves(out)
return (token, *flat_out)
token_in = ctx.tokens_in.get(effect)[0]
(token_out, *out_op), keep_alive = mlir.emit_python_callback(
ctx.module_context.platform, _token_callback,
[token_in, *args], [core.abstract_token, *ctx.avals_in],
[core.abstract_token, *ctx.avals_out], True)
leak(keep_alive)
ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect:
token_out})))
return out_op
mlir.register_lowering(callback_p, callback_effect_lowering)
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
# This will control the CPU devices. On TPU we always have 2 devices
prev_xla_flags = jtu.set_host_platform_device_count(2)
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
prev_xla_flags()
class JaxprEffectsTest(jtu.JaxTestCase):
@ -210,6 +268,18 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
return []
mlir.register_lowering(effect_p, _effect_lowering)
def setUp(self):
super().setUp()
self.old_x64 = config.jax_enable_x64
config.update('jax_enable_x64', False)
dispatch.runtime_tokens.clear()
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
config.update('jax_enable_x64', self.old_x64)
def test_cannot_lower_unlowerable_effect(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
@ -350,6 +420,194 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
self.assertLen(list(func.type.inputs), 0)
self.assertLen(list(func.type.results), 0)
def test_lowered_jaxpr_without_ordered_effects_takes_no_dummy_inputs(self):
@jax.jit
def f(x):
effect_p.bind(effect='bar')
return x + 1.
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
input_types = mhlo.body.operations[0].type.inputs
# First argument should be dummy token
self.assertLen(list(input_types), 1)
self.assertEqual(str(input_types[0]), 'tensor<f32>')
# First output should be dummy token
result_types = mhlo.body.operations[0].type.results
self.assertLen(list(result_types), 1)
self.assertEqual(str(result_types[0]), 'tensor<f32>')
def test_lowered_jaxpr_with_ordered_effects_takes_in_dummy_inputs(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
return x + 1.
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
input_types = mhlo.body.operations[0].type.inputs
# First argument should be dummy token
self.assertLen(list(input_types), 2)
self.assertEqual(str(input_types[0]), 'tensor<0xi1>')
# First output should be dummy token
result_types = mhlo.body.operations[0].type.results
self.assertLen(list(result_types), 2)
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
def test_lowered_jaxpr_with_multiple_ordered_effects_takes_in_dummy_inputs(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
return x + 1.
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
input_types = mhlo.body.operations[0].type.inputs
# First two arguments should be dummy values
self.assertLen(list(input_types), 3)
self.assertEqual(str(input_types[0]), 'tensor<0xi1>')
self.assertEqual(str(input_types[1]), 'tensor<0xi1>')
# First two outputs should be dummy values
result_types = mhlo.body.operations[0].type.results
self.assertLen(list(result_types), 3)
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
self.assertEqual(str(result_types[1]), 'tensor<0xi1>')
def test_can_lower_and_run_jaxpr_with_ordered_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
return x + 1.
self.assertEqual(f(2.), 3.)
def test_can_lower_and_run_jaxpr_with_unordered_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect='bar')
return x + 1.
self.assertEqual(f(2.), 3.)
def test_runtime_tokens_should_update_after_running_effectful_function(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
return x + 1.
self.assertNotIn('foo', dispatch.runtime_tokens.tokens)
f(2.)
prev_token = dispatch.runtime_tokens.tokens['foo']
f(2.)
curr_token = dispatch.runtime_tokens.tokens['foo']
self.assertIsNot(prev_token, curr_token)
def test_can_lower_multiple_effects(self):
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
return x + 1.
@jax.jit
def g(x):
effect_p.bind(effect='foo')
return x + 1.
self.assertNotIn('foo', dispatch.runtime_tokens.tokens)
self.assertNotIn('foo2', dispatch.runtime_tokens.tokens)
f(2.)
foo_token = dispatch.runtime_tokens.tokens['foo'][0]
foo2_token = dispatch.runtime_tokens.tokens['foo'][0]
f(2.)
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
foo_token = dispatch.runtime_tokens.tokens['foo'][0]
foo2_token = dispatch.runtime_tokens.tokens['foo2'][0]
g(2.)
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
class EffectOrderingTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_can_execute_python_callback(self):
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
log = []
def log_value(x):
log.append(x)
return ()
@jax.jit
def f(x):
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
f(2.)
self.assertListEqual(log, [2.])
f(3.)
self.assertListEqual(log, [2., 3.])
dispatch.runtime_tokens.block_until_ready()
@jtu.skip_on_devices("tpu", "gpu")
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
log = []
def log_value(x):
log.append(x)
return ()
@functools.partial(jax.jit, device=jax.devices()[0])
def f(x):
# Expensive computation
x = x.dot(x)
x = jnp.log(x.sum())
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
@functools.partial(jax.jit, device=jax.devices()[1])
def g(x):
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
f(jnp.ones((500, 500)))
g(3.)
f(jnp.ones((500, 500)))
g(3.)
f(jnp.ones((500, 500)))
g(3.)
dispatch.runtime_tokens.block_until_ready()
x_, y_ = float(jnp.log(1.25e8)), 3.
expected_log = [x_, y_, x_, y_, x_, y_]
self.assertListEqual(log, expected_log)
@jtu.skip_on_devices("tpu", "gpu")
def test_different_threads_get_different_tokens(self):
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
tokens = []
def _noop(_):
tokens.append(dispatch.runtime_tokens.tokens['log'][0])
return ()
@functools.partial(jax.jit, device=jax.devices()[0])
def f(x):
return callback_p.bind(x, callback=_noop, effect='log', out_avals=[])
t1 = threading.Thread(target=lambda: f(2.))
t2 = threading.Thread(target=lambda: f(3.))
t1.start()
t2.start()
t1.join()
t2.join()
token1, token2 = tokens
self.assertIsNot(token1, token2)
class ControlFlowEffectsTest(jtu.JaxTestCase):