mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add in runtime tokens for effectful jaxprs
This commit is contained in:
parent
37ea024d39
commit
8031eee7ee
@ -476,18 +476,21 @@ def _cpp_jit(
|
||||
# outputs that could be tracers (if f is capturing `Tracer` by closure).
|
||||
execute: Optional[functools.partial] = (
|
||||
dispatch._xla_callable.most_recent_entry())
|
||||
# TODO(sharadmv): Enable fast path for effectful jaxprs
|
||||
use_fastpath = (
|
||||
# This is if we have already executed this code-path (most-recent entry
|
||||
# has been reset to None). Thus, we do not support the fast-path.
|
||||
execute is not None and
|
||||
execute.func is dispatch._execute_compiled and # not trivial, not pmap
|
||||
# No effects in computation
|
||||
not execute.args[5] and
|
||||
# Not supported: ShardedDeviceArray
|
||||
all(device_array.type_is_device_array(x) for x in out_flat) and
|
||||
# Not supported: dynamic shapes
|
||||
not jax.config.jax_dynamic_shapes)
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
_, xla_executable, _, _, result_handlers, kept_var_idx = execute.args
|
||||
_, xla_executable, _, _, result_handlers, _, kept_var_idx = execute.args
|
||||
sticky_device = None
|
||||
avals = []
|
||||
lazy_exprs = [None] * len(result_handlers)
|
||||
@ -811,9 +814,11 @@ def xla_computation(fun: Callable,
|
||||
else:
|
||||
out_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation out_parts", out_tree(), out_parts))
|
||||
effects = list(jaxpr.effects)
|
||||
m = mlir.lower_jaxpr_to_module(
|
||||
f"xla_computation_{fun_name}",
|
||||
core.ClosedJaxpr(jaxpr, consts),
|
||||
effects=effects,
|
||||
platform=backend,
|
||||
axis_context=mlir.ReplicaAxisContext(axis_env_),
|
||||
name_stack=new_name_stack(wrap_name(fun_name, "xla_computation")),
|
||||
|
@ -15,6 +15,7 @@
|
||||
# Primitive dispatch and jit dispatch.
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools
|
||||
@ -24,6 +25,7 @@ from typing import (
|
||||
from typing_extensions import Protocol
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
@ -100,6 +102,37 @@ def apply_primitive(prim, *args, **params):
|
||||
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
|
||||
xla.apply_primitive = apply_primitive
|
||||
|
||||
RuntimeToken = Any
|
||||
|
||||
class RuntimeTokenSet(threading.local):
|
||||
tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]]
|
||||
|
||||
def __init__(self):
|
||||
self.tokens = {}
|
||||
|
||||
def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
|
||||
if eff not in self.tokens:
|
||||
self.tokens[eff] = device_put(np.zeros(0, np.bool_), device), device
|
||||
elif self.tokens[eff][1] != device:
|
||||
(old_token,), _ = self.tokens[eff]
|
||||
old_token.aval = core.ShapedArray((0,), np.bool_)
|
||||
self.tokens[eff] = device_put(old_token, device), device
|
||||
return self.tokens[eff][0]
|
||||
|
||||
def update_token(self, eff: core.Effect, token: RuntimeToken):
|
||||
self.tokens[eff] = token, self.tokens[eff][1]
|
||||
|
||||
def clear(self):
|
||||
self.tokens = {}
|
||||
|
||||
def block_until_ready(self):
|
||||
[t[0].block_until_ready() for t, _ in self.tokens.values()]
|
||||
|
||||
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
|
||||
|
||||
@atexit.register
|
||||
def wait_for_tokens():
|
||||
runtime_tokens.block_until_ready()
|
||||
|
||||
@util.cache()
|
||||
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
|
||||
@ -256,7 +289,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
if not jaxpr.eqns and not always_lower:
|
||||
return XlaComputation(
|
||||
name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device,
|
||||
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
|
||||
in_avals=abstract_args, out_avals=out_avals, effects=jaxpr.effects,
|
||||
kept_var_idx=kept_var_idx)
|
||||
|
||||
if not _on_exit:
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
@ -291,13 +325,15 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module_name = f"jit_{fun.__name__}"
|
||||
effects = [eff for eff in closed_jaxpr.effects if eff in core.ordered_effects]
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform,
|
||||
module_name, closed_jaxpr, effects, backend.platform,
|
||||
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
|
||||
return XlaComputation(
|
||||
name, module, False, donated_invars, which_explicit, nreps=nreps,
|
||||
device=device, backend=backend, tuple_args=tuple_args,
|
||||
in_avals=abstract_args, out_avals=out_avals, kept_var_idx=kept_var_idx)
|
||||
in_avals=abstract_args, out_avals=out_avals, effects=effects,
|
||||
kept_var_idx=kept_var_idx)
|
||||
|
||||
|
||||
def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
|
||||
@ -544,27 +580,48 @@ def _check_special(name, xla_shape, buf):
|
||||
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
|
||||
def _add_tokens(effects: List[core.Effect], device, input_bufs):
|
||||
tokens = [runtime_tokens.get_token(eff, device) for eff in effects]
|
||||
tokens_flat = flatten(tokens)
|
||||
input_bufs = [*tokens_flat, *input_bufs]
|
||||
def _remove_tokens(output_bufs):
|
||||
token_bufs, output_bufs = util.split_list(output_bufs, [len(effects)])
|
||||
for eff, token_buf in zip(effects, token_bufs):
|
||||
runtime_tokens.update_token(eff, token_buf)
|
||||
return output_bufs
|
||||
return input_bufs, _remove_tokens
|
||||
|
||||
|
||||
def _execute_compiled(name: str, compiled: XlaExecutable,
|
||||
input_handler: Optional[Callable],
|
||||
output_buffer_counts: Optional[Sequence[int]],
|
||||
result_handlers, kept_var_idx, *args):
|
||||
result_handlers,
|
||||
effects: List[core.Effect],
|
||||
kept_var_idx, *args):
|
||||
device, = compiled.local_devices()
|
||||
args = input_handler(args) if input_handler else args
|
||||
input_bufs_flat = flatten(device_put(x, device) for i, x in enumerate(args)
|
||||
if i in kept_var_idx)
|
||||
if effects:
|
||||
input_bufs_flat, token_handler = _add_tokens(effects, device, input_bufs_flat)
|
||||
out_bufs_flat = compiled.execute(input_bufs_flat)
|
||||
check_special(name, out_bufs_flat)
|
||||
if output_buffer_counts is None:
|
||||
return (result_handlers[0](*out_bufs_flat),)
|
||||
out_bufs = unflatten(out_bufs_flat, output_buffer_counts)
|
||||
if effects:
|
||||
out_bufs = token_handler(out_bufs)
|
||||
return tuple(h(*bs) for h, bs in unsafe_zip(result_handlers, out_bufs))
|
||||
|
||||
|
||||
def _execute_replicated(name: str, compiled: XlaExecutable,
|
||||
input_handler: Optional[Callable],
|
||||
output_buffer_counts: Optional[Sequence[int]],
|
||||
result_handlers, kept_var_idx, *args):
|
||||
result_handlers,
|
||||
effects: List[core.Effect],
|
||||
kept_var_idx, *args):
|
||||
if effects:
|
||||
raise NotImplementedError('Cannot execute replicated computation with effects.')
|
||||
if input_handler: raise NotImplementedError # TODO(mattjj, dougalm)
|
||||
input_bufs = [flatten(device_put(x, device) for i, x in enumerate(args)
|
||||
if i in kept_var_idx)
|
||||
@ -580,7 +637,7 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
|
||||
|
||||
|
||||
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
|
||||
kept_var_idx, *args):
|
||||
_: List[core.Effect], kept_var_idx, *args):
|
||||
env = {core.unitvar: core.unit}
|
||||
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
||||
map(env.setdefault, jaxpr.invars, pruned_args)
|
||||
@ -721,6 +778,7 @@ class XlaCompiledComputation(stages.Executable):
|
||||
tuple_args: bool,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue],
|
||||
effects: List[core.Effect],
|
||||
kept_var_idx: Set[int]) -> XlaCompiledComputation:
|
||||
sticky_device = device
|
||||
input_handler = _input_handler(backend, explicit_args, in_avals)
|
||||
@ -735,9 +793,13 @@ class XlaCompiledComputation(stages.Executable):
|
||||
compiled = compile_or_get_cached(backend, xla_computation, options)
|
||||
buffer_counts = (None if len(out_avals) == 1 and not config.jax_dynamic_shapes
|
||||
else [aval_to_num_buffers(aval) for aval in out_avals])
|
||||
if effects:
|
||||
if buffer_counts is None:
|
||||
buffer_counts = [1]
|
||||
buffer_counts = ([1] * len(effects)) + buffer_counts
|
||||
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
||||
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts,
|
||||
result_handlers, kept_var_idx)
|
||||
result_handlers, effects, kept_var_idx)
|
||||
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
|
||||
|
||||
def is_trivial(self):
|
||||
@ -751,11 +813,11 @@ class XlaCompiledComputation(stages.Executable):
|
||||
return self._xla_executable
|
||||
|
||||
@staticmethod
|
||||
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
|
||||
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals, effects,
|
||||
kept_var_idx) -> XlaCompiledComputation:
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
|
||||
out_avals, result_handlers, kept_var_idx)
|
||||
out_avals, result_handlers, effects, kept_var_idx)
|
||||
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)
|
||||
|
||||
# -- stages.Executable protocol
|
||||
|
@ -473,8 +473,11 @@ def sharded_aval(aval: core.ShapedArray,
|
||||
sharded_shape.append((aval.shape[i] + partitions - 1) // partitions)
|
||||
return aval.update(tuple(sharded_shape))
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str, jaxpr: core.ClosedJaxpr, platform: str,
|
||||
module_name: str, jaxpr: core.ClosedJaxpr,
|
||||
effects: List[core.Effect],
|
||||
platform: str,
|
||||
axis_context: AxisContext,
|
||||
name_stack: NameStack, donated_args: Sequence[bool],
|
||||
replicated_args: Optional[Sequence[bool]] = None,
|
||||
@ -504,6 +507,8 @@ def lower_jaxpr_to_module(
|
||||
if platform in platforms_with_donation:
|
||||
input_output_aliases, donated_args = _set_up_aliases(
|
||||
in_avals, out_avals, donated_args)
|
||||
if any(eff not in lowerable_effects for eff in jaxpr.effects):
|
||||
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
|
||||
if any(donated_args):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(a) for a, d in zip(in_avals, donated_args)
|
||||
@ -528,7 +533,6 @@ def lower_jaxpr_to_module(
|
||||
if unlowerable_effects:
|
||||
raise ValueError(
|
||||
f'Cannot lower jaxpr with unlowerable effects: {unlowerable_effects}')
|
||||
effects = [eff for eff in jaxpr.effects if eff in core.ordered_effects]
|
||||
lower_jaxpr_to_fun(
|
||||
ctx, "main", jaxpr, effects, public=True, create_tokens=True,
|
||||
replace_units_with_dummy=True,
|
||||
@ -623,6 +627,12 @@ class TokenSet:
|
||||
new_tokens.append(self._tokens[eff])
|
||||
return TokenSet(zip(self.effects(), new_tokens))
|
||||
|
||||
def dummy_token_type() -> Sequence[ir.Type]:
|
||||
return aval_to_ir_types(core.ShapedArray((0,), np.bool_))
|
||||
|
||||
def dummy_token() -> Sequence[ir.Value]:
|
||||
return ir_constants(np.zeros(0, np.bool_))
|
||||
|
||||
def lower_jaxpr_to_fun(
|
||||
ctx: ModuleContext,
|
||||
name: str,
|
||||
@ -650,6 +660,7 @@ def lower_jaxpr_to_fun(
|
||||
jaxpr: the jaxpr to lower.
|
||||
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
|
||||
that will be created in or used by the lowered function.
|
||||
create_tokens: if true, the MHLO will create tokens and ignore dummy input tokens.
|
||||
public: if true, the function's visibility is set to "public".
|
||||
replace_units_with_dummy: if true, unit arguments/return values are
|
||||
replaced with bool arrays of size [0].
|
||||
@ -676,10 +687,10 @@ def lower_jaxpr_to_fun(
|
||||
|
||||
input_types = map(aval_to_types, jaxpr.in_avals)
|
||||
output_types = map(aval_to_types, jaxpr.out_avals)
|
||||
num_tokens = len(effects)
|
||||
if create_tokens:
|
||||
# If we create the tokens they won't be inputs to the MLIR function.
|
||||
num_tokens = 0
|
||||
token_types = []
|
||||
token_types = [dummy_token_type() for _ in effects]
|
||||
else:
|
||||
# If we aren't creating tokens they will be the initial inputs to the
|
||||
# MLIR function.
|
||||
@ -789,9 +800,8 @@ def lower_jaxpr_to_fun(
|
||||
*args)
|
||||
outs = []
|
||||
if create_tokens:
|
||||
# If we created the tokens in this function, we are done with them and can
|
||||
# ignore `tokens_out`.
|
||||
pass
|
||||
for _ in effects:
|
||||
outs.append(dummy_token())
|
||||
else:
|
||||
for token in tokens_out.tokens():
|
||||
outs.append(token)
|
||||
|
@ -1053,8 +1053,9 @@ def lower_parallel_callable(
|
||||
tuple_args = should_tuple_args(shards)
|
||||
module_name = f"pmap_{fun.__name__}"
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
effects = list(closed_jaxpr.effects)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env),
|
||||
module_name, closed_jaxpr, effects, backend.platform, mlir.ReplicaAxisContext(axis_env),
|
||||
name_stack, donated_invars, replicated_args=replicated_args,
|
||||
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
|
||||
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
|
||||
@ -2236,8 +2237,9 @@ def lower_mesh_computation(
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
effects = list(closed_jaxpr.effects)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, backend.platform, axis_ctx, name_stack,
|
||||
module_name, closed_jaxpr, effects, backend.platform, axis_ctx, name_stack,
|
||||
donated_invars, replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions, result_shardings=out_partitions)
|
||||
|
||||
|
@ -140,9 +140,11 @@ def _sharded_callable(
|
||||
fun.__name__, nparts, global_abstract_args)
|
||||
|
||||
axis_env = xla.AxisEnv(nrep, (), ())
|
||||
effects = list(jaxpr.effects)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
"spjit_{}".format(fun.__name__),
|
||||
core.ClosedJaxpr(jaxpr, consts),
|
||||
effects,
|
||||
platform=platform,
|
||||
axis_context=mlir.ReplicaAxisContext(axis_env),
|
||||
name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
|
||||
|
@ -11,6 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -20,10 +23,12 @@ from jax import ad_checkpoint
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.config import config
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import lib as jaxlib
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
import numpy as np
|
||||
@ -47,6 +52,7 @@ core.ordered_effects.add('foo2')
|
||||
def trivial_effect_lowering(ctx, *, effect):
|
||||
ctx.set_tokens_out(ctx.tokens_in)
|
||||
return []
|
||||
mlir.register_lowering(effect_p, trivial_effect_lowering)
|
||||
|
||||
def function_effect_lowering(ctx, *, effect):
|
||||
def _f(ctx):
|
||||
@ -65,6 +71,58 @@ def function_effect_lowering(ctx, *, effect):
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)))
|
||||
return out
|
||||
|
||||
callback_p = core.Primitive('callback')
|
||||
callback_p.multiple_results = True
|
||||
|
||||
mlir.lowerable_effects.add('log')
|
||||
core.ordered_effects.add('log')
|
||||
|
||||
@callback_p.def_impl
|
||||
def _(*args, callback, out_avals, effect):
|
||||
del out_avals, effect
|
||||
callback(*args)
|
||||
return []
|
||||
|
||||
@callback_p.def_effectful_abstract_eval
|
||||
def _(*avals, callback, out_avals, effect):
|
||||
del avals, callback
|
||||
return out_avals, {effect}
|
||||
|
||||
|
||||
# TODO(sharadmv): Attach keep alive to executable
|
||||
leak = [].append
|
||||
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect):
|
||||
del out_avals
|
||||
def _token_callback(token, *args):
|
||||
out = callback(*args)
|
||||
flat_out = jax.tree_util.tree_leaves(out)
|
||||
return (token, *flat_out)
|
||||
token_in = ctx.tokens_in.get(effect)[0]
|
||||
(token_out, *out_op), keep_alive = mlir.emit_python_callback(
|
||||
ctx.module_context.platform, _token_callback,
|
||||
[token_in, *args], [core.abstract_token, *ctx.avals_in],
|
||||
[core.abstract_token, *ctx.avals_out], True)
|
||||
leak(keep_alive)
|
||||
ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect:
|
||||
token_out})))
|
||||
return out_op
|
||||
|
||||
mlir.register_lowering(callback_p, callback_effect_lowering)
|
||||
|
||||
|
||||
prev_xla_flags = None
|
||||
|
||||
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
# This will control the CPU devices. On TPU we always have 2 devices
|
||||
prev_xla_flags = jtu.set_host_platform_device_count(2)
|
||||
|
||||
|
||||
# Reset to previous configuration in case other test modules will be run.
|
||||
def tearDownModule():
|
||||
prev_xla_flags()
|
||||
|
||||
|
||||
class JaxprEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
@ -210,6 +268,18 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
return []
|
||||
mlir.register_lowering(effect_p, _effect_lowering)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.old_x64 = config.jax_enable_x64
|
||||
config.update('jax_enable_x64', False)
|
||||
dispatch.runtime_tokens.clear()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
config.update('jax_enable_x64', self.old_x64)
|
||||
|
||||
def test_cannot_lower_unlowerable_effect(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
@ -350,6 +420,194 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
self.assertLen(list(func.type.inputs), 0)
|
||||
self.assertLen(list(func.type.results), 0)
|
||||
|
||||
def test_lowered_jaxpr_without_ordered_effects_takes_no_dummy_inputs(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
return x + 1.
|
||||
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
|
||||
input_types = mhlo.body.operations[0].type.inputs
|
||||
# First argument should be dummy token
|
||||
self.assertLen(list(input_types), 1)
|
||||
self.assertEqual(str(input_types[0]), 'tensor<f32>')
|
||||
|
||||
# First output should be dummy token
|
||||
result_types = mhlo.body.operations[0].type.results
|
||||
self.assertLen(list(result_types), 1)
|
||||
self.assertEqual(str(result_types[0]), 'tensor<f32>')
|
||||
|
||||
def test_lowered_jaxpr_with_ordered_effects_takes_in_dummy_inputs(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
|
||||
input_types = mhlo.body.operations[0].type.inputs
|
||||
# First argument should be dummy token
|
||||
self.assertLen(list(input_types), 2)
|
||||
self.assertEqual(str(input_types[0]), 'tensor<0xi1>')
|
||||
|
||||
# First output should be dummy token
|
||||
result_types = mhlo.body.operations[0].type.results
|
||||
self.assertLen(list(result_types), 2)
|
||||
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
|
||||
|
||||
def test_lowered_jaxpr_with_multiple_ordered_effects_takes_in_dummy_inputs(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
return x + 1.
|
||||
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
|
||||
input_types = mhlo.body.operations[0].type.inputs
|
||||
# First two arguments should be dummy values
|
||||
self.assertLen(list(input_types), 3)
|
||||
self.assertEqual(str(input_types[0]), 'tensor<0xi1>')
|
||||
self.assertEqual(str(input_types[1]), 'tensor<0xi1>')
|
||||
|
||||
# First two outputs should be dummy values
|
||||
result_types = mhlo.body.operations[0].type.results
|
||||
self.assertLen(list(result_types), 3)
|
||||
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
|
||||
self.assertEqual(str(result_types[1]), 'tensor<0xi1>')
|
||||
|
||||
def test_can_lower_and_run_jaxpr_with_ordered_effects(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
self.assertEqual(f(2.), 3.)
|
||||
|
||||
def test_can_lower_and_run_jaxpr_with_unordered_effects(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
return x + 1.
|
||||
self.assertEqual(f(2.), 3.)
|
||||
|
||||
def test_runtime_tokens_should_update_after_running_effectful_function(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
self.assertNotIn('foo', dispatch.runtime_tokens.tokens)
|
||||
f(2.)
|
||||
prev_token = dispatch.runtime_tokens.tokens['foo']
|
||||
f(2.)
|
||||
curr_token = dispatch.runtime_tokens.tokens['foo']
|
||||
self.assertIsNot(prev_token, curr_token)
|
||||
|
||||
def test_can_lower_multiple_effects(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
return x + 1.
|
||||
@jax.jit
|
||||
def g(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
self.assertNotIn('foo', dispatch.runtime_tokens.tokens)
|
||||
self.assertNotIn('foo2', dispatch.runtime_tokens.tokens)
|
||||
f(2.)
|
||||
foo_token = dispatch.runtime_tokens.tokens['foo'][0]
|
||||
foo2_token = dispatch.runtime_tokens.tokens['foo'][0]
|
||||
f(2.)
|
||||
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
|
||||
self.assertIsNot(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
|
||||
foo_token = dispatch.runtime_tokens.tokens['foo'][0]
|
||||
foo2_token = dispatch.runtime_tokens.tokens['foo2'][0]
|
||||
g(2.)
|
||||
self.assertIsNot(foo_token, dispatch.runtime_tokens.tokens['foo'][0])
|
||||
self.assertIs(foo2_token, dispatch.runtime_tokens.tokens['foo2'][0])
|
||||
|
||||
class EffectOrderingTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_can_execute_python_callback(self):
|
||||
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
|
||||
# TODO(sharadmv): enable this test on GPU and TPU when backends are
|
||||
# supported
|
||||
if jaxlib.version < (0, 3, 8):
|
||||
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
|
||||
log = []
|
||||
def log_value(x):
|
||||
log.append(x)
|
||||
return ()
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
|
||||
|
||||
f(2.)
|
||||
self.assertListEqual(log, [2.])
|
||||
f(3.)
|
||||
self.assertListEqual(log, [2., 3.])
|
||||
dispatch.runtime_tokens.block_until_ready()
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
|
||||
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
|
||||
# TODO(sharadmv): enable this test on GPU and TPU when backends are
|
||||
# supported
|
||||
if jaxlib.version < (0, 3, 8):
|
||||
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
log = []
|
||||
def log_value(x):
|
||||
log.append(x)
|
||||
return ()
|
||||
|
||||
@functools.partial(jax.jit, device=jax.devices()[0])
|
||||
def f(x):
|
||||
# Expensive computation
|
||||
x = x.dot(x)
|
||||
x = jnp.log(x.sum())
|
||||
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
|
||||
|
||||
@functools.partial(jax.jit, device=jax.devices()[1])
|
||||
def g(x):
|
||||
return callback_p.bind(x, callback=log_value, effect='log', out_avals=[])
|
||||
|
||||
f(jnp.ones((500, 500)))
|
||||
g(3.)
|
||||
f(jnp.ones((500, 500)))
|
||||
g(3.)
|
||||
f(jnp.ones((500, 500)))
|
||||
g(3.)
|
||||
dispatch.runtime_tokens.block_until_ready()
|
||||
x_, y_ = float(jnp.log(1.25e8)), 3.
|
||||
expected_log = [x_, y_, x_, y_, x_, y_]
|
||||
self.assertListEqual(log, expected_log)
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_different_threads_get_different_tokens(self):
|
||||
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
|
||||
# TODO(sharadmv): enable this test on GPU and TPU when backends are
|
||||
# supported
|
||||
if jaxlib.version < (0, 3, 8):
|
||||
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
tokens = []
|
||||
def _noop(_):
|
||||
tokens.append(dispatch.runtime_tokens.tokens['log'][0])
|
||||
return ()
|
||||
|
||||
@functools.partial(jax.jit, device=jax.devices()[0])
|
||||
def f(x):
|
||||
return callback_p.bind(x, callback=_noop, effect='log', out_avals=[])
|
||||
|
||||
t1 = threading.Thread(target=lambda: f(2.))
|
||||
t2 = threading.Thread(target=lambda: f(3.))
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
token1, token2 = tokens
|
||||
self.assertIsNot(token1, token2)
|
||||
|
||||
|
||||
class ControlFlowEffectsTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user