diff --git a/CHANGELOG.md b/CHANGELOG.md index 4151b23ef..1c341d875 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. the `concrete` option, following the previous version's deprecation; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). * Changes - * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from functions compiled with `jax.jit` or `jax.pmap`. + * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`). ## jax 0.3.16 * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). diff --git a/jax/_src/api.py b/jax/_src/api.py index 9b391ce62..60b58d5a0 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3233,37 +3233,45 @@ def block_until_ready(x): def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, *args: Any, **kwargs: Any): - """Calls a pure Python callback function from staged out JAX programs. + """Applies a functionally pure Python callable. Works under `jit`/`pmap`/etc. ``pure_callback`` enables calling a Python function in JIT-ed JAX functions. The input ``callback`` will be passed NumPy arrays in place of JAX arrays and - should also return NumPy arrays. Execution takes place on the CPU host. + should also return NumPy arrays. Execution takes place on CPU, like any + Python+NumPy function. - The callback is treated as "pure" meaning it can be called multiple times when - transformed (for example in a ``vmap`` or ``pmap``), and it can also - potentially be removed from JAX programs via dead-code elimination. Pure - callbacks can also be reordered if data-dependence allows. + The callback is treated as functionally pure, meaning it has no side-effects + and its output value depends only on its argument values. As a consequence, it + is safe to be called multiple times (e.g. when transformed by ``vmap`` or + ``pmap``), or not to be called at all when e.g. the output of a + `jit`-decorated function has no data dependence on its value. Pure callbacks + may also be reordered if data-dependence allows. - When ``pmap``-ed, the pure callback will be called several times (one on each axis - of the map). When `vmap`-ed the behavior will depend on the value of the - ``rank_polymorphic`` keyword argument. If the callback is indicated as rank - polymorphic, the callback will be called directly on batched inputs (where the - batch axis is the leading dimension). Additionally, the callbacks should - return outputs that also have a leading batch axis. If not rank polymorphic, -``callback`` will be mapped sequentially across the batched axis. + When ``pmap``-ed, the pure callback will be called several times (one on each + axis of the map). When `vmap`-ed the behavior will depend on the value of the + ``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback + is assumed to obey + ``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``. + Therefore, the callback will be called directly on batched inputs (where the + batch axes are the leading dimensions). Additionally, the callbacks should + return outputs that have corresponding leading batch axes. If not vectorized + ``callback`` will be mapped sequentially across the batched axis. + For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free + to set ``vectorized=True`` because the ``np.matmul`` function handles + arbitrary leading batch dimensions. Args: - callback: A Python callable. The callable will be passed in NumPy arrays and - should return a PyTree of NumPy arrays that matches - ``result_shape_dtypes``. - result_shape_dtypes: A PyTree of Python objects that have ``shape`` and - ``dtype`` properties that correspond to the shape and dtypes of the - outputs of ``callback``. + callback: A Python callable. The callable will be passed PyTrees of NumPy + arrays as arguments, and should return a PyTree of NumPy arrays that + matches ``result_shape_dtypes``. + result_shape_dtypes: A PyTree with leaves that are objects with ``shape`` + and ``dtype`` attributes which represent to the shapes and dtypes of the + value of ``callback`` applied to ``args`` and ``kwargs``. *args: The positional arguments to the callback. Must be PyTrees of JAX types. - rank_polymorphic: A boolean that indicates whether or not ``callback`` is - rank polymorphic, meaning it can handle arrays with additional leading - dimensions. If ``rank_polymorphic`` is `True`, when the callback is mapped + vectorized: A boolean that indicates whether or not ``callback`` is + vectorized, meaning it can handle arrays with additional leading + dimensions. If ``vectorized`` is `True`, when the callback is mapped via `jax.vmap`, it will be called directly on inputs with leading batch dimensions instead of executing ``callback`` on each mapped input individually. The callback should also return outputs batched across the diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 84394a25f..7731eaa3e 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -36,15 +36,15 @@ map, unsafe_map = util.safe_map, map @pure_callback_p.def_impl def pure_callback_impl(*args, result_avals, callback: Callable[..., Any], - rank_polymorphic: bool): - del rank_polymorphic, result_avals + vectorized: bool): + del vectorized, result_avals return callback(*args) @pure_callback_p.def_abstract_eval def pure_callback_abstract_eval(*avals, callback: Callable[..., Any], - result_avals, rank_polymorphic: bool): - del avals, callback, rank_polymorphic + result_avals, vectorized: bool): + del avals, callback, vectorized return result_avals @@ -67,26 +67,29 @@ def pure_callback_transpose_rule(*args, **kwargs): ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule -def pure_callback_batching_rule(args, dims, *, callback, rank_polymorphic: bool, +def pure_callback_batching_rule(args, dims, *, callback, vectorized: bool, result_avals: Sequence[core.ShapedArray]): axis_size = next(a.shape[0] for a, d in zip(args, dims) if d is not batching.not_mapped) - new_args = [] - for arg, dim in zip(args, dims): - new_args.append(batching.moveaxis(arg, dim, 0)) - if rank_polymorphic: + new_args = [arg if dim is batching.not_mapped else + batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)] + if vectorized: result_avals = tuple( core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore for aval in result_avals) outvals = pure_callback_p.bind( - *new_args, callback=callback, rank_polymorphic=rank_polymorphic, + *new_args, callback=callback, vectorized=vectorized, result_avals=result_avals) else: + is_batched = [d is not batching.not_mapped for d in dims] + unbatched_args, batched_args = util.partition_list(is_batched, new_args) + def _batch_fun(*batched_args): + merged_args = util.merge_lists(is_batched, unbatched_args, batched_args) + return pure_callback_p.bind( + *merged_args, callback=callback, result_avals=result_avals, + vectorized=vectorized) from jax._src.lax.control_flow import map as lax_map - outvals = lax_map( - functools.partial(pure_callback_p.bind, callback=callback, - rank_polymorphic=rank_polymorphic, result_avals=result_avals), - *new_args) + outvals = lax_map(_batch_fun, *batched_args) return tuple(outvals), (0,) * len(outvals) @@ -115,7 +118,7 @@ mlir.register_lowering(pure_callback_p, pure_callback_lowering) def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, - *args: Any, rank_polymorphic: bool = False, **kwargs: Any): + *args: Any, vectorized: bool = False, **kwargs: Any): def _flat_callback(*flat_args): args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) return tree_util.tree_leaves(callback(*args, **kwargs)) @@ -126,5 +129,5 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, flat_result_avals, out_tree = tree_util.tree_flatten(result_avals) out_flat = pure_callback_p.bind( *flat_args, callback=_flat_callback, - result_avals=tuple(flat_result_avals), rank_polymorphic=rank_polymorphic) + result_avals=tuple(flat_result_avals), vectorized=vectorized) return tree_util.tree_unflatten(out_tree, out_flat) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 0bd23170c..4a037001b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -152,6 +152,7 @@ class RuntimeTokenSet(threading.local): token[0].block_until_ready() for token in self.output_runtime_tokens.values(): token.block_until_ready() + self.clear() runtime_tokens: RuntimeTokenSet = RuntimeTokenSet() @@ -711,7 +712,7 @@ def _check_special(name, xla_shape, buf): raise FloatingPointError(f"invalid value (inf) encountered in {name}") def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect], - device: Device, input_bufs): + has_host_callbacks: bool, device: Device, input_bufs): tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects] tokens_flat = flatten(tokens) input_bufs = [*tokens_flat, *input_bufs] @@ -720,7 +721,7 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect], num_output_tokens = len(ordered_effects) + (not can_execute_with_token and has_unordered_effects) token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens]) - if has_unordered_effects: + if has_unordered_effects or has_host_callbacks: if can_execute_with_token: runtime_tokens.set_output_runtime_token(device, runtime_token) else: @@ -738,14 +739,15 @@ def _execute_compiled(name: str, compiled: XlaExecutable, result_handler: Callable, has_unordered_effects: bool, ordered_effects: List[core.Effect], - kept_var_idx, host_callbacks, *args): + kept_var_idx, has_host_callbacks: bool, *args): device, = compiled.local_devices() args, env = input_handler(args) if input_handler else (args, None) in_flat = flatten(device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx) - if has_unordered_effects or ordered_effects: + if has_unordered_effects or ordered_effects or has_host_callbacks: in_flat, token_handler = _add_tokens( - has_unordered_effects, ordered_effects, device, in_flat) + has_unordered_effects, ordered_effects, has_host_callbacks, device, + in_flat) if can_execute_with_token: out_flat, runtime_token = compiled.execute_with_token(in_flat) else: @@ -755,7 +757,7 @@ def _execute_compiled(name: str, compiled: XlaExecutable, out_flat = compiled.execute(in_flat) check_special(name, out_flat) out_bufs = unflatten(out_flat, output_buffer_counts) - if ordered_effects or has_unordered_effects: + if ordered_effects or has_unordered_effects or has_host_callbacks: out_bufs = token_handler(out_bufs, runtime_token) return result_handler(env, out_bufs) @@ -766,7 +768,7 @@ def _execute_replicated(name: str, compiled: XlaExecutable, result_handler: Callable, has_unordered_effects: bool, ordered_effects: List[core.Effect], - kept_var_idx, host_callbacks, *args): + kept_var_idx, has_host_callbacks: bool, *args): if has_unordered_effects or ordered_effects: # TODO(sharadmv): support jit-of-pmap with effects raise NotImplementedError( @@ -965,7 +967,7 @@ class XlaCompiledComputation(stages.XlaExecutable): execute = _execute_compiled if nreps == 1 else _execute_replicated unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811 result_handler, has_unordered_effects, - ordered_effects, kept_var_idx, host_callbacks) + ordered_effects, kept_var_idx, bool(host_callbacks)) return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call, keepalive) @@ -988,7 +990,7 @@ class XlaCompiledComputation(stages.XlaExecutable): result_handlers = map(partial(aval_to_result_handler, device), out_avals) unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers, has_unordered_effects, - ordered_effects, kept_var_idx, host_callbacks) + ordered_effects, kept_var_idx, bool(host_callbacks)) return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call, keepalive) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 9e7e0bdca..1856e0459 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +import functools import io import textwrap import unittest @@ -25,6 +26,7 @@ from jax import core from jax import lax from jax import tree_util from jax._src import debugging +from jax._src import dispatch from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src import util @@ -459,6 +461,10 @@ class PythonCallbackTest(jtu.JaxTestCase): class PurePythonCallbackTest(jtu.JaxTestCase): + def tearDown(self): + super().tearDown() + dispatch.runtime_tokens.clear() + @jtu.skip_on_devices(*disabled_backends) def test_simple_pure_callback(self): @@ -504,8 +510,8 @@ class PurePythonCallbackTest(jtu.JaxTestCase): @jax.jit def f(x): - # Calling a function with a return value that expects no return values - return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x) + # Calling a function with two return values that expects one return value + return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x, x) with self.assertRaises(RuntimeError): f(2.) @@ -513,8 +519,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase): @jax.jit def g(): - # Calling a function with a return value that expects no return values - return jax.pure_callback(lambda: None, ( + return jax.pure_callback(lambda: (), ( core.ShapedArray((1,), np.float32), core.ShapedArray((2,), np.float32))) with self.assertRaises(RuntimeError): @@ -564,7 +569,14 @@ class PurePythonCallbackTest(jtu.JaxTestCase): out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2))) np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T) - def test_vmap_rank_polymorphic_callback(self): + @jax.jit + @functools.partial(jax.vmap, in_axes=(0, None)) + def h(x, y): + return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y) + out = h(jnp.arange(4.), 4.) + np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.) + + def test_vmap_vectorized_callback(self): def cb(x): self.assertTupleEqual(x.shape, ()) @@ -575,7 +587,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase): def f(x): return jax.pure_callback(cb, x, x) - _ = f(jnp.arange(4.)) + np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.))) def cb2(x): self.assertTupleEqual(x.shape, (4,)) @@ -585,11 +597,19 @@ class PurePythonCallbackTest(jtu.JaxTestCase): @jax.jit @jax.vmap def g(x): - return jax.pure_callback(cb2, x, x, rank_polymorphic=True) + return jax.pure_callback(cb2, x, x, vectorized=True) - _ = g(jnp.arange(4.)) + np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.))) - def test_vmap_rank_polymorphic_callback_errors_if_returns_wrong_shape(self): + @jax.jit + @functools.partial(jax.vmap, in_axes=(0, None)) + def h(x, y): + return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y, + vectorized=True) + out = h(jnp.arange(4.), 4.) + np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.) + + def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self): def cb(x): # Reduces over all dimension when it shouldn't @@ -598,7 +618,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase): @jax.jit @jax.vmap def f(x): - return jax.pure_callback(cb, x, x, rank_polymorphic=True) + return jax.pure_callback(cb, x, x, vectorized=True) with self.assertRaises(RuntimeError): f(jnp.arange(4.)) @@ -743,13 +763,13 @@ class PurePythonCallbackTest(jtu.JaxTestCase): np.testing.assert_allclose(out, jnp.arange(1., 41.)) @jtu.skip_on_devices(*disabled_backends) - def test_rank_polymorphic_callback_inside_xmap(self): + def test_vectorized_callback_inside_xmap(self): def _callback(x): return (x + 1.).astype(x.dtype) def f(x): - return jax.pure_callback(_callback, x, x, rank_polymorphic=True) + return jax.pure_callback(_callback, x, x, vectorized=True) f = maps.xmap(f, in_axes=['a'], out_axes=['a'], axis_resources={'a': 'dev'})