Add IO callback

This commit is contained in:
Sharad Vikram 2022-11-10 12:00:21 -08:00
parent a58e59d98f
commit 3de5c2b716
9 changed files with 294 additions and 17 deletions

View File

@ -3347,7 +3347,7 @@ def block_until_ready(x):
return jax.tree_util.tree_map(try_to_block, x)
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, **kwargs: Any):
*args: Any, vectorized: bool = False, **kwargs: Any):
"""Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
@ -3390,14 +3390,15 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
via `jax.vmap`, it will be called directly on inputs with leading batch
dimensions instead of executing ``callback`` on each mapped input
individually. The callback should also return outputs batched across the
leading axis.
leading axis. By default, ``vectorized`` is ``False``.
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
types.
Returns:
The value of ``callback(*args, **kwargs)``.
"""
return jcb.pure_callback(callback, result_shape_dtypes, *args, **kwargs)
return jcb.pure_callback(callback, result_shape_dtypes, *args,
vectorized=vectorized, **kwargs)
def clear_backends():
"""

View File

@ -23,11 +23,11 @@ from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src import dispatch
from jax._src.lib import xla_client as xc
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
import numpy as np
from jax._src.lib import xla_client as xc
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
pure_callback_p = core.Primitive("pure_callback")
@ -150,3 +150,106 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*flat_args, callback=_flat_callback,
result_avals=tuple(flat_result_avals), vectorized=vectorized)
return tree_util.tree_unflatten(out_tree, out_flat)
# IO Callback
io_callback_p = core.Primitive("io_callback")
io_callback_p.multiple_results = True
class IOEffect:
__str__ = lambda _: "IO"
class OrderedIOEffect:
__str__ = lambda _: "OrderedIO"
_IOEffect = IOEffect()
_OrderedIOEffect = OrderedIOEffect()
mlir.lowerable_effects.add(_IOEffect)
mlir.lowerable_effects.add(_OrderedIOEffect)
core.control_flow_allowed_effects.add(_IOEffect)
core.control_flow_allowed_effects.add(_OrderedIOEffect)
core.ordered_effects.add(_OrderedIOEffect)
def io_callback_impl(*args, result_avals, callback: Callable[..., Any],
ordered: bool):
del result_avals, ordered
return callback(*args)
io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
io_callback_p))
@io_callback_p.def_effectful_abstract_eval
def io_callback_abstract_eval(*avals, callback: Callable[..., Any],
result_avals, ordered: bool):
del avals, callback
effect = _OrderedIOEffect if ordered else _IOEffect
return result_avals, {effect}
def io_callback_jvp_rule(*args, **kwargs):
del args, kwargs
raise ValueError("IO callbacks do not support JVP.")
ad.primitive_jvps[io_callback_p] = io_callback_jvp_rule
def io_callback_transpose_rule(*args, **kwargs):
del args, kwargs
raise ValueError("IO callbacks do not support transpose.")
ad.primitive_transposes[io_callback_p] = io_callback_transpose_rule
def io_callback_batching_rule(args, dims, callback, result_avals, ordered):
if ordered:
raise ValueError("Cannot `vmap` ordered IO callback.")
return pure_callback_batching_rule(args, dims, callback=callback,
vectorized=False, result_avals=result_avals)
batching.primitive_batchers[io_callback_p] = io_callback_batching_rule
def io_callback_lowering(ctx, *args, callback, ordered, **params):
def _callback(*flat_args):
return tuple(io_callback_impl(*flat_args, callback=callback,
ordered=ordered, **params))
# TODO(sharadmv): figure out the best API for sharding callbacks. For now, we
# can only safely maximally shard. Should we allow device_index to be passed
# in like host_callback?
if isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
# Apply maximal sharding so pjit only executes the callback on device 0.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
else:
sharding = None
if ordered:
token = ctx.tokens_in.get(_OrderedIOEffect)[0]
result, token, keepalive = mlir.emit_python_callback(
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True,
sharding=sharding)
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)}))
else:
result, token, keepalive = mlir.emit_python_callback(
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True,
sharding=sharding)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(io_callback_p, io_callback_lowering)
def io_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, ordered: bool = False, **kwargs: Any):
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
return tree_util.tree_leaves(callback(*args, **kwargs))
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype),
flat_shape_dtypes)
flat_args = map(core.raise_as_much_as_possible, flat_args)
out_flat = io_callback_p.bind(
*flat_args, callback=_flat_callback,
result_avals=tuple(flat_result_avals),
ordered=ordered)
return tree_util.tree_unflatten(out_tree, out_flat)

View File

@ -63,6 +63,7 @@ Effect = Hashable
Effects = Set[Effect]
no_effects: Effects = set()
ordered_effects: Set[Effect] = set()
control_flow_allowed_effects: Set[Effect] = set()
class Jaxpr:

View File

@ -15,21 +15,21 @@
import os
from functools import partial
from typing import Callable, Optional, Sequence, Set
from typing import Callable, Optional, Sequence
from jax import core
from jax._src import core
from jax._src import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src.core import control_flow_allowed_effects as allowed_effects
from jax._src.lax import lax
from jax._src import ad_util
from jax._src import util
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_map, tree_unflatten
map, unsafe_map = safe_map, map
allowed_effects: Set[core.Effect] = set()
allowed_effects.add(lax.InOutFeedEffect.Infeed)
allowed_effects.add(lax.InOutFeedEffect.Outfeed)

View File

@ -337,10 +337,17 @@ def _bcast_select_n(pred, *cases):
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
index, *ops = args
index_dim, *op_dims = dims
# TODO(sharadmv): clean this up by adding a specific blocklist
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError(
"State effect not supported in cond vmap.")
"State effect not supported in vmap-of-cond.")
from jax._src.callback import _IOEffect, _OrderedIOEffect
if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect]
for branch in branches):
raise NotImplementedError(
"IO effect not supported in vmap-of-cond.")
if index_dim is not batching.not_mapped:
# Convert to a lax.select. While we could get away with not broadcasting

View File

@ -1143,6 +1143,12 @@ def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
def _while_loop_batching_rule(axis_size, axis_name, main_type, args, dims,
cond_nconsts, cond_jaxpr,
body_nconsts, body_jaxpr):
from jax._src.callback import _IOEffect, _OrderedIOEffect
if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect]
for branch in [body_jaxpr, cond_jaxpr]):
raise NotImplementedError(
"IO effect not supported in vmap-of-while.")
orig_batched = [d is not batching.not_mapped for d in dims]
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts])

View File

@ -22,3 +22,6 @@ from jax.experimental.x64_context import (
enable_x64 as enable_x64,
disable_x64 as disable_x64,
)
from jax._src.callback import (
io_callback as io_callback
)

View File

@ -972,6 +972,9 @@ jax_test(
jax_test(
name = "python_callback_test",
srcs = ["python_callback_test.py"],
deps = [
"//jax:experimental",
],
)
jax_test(

View File

@ -24,17 +24,18 @@ from jax import lax
from jax import tree_util
from jax._src import debugging
from jax._src import dispatch
from jax._src import sharding
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec as P
from jax.experimental.pjit import pjit
from jax.experimental.maps import xmap
from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import mlir
from jax.experimental.maps import Mesh
from jax.experimental.maps import xmap
from jax.experimental import io_callback
import jax.numpy as jnp
import numpy as np
@ -682,7 +683,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
try:
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = P('x')
spec = pjit.PartitionSpec('x')
def f(x):
axis_resources = {v: v for v in mesh.axis_names}
@ -699,7 +700,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
with mesh:
inp = jnp.arange(float(jax.local_device_count()))
out = pjit(f, in_axis_resources=spec, out_axis_resources=spec)(inp)
out = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)(inp)
np.testing.assert_allclose(
out, np.sin(np.arange(jax.local_device_count()))
)
@ -708,7 +709,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
NotImplementedError, 'when all mesh axes are partitioned manually'
):
pjit(
pjit.pjit(
without_xmap_f, in_axis_resources=spec, out_axis_resources=spec
)(inp)
@ -879,5 +880,157 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
x = np.arange(6, dtype=np.int32).reshape((3, 2))
np.testing.assert_allclose(g(x), x)
class IOPythonCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
def test_io_callback_can_mutate_state(self):
x = 0
def cb():
nonlocal x
x += 1
return np.array(x, np.int32)
def f():
return io_callback(cb, jax.ShapeDtypeStruct((), jnp.int32))
f()
jax.effects_barrier()
self.assertEqual(x, 1)
f()
jax.effects_barrier()
self.assertEqual(x, 2)
def test_io_callback_can_be_batched_if_unordered(self):
_mut = 0
def cb(x):
nonlocal _mut
_mut += 1
return x
x = jnp.arange(4)
def f(x):
return io_callback(cb, jax.ShapeDtypeStruct((), x.dtype), x)
jax.vmap(f)(x)
jax.effects_barrier()
self.assertEqual(_mut, 4)
jax.vmap(f)(x)
jax.effects_barrier()
self.assertEqual(_mut, 8)
def test_cannot_call_ordered_io_in_pmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Ordered effects not supported in `pmap`"):
jax.pmap(f)(jnp.arange(jax.local_device_count()))
def test_cannot_call_ordered_io_in_xmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Cannot `vmap` ordered IO callback"):
maps.xmap(f, in_axes=([0],), out_axes=[0])(jnp.arange(16))
def test_cannot_call_ordered_io_in_vmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Cannot `vmap` ordered IO callback"):
jax.vmap(f)(jnp.arange(4))
def test_cannot_use_io_callback_in_jvp(self):
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support JVP."):
jax.jvp(f, (0.,), (1.,))
def test_cannot_use_io_callback_in_linearize(self):
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support JVP."):
jax.linearize(f, 0.)
def test_cannot_use_io_callback_in_transpose(self):
x = jnp.array(1.)
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), x.dtype), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support transpose."):
jax.linear_transpose(f, x)(x)
def test_cannot_vmap_of_cond_io_callback(self):
def f(pred):
def true_fun():
io_callback(lambda: print("true"), None)
def false_fun():
io_callback(lambda: print("false"), None)
return lax.cond(pred, false_fun, true_fun)
with self.assertRaisesRegex(NotImplementedError,
"IO effect not supported in vmap-of-cond."):
jax.vmap(f)(jnp.array([True, True]))
def test_cannot_vmap_of_while_io_callback(self):
def check(x):
assert np.all(x < 5)
def f(i):
def cond(i):
return i < 5
def body(i):
io_callback(check, None, i)
return i + 1
return lax.while_loop(cond, body, i)
with self.assertRaisesRegex(NotImplementedError,
"IO effect not supported in vmap-of-while."):
jax.vmap(f)(jnp.array([0, 4]))
def test_cannot_use_io_callback_in_checkpoint(self):
@jax.grad
@jax.checkpoint
def f(x, y):
io_callback(lambda x: x, y, y)
return x
with self.assertRaisesRegex(NotImplementedError,
"Effects not supported in partial-eval of `checkpoint`"):
f(2., 3.)
def test_can_use_io_callback_in_pjit(self):
_mut = 0
def _cb(x):
nonlocal _mut
_mut = x.sum()
def f(x):
io_callback(_cb, None, x)
return x
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
else:
spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
with mesh:
f(jnp.arange(mesh.size))
jax.effects_barrier()
self.assertEqual(_mut, jnp.arange(mesh.size).sum())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())