diff --git a/docs/external-callbacks.md b/docs/external-callbacks.md index 7089e17a1..85ec64670 100644 --- a/docs/external-callbacks.md +++ b/docs/external-callbacks.md @@ -74,7 +74,7 @@ In earlier versions of JAX, there was only one kind of callback available, imple (The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`). -From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. All three of them must **not** include any calls back into JAX. +From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. |callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution | |-------------------------------------|----|----|----|----|----|----| diff --git a/jax/_src/callback.py b/jax/_src/callback.py index d5b8f145b..3aadb478e 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -349,8 +349,7 @@ def pure_callback( ``pure_callback`` enables calling a Python function in JIT-ed JAX functions. The input ``callback`` will be passed JAX arrays placed on a local CPU, and - it should also return JAX arrays on CPU. The ``callback`` function must not - include any calls back into JAX. + it should also return JAX arrays on CPU. The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it @@ -389,9 +388,8 @@ def pure_callback( Args: callback: function to execute on the host. The callback is assumed to be a pure function (i.e. one without side-effects): if an impure function is passed, it - may behave in unexpected ways, particularly under transformation. - Furthermore, the callback must not call into JAX. The callable will - be passed PyTrees of arrays as arguments, and should return a PyTree of + may behave in unexpected ways, particularly under transformation. The callable + will be passed PyTrees of arrays as arguments, and should return a PyTree of arrays that matches ``result_shape_dtypes``. result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes, whose structure matches the expected output of the callback function at runtime. @@ -630,15 +628,14 @@ def io_callback( ordered: bool = False, **kwargs: Any, ): - """Calls an impure Python callback. The callback function must not include any - calls back into JAX. + """Calls an impure Python callback. For more explanation, see `External Callbacks`_. Args: callback: function to execute on the host. It is assumed to be an impure function. If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to - more efficient execution. The ``callback`` must not call into JAX. + more efficient execution. result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes, whose structure matches the expected output of the callback function at runtime. :class:`jax.ShapeDtypeStruct` is often used to define leaf values. diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index fab9ab296..7685ac2bf 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -249,11 +249,8 @@ def debug_callback(callback: Callable[..., None], *args: Any, possible while revealing as much about them as possible, such as which parts of the computation are duplicated or dropped. - Inside of the ``callback`` function there should not be a call back into JAX. - Args: - callback: A Python callable returning None. The ``callback`` must not call - into JAX. + callback: A Python callable returning None. *args: The positional arguments to the callback. ordered: A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this callback w.r.t.