mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Revert https://github.com/jax-ml/jax/pull/25982 since callbacks can now use JAX functions.
This commit is contained in:
parent
8720e95570
commit
e2eff1f8d5
@ -74,7 +74,7 @@ In earlier versions of JAX, there was only one kind of callback available, imple
|
||||
|
||||
(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).
|
||||
|
||||
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. All three of them must **not** include any calls back into JAX.
|
||||
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
|
||||
|
||||
|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|
||||
|-------------------------------------|----|----|----|----|----|----|
|
||||
|
@ -349,8 +349,7 @@ def pure_callback(
|
||||
|
||||
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
||||
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
|
||||
it should also return JAX arrays on CPU. The ``callback`` function must not
|
||||
include any calls back into JAX.
|
||||
it should also return JAX arrays on CPU.
|
||||
|
||||
The callback is treated as functionally pure, meaning it has no side-effects
|
||||
and its output value depends only on its argument values. As a consequence, it
|
||||
@ -389,9 +388,8 @@ def pure_callback(
|
||||
Args:
|
||||
callback: function to execute on the host. The callback is assumed to be a pure
|
||||
function (i.e. one without side-effects): if an impure function is passed, it
|
||||
may behave in unexpected ways, particularly under transformation.
|
||||
Furthermore, the callback must not call into JAX. The callable will
|
||||
be passed PyTrees of arrays as arguments, and should return a PyTree of
|
||||
may behave in unexpected ways, particularly under transformation. The callable
|
||||
will be passed PyTrees of arrays as arguments, and should return a PyTree of
|
||||
arrays that matches ``result_shape_dtypes``.
|
||||
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
||||
whose structure matches the expected output of the callback function at runtime.
|
||||
@ -630,15 +628,14 @@ def io_callback(
|
||||
ordered: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Calls an impure Python callback. The callback function must not include any
|
||||
calls back into JAX.
|
||||
"""Calls an impure Python callback.
|
||||
|
||||
For more explanation, see `External Callbacks`_.
|
||||
|
||||
Args:
|
||||
callback: function to execute on the host. It is assumed to be an impure function.
|
||||
If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to
|
||||
more efficient execution. The ``callback`` must not call into JAX.
|
||||
more efficient execution.
|
||||
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
||||
whose structure matches the expected output of the callback function at runtime.
|
||||
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
|
||||
|
@ -249,11 +249,8 @@ def debug_callback(callback: Callable[..., None], *args: Any,
|
||||
possible while revealing as much about them as possible, such as which parts
|
||||
of the computation are duplicated or dropped.
|
||||
|
||||
Inside of the ``callback`` function there should not be a call back into JAX.
|
||||
|
||||
Args:
|
||||
callback: A Python callable returning None. The ``callback`` must not call
|
||||
into JAX.
|
||||
callback: A Python callable returning None.
|
||||
*args: The positional arguments to the callback.
|
||||
ordered: A keyword only argument used to indicate whether or not the
|
||||
staged out computation will enforce ordering of this callback w.r.t.
|
||||
|
Loading…
x
Reference in New Issue
Block a user