From 3de5c2b7165e7e4374089fc5a6824cc9c20390d2 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 10 Nov 2022 12:00:21 -0800 Subject: [PATCH] Add IO callback --- jax/_src/api.py | 7 +- jax/_src/callback.py | 105 +++++++++++++- jax/_src/core.py | 1 + jax/_src/lax/control_flow/common.py | 10 +- jax/_src/lax/control_flow/conditionals.py | 9 +- jax/_src/lax/control_flow/loops.py | 6 + jax/experimental/__init__.py | 3 + tests/BUILD | 3 + tests/python_callback_test.py | 167 +++++++++++++++++++++- 9 files changed, 294 insertions(+), 17 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index e909759d1..16bc42a0d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3347,7 +3347,7 @@ def block_until_ready(x): return jax.tree_util.tree_map(try_to_block, x) def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, - *args: Any, **kwargs: Any): + *args: Any, vectorized: bool = False, **kwargs: Any): """Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc. ``pure_callback`` enables calling a Python function in JIT-ed JAX functions. @@ -3390,14 +3390,15 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, via `jax.vmap`, it will be called directly on inputs with leading batch dimensions instead of executing ``callback`` on each mapped input individually. The callback should also return outputs batched across the - leading axis. + leading axis. By default, ``vectorized`` is ``False``. **kwargs: The keyword arguments to the callback. Must be PyTrees of JAX types. Returns: The value of ``callback(*args, **kwargs)``. """ - return jcb.pure_callback(callback, result_shape_dtypes, *args, **kwargs) + return jcb.pure_callback(callback, result_shape_dtypes, *args, + vectorized=vectorized, **kwargs) def clear_backends(): """ diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 286241a6a..7f551045a 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -23,11 +23,11 @@ from jax._src import core from jax._src import dtypes from jax._src import util from jax._src import dispatch +from jax._src.lib import xla_client as xc from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir import numpy as np -from jax._src.lib import xla_client as xc # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") @@ -150,3 +150,106 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, *flat_args, callback=_flat_callback, result_avals=tuple(flat_result_avals), vectorized=vectorized) return tree_util.tree_unflatten(out_tree, out_flat) + + +# IO Callback + +io_callback_p = core.Primitive("io_callback") +io_callback_p.multiple_results = True + +class IOEffect: + __str__ = lambda _: "IO" +class OrderedIOEffect: + __str__ = lambda _: "OrderedIO" +_IOEffect = IOEffect() +_OrderedIOEffect = OrderedIOEffect() +mlir.lowerable_effects.add(_IOEffect) +mlir.lowerable_effects.add(_OrderedIOEffect) +core.control_flow_allowed_effects.add(_IOEffect) +core.control_flow_allowed_effects.add(_OrderedIOEffect) +core.ordered_effects.add(_OrderedIOEffect) + + +def io_callback_impl(*args, result_avals, callback: Callable[..., Any], + ordered: bool): + del result_avals, ordered + return callback(*args) +io_callback_p.def_impl(functools.partial(dispatch.apply_primitive, + io_callback_p)) + +@io_callback_p.def_effectful_abstract_eval +def io_callback_abstract_eval(*avals, callback: Callable[..., Any], + result_avals, ordered: bool): + del avals, callback + effect = _OrderedIOEffect if ordered else _IOEffect + return result_avals, {effect} + +def io_callback_jvp_rule(*args, **kwargs): + del args, kwargs + raise ValueError("IO callbacks do not support JVP.") +ad.primitive_jvps[io_callback_p] = io_callback_jvp_rule + + +def io_callback_transpose_rule(*args, **kwargs): + del args, kwargs + raise ValueError("IO callbacks do not support transpose.") +ad.primitive_transposes[io_callback_p] = io_callback_transpose_rule + + +def io_callback_batching_rule(args, dims, callback, result_avals, ordered): + if ordered: + raise ValueError("Cannot `vmap` ordered IO callback.") + return pure_callback_batching_rule(args, dims, callback=callback, + vectorized=False, result_avals=result_avals) +batching.primitive_batchers[io_callback_p] = io_callback_batching_rule + +def io_callback_lowering(ctx, *args, callback, ordered, **params): + + def _callback(*flat_args): + return tuple(io_callback_impl(*flat_args, callback=callback, + ordered=ordered, **params)) + + # TODO(sharadmv): figure out the best API for sharding callbacks. For now, we + # can only safely maximally shard. Should we allow device_index to be passed + # in like host_callback? + if isinstance(ctx.module_context.axis_context, + (mlir.SPMDAxisContext, mlir.ShardingContext)): + # Apply maximal sharding so pjit only executes the callback on device 0. + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MAXIMAL + sharding.tile_assignment_dimensions = [1] + sharding.tile_assignment_devices = [0] + else: + sharding = None + + if ordered: + token = ctx.tokens_in.get(_OrderedIOEffect)[0] + result, token, keepalive = mlir.emit_python_callback( + ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True, + sharding=sharding) + ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)})) + else: + result, token, keepalive = mlir.emit_python_callback( + ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True, + sharding=sharding) + ctx.module_context.add_keepalive(keepalive) + return result +mlir.register_lowering(io_callback_p, io_callback_lowering) + +def io_callback(callback: Callable[..., Any], result_shape_dtypes: Any, + *args: Any, ordered: bool = False, **kwargs: Any): + def _flat_callback(*flat_args): + args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) + return tree_util.tree_leaves(callback(*args, **kwargs)) + + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) + flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) + flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype), + flat_shape_dtypes) + flat_args = map(core.raise_as_much_as_possible, flat_args) + out_flat = io_callback_p.bind( + *flat_args, callback=_flat_callback, + result_avals=tuple(flat_result_avals), + ordered=ordered) + return tree_util.tree_unflatten(out_tree, out_flat) diff --git a/jax/_src/core.py b/jax/_src/core.py index 86f4263e4..d9a60f33b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -63,6 +63,7 @@ Effect = Hashable Effects = Set[Effect] no_effects: Effects = set() ordered_effects: Set[Effect] = set() +control_flow_allowed_effects: Set[Effect] = set() class Jaxpr: diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index f5f183f30..d8e15442d 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -15,21 +15,21 @@ import os from functools import partial -from typing import Callable, Optional, Sequence, Set +from typing import Callable, Optional, Sequence -from jax import core +from jax._src import core from jax._src import linear_util as lu -from jax.api_util import flatten_fun_nokwargs -from jax.interpreters import partial_eval as pe +from jax._src.core import control_flow_allowed_effects as allowed_effects from jax._src.lax import lax from jax._src import ad_util from jax._src import util from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3 +from jax.api_util import flatten_fun_nokwargs +from jax.interpreters import partial_eval as pe from jax.tree_util import tree_map, tree_unflatten map, unsafe_map = safe_map, map -allowed_effects: Set[core.Effect] = set() allowed_effects.add(lax.InOutFeedEffect.Infeed) allowed_effects.add(lax.InOutFeedEffect.Outfeed) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 479bc5440..be7c1a659 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -337,10 +337,17 @@ def _bcast_select_n(pred, *cases): def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear): index, *ops = args index_dim, *op_dims = dims + # TODO(sharadmv): clean this up by adding a specific blocklist if any(isinstance(eff, state.RefEffect) for branch in branches for eff in branch.jaxpr.effects): raise NotImplementedError( - "State effect not supported in cond vmap.") + "State effect not supported in vmap-of-cond.") + from jax._src.callback import _IOEffect, _OrderedIOEffect + if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect] + for branch in branches): + raise NotImplementedError( + "IO effect not supported in vmap-of-cond.") + if index_dim is not batching.not_mapped: # Convert to a lax.select. While we could get away with not broadcasting diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index cbdb9c9de..8e1212bd1 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1143,6 +1143,12 @@ def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs): def _while_loop_batching_rule(axis_size, axis_name, main_type, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): + from jax._src.callback import _IOEffect, _OrderedIOEffect + if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect] + for branch in [body_jaxpr, cond_jaxpr]): + raise NotImplementedError( + "IO effect not supported in vmap-of-while.") + orig_batched = [d is not batching.not_mapped for d in dims] cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts]) diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 29e37f079..da89aaf3a 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -22,3 +22,6 @@ from jax.experimental.x64_context import ( enable_x64 as enable_x64, disable_x64 as disable_x64, ) +from jax._src.callback import ( + io_callback as io_callback +) diff --git a/tests/BUILD b/tests/BUILD index 214d6753f..6182cfb8c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -972,6 +972,9 @@ jax_test( jax_test( name = "python_callback_test", srcs = ["python_callback_test.py"], + deps = [ + "//jax:experimental", + ], ) jax_test( diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index f2c609b05..81c14855b 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -24,17 +24,18 @@ from jax import lax from jax import tree_util from jax._src import debugging from jax._src import dispatch +from jax._src import sharding from jax._src import test_util as jtu from jax._src import util from jax._src.lib import xla_bridge from jax._src.lib import xla_client -from jax.experimental.maps import Mesh -from jax.experimental.pjit import PartitionSpec as P -from jax.experimental.pjit import pjit -from jax.experimental.maps import xmap from jax.config import config from jax.experimental import maps +from jax.experimental import pjit from jax.interpreters import mlir +from jax.experimental.maps import Mesh +from jax.experimental.maps import xmap +from jax.experimental import io_callback import jax.numpy as jnp import numpy as np @@ -682,7 +683,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase): try: mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) - spec = P('x') + spec = pjit.PartitionSpec('x') def f(x): axis_resources = {v: v for v in mesh.axis_names} @@ -699,7 +700,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase): with mesh: inp = jnp.arange(float(jax.local_device_count())) - out = pjit(f, in_axis_resources=spec, out_axis_resources=spec)(inp) + out = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)(inp) np.testing.assert_allclose( out, np.sin(np.arange(jax.local_device_count())) ) @@ -708,7 +709,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase): with self.assertRaisesRegex( NotImplementedError, 'when all mesh axes are partitioned manually' ): - pjit( + pjit.pjit( without_xmap_f, in_axis_resources=spec, out_axis_resources=spec )(inp) @@ -879,5 +880,157 @@ class PurePythonCallbackTest(jtu.JaxTestCase): x = np.arange(6, dtype=np.int32).reshape((3, 2)) np.testing.assert_allclose(g(x), x) +class IOPythonCallbackTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if xla_bridge.get_backend().runtime_type == 'stream_executor': + raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') + + def tearDown(self): + super().tearDown() + dispatch.runtime_tokens.clear() + + def test_io_callback_can_mutate_state(self): + x = 0 + def cb(): + nonlocal x + x += 1 + return np.array(x, np.int32) + + def f(): + return io_callback(cb, jax.ShapeDtypeStruct((), jnp.int32)) + f() + jax.effects_barrier() + self.assertEqual(x, 1) + f() + jax.effects_barrier() + self.assertEqual(x, 2) + + def test_io_callback_can_be_batched_if_unordered(self): + _mut = 0 + def cb(x): + nonlocal _mut + _mut += 1 + return x + + x = jnp.arange(4) + def f(x): + return io_callback(cb, jax.ShapeDtypeStruct((), x.dtype), x) + jax.vmap(f)(x) + jax.effects_barrier() + self.assertEqual(_mut, 4) + jax.vmap(f)(x) + jax.effects_barrier() + self.assertEqual(_mut, 8) + + def test_cannot_call_ordered_io_in_pmap(self): + def f(x): + return io_callback( + lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True) + with self.assertRaisesRegex( + ValueError, "Ordered effects not supported in `pmap`"): + jax.pmap(f)(jnp.arange(jax.local_device_count())) + + def test_cannot_call_ordered_io_in_xmap(self): + def f(x): + return io_callback( + lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True) + with self.assertRaisesRegex( + ValueError, "Cannot `vmap` ordered IO callback"): + maps.xmap(f, in_axes=([0],), out_axes=[0])(jnp.arange(16)) + + def test_cannot_call_ordered_io_in_vmap(self): + def f(x): + return io_callback( + lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True) + with self.assertRaisesRegex( + ValueError, "Cannot `vmap` ordered IO callback"): + jax.vmap(f)(jnp.arange(4)) + + def test_cannot_use_io_callback_in_jvp(self): + def f(x): + return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x) + with self.assertRaisesRegex( + ValueError, "IO callbacks do not support JVP."): + jax.jvp(f, (0.,), (1.,)) + + def test_cannot_use_io_callback_in_linearize(self): + def f(x): + return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x) + with self.assertRaisesRegex( + ValueError, "IO callbacks do not support JVP."): + jax.linearize(f, 0.) + + def test_cannot_use_io_callback_in_transpose(self): + x = jnp.array(1.) + + def f(x): + return io_callback(lambda x: x, jax.ShapeDtypeStruct((), x.dtype), x) + with self.assertRaisesRegex( + ValueError, "IO callbacks do not support transpose."): + jax.linear_transpose(f, x)(x) + + def test_cannot_vmap_of_cond_io_callback(self): + def f(pred): + def true_fun(): + io_callback(lambda: print("true"), None) + def false_fun(): + io_callback(lambda: print("false"), None) + return lax.cond(pred, false_fun, true_fun) + with self.assertRaisesRegex(NotImplementedError, + "IO effect not supported in vmap-of-cond."): + jax.vmap(f)(jnp.array([True, True])) + + def test_cannot_vmap_of_while_io_callback(self): + def check(x): + assert np.all(x < 5) + + def f(i): + def cond(i): + return i < 5 + def body(i): + io_callback(check, None, i) + return i + 1 + return lax.while_loop(cond, body, i) + with self.assertRaisesRegex(NotImplementedError, + "IO effect not supported in vmap-of-while."): + jax.vmap(f)(jnp.array([0, 4])) + + def test_cannot_use_io_callback_in_checkpoint(self): + @jax.grad + @jax.checkpoint + def f(x, y): + io_callback(lambda x: x, y, y) + return x + + with self.assertRaisesRegex(NotImplementedError, + "Effects not supported in partial-eval of `checkpoint`"): + f(2., 3.) + + def test_can_use_io_callback_in_pjit(self): + + _mut = 0 + def _cb(x): + nonlocal _mut + _mut = x.sum() + + def f(x): + io_callback(_cb, None, x) + return x + + mesh = maps.Mesh(np.array(jax.devices()), ['dev']) + if config.jax_array: + spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev')) + out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec()) + else: + spec = pjit.PartitionSpec('dev') + out_spec = pjit.PartitionSpec() + f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec) + with mesh: + f(jnp.arange(mesh.size)) + jax.effects_barrier() + self.assertEqual(_mut, jnp.arange(mesh.size).sum()) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())