Revert https://github.com/jax-ml/jax/pull/25982 since callbacks can now use JAX functions.

This commit is contained in:
Dan Foreman-Mackey 2025-01-29 11:12:32 -05:00
parent 8720e95570
commit e2eff1f8d5
3 changed files with 7 additions and 13 deletions

View File

@ -74,7 +74,7 @@ In earlier versions of JAX, there was only one kind of callback available, imple
(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. All three of them must **not** include any calls back into JAX.
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|-------------------------------------|----|----|----|----|----|----|

View File

@ -349,8 +349,7 @@ def pure_callback(
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
it should also return JAX arrays on CPU. The ``callback`` function must not
include any calls back into JAX.
it should also return JAX arrays on CPU.
The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
@ -389,9 +388,8 @@ def pure_callback(
Args:
callback: function to execute on the host. The callback is assumed to be a pure
function (i.e. one without side-effects): if an impure function is passed, it
may behave in unexpected ways, particularly under transformation.
Furthermore, the callback must not call into JAX. The callable will
be passed PyTrees of arrays as arguments, and should return a PyTree of
may behave in unexpected ways, particularly under transformation. The callable
will be passed PyTrees of arrays as arguments, and should return a PyTree of
arrays that matches ``result_shape_dtypes``.
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
whose structure matches the expected output of the callback function at runtime.
@ -630,15 +628,14 @@ def io_callback(
ordered: bool = False,
**kwargs: Any,
):
"""Calls an impure Python callback. The callback function must not include any
calls back into JAX.
"""Calls an impure Python callback.
For more explanation, see `External Callbacks`_.
Args:
callback: function to execute on the host. It is assumed to be an impure function.
If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to
more efficient execution. The ``callback`` must not call into JAX.
more efficient execution.
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
whose structure matches the expected output of the callback function at runtime.
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.

View File

@ -249,11 +249,8 @@ def debug_callback(callback: Callable[..., None], *args: Any,
possible while revealing as much about them as possible, such as which parts
of the computation are duplicated or dropped.
Inside of the ``callback`` function there should not be a call back into JAX.
Args:
callback: A Python callable returning None. The ``callback`` must not call
into JAX.
callback: A Python callable returning None.
*args: The positional arguments to the callback.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this callback w.r.t.