diff --git a/docs/external-callbacks.md b/docs/external-callbacks.md index 3c17e7547..7089e17a1 100644 --- a/docs/external-callbacks.md +++ b/docs/external-callbacks.md @@ -24,7 +24,7 @@ kernelspec: -This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`. +This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are `jax.pure_callback`, `jax.experimental.io_callback` and `jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`. ## Why callbacks? @@ -66,8 +66,11 @@ This works by passing the runtime value of `y` as a CPU {class}`jax.Array` back In earlier versions of JAX, there was only one kind of callback available, implemented in {func}`jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations: - {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effects. + See {ref}`external-callbacks-exploring-pure-callback`. - {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk. + See {ref}`external-callbacks-exploring-io-callback`. - {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler. + See {ref}`external-callbacks-exploring-debug-callback`. (The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`). @@ -85,6 +88,7 @@ From the user perspective, these three flavors of callback are mainly distinguis ³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases. +(external-callbacks-exploring-pure-callback)= ### Exploring `pure_callback` {func}`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.). @@ -163,6 +167,41 @@ f2(); In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output. In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects. +#### `pure_callback` and exceptions + +In the context of JAX transformations, Python runtime exceptions should be considered side-effects: +this means that intentionally raising an error within a `pure_callback` breaks the API contract, +and the behavior of the resulting program is undefined. In particular, the manner in which +such a program halts will generally depend on the backend, and the details of that behavior may +change in future releases. + +Additionally, passing impure functions to `pure_callback` may result in unexpected behavior during +transformations like {func}`jax.jit` or {func}`jax.vmap`, because the transformation rules for +`pure_callback` are defined under the assumption that the callback function is pure. Here's one +simple example of an impure callback behaving unexpectedly under `vmap`: +```python +import jax +import jax.numpy as jnp + +def raise_via_callback(x): + def _raise(x): + raise ValueError(f"value of x is {x}") + return jax.pure_callback(_raise, x, x) + +def raise_if_negative(x): + return jax.lax.cond(x < 0, raise_via_callback, lambda x: x, x) + +x_batch = jnp.arange(4) + +[raise_if_negative(x) for x in x_batch] # does not raise + +jax.vmap(raise_if_negative)(x_batch) # ValueError: value of x is 0 +``` +To avoid this and similar unexpected behavior, we recommend not attempting to use +`pure_callback` to raise runtime errors. + + +(external-callbacks-exploring-io-callback)= ### Exploring `io_callback` In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects. diff --git a/jax/_src/callback.py b/jax/_src/callback.py index a91cc24f9..d5b8f145b 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -359,6 +359,13 @@ def pure_callback( `jit`-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows. + .. warning:: + + In the context of JAX transformations, Python exceptions should be + considered side-effects: this means that intentionally raising an error + within a `pure_callback` breaks the API contract, and the behavior of + the resulting program is undefined. + When `vmap`-ed the behavior will depend on the value of the ``vmap_method``. * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` @@ -440,7 +447,7 @@ def pure_callback( (4,) (4,) Array([1., 2., 3., 4.], dtype=float32) - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://jax.readthedocs.io/en/latest/external-callbacks.html """ if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: deprecations.warn(