mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 06:36:07 +00:00
Revert https://github.com/jax-ml/jax/pull/25982 since callbacks can now use JAX functions.
This commit is contained in:
parent
8720e95570
commit
e2eff1f8d5
@ -74,7 +74,7 @@ In earlier versions of JAX, there was only one kind of callback available, imple
|
|||||||
|
|
||||||
(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).
|
(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).
|
||||||
|
|
||||||
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. All three of them must **not** include any calls back into JAX.
|
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
|
||||||
|
|
||||||
|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|
|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|
||||||
|-------------------------------------|----|----|----|----|----|----|
|
|-------------------------------------|----|----|----|----|----|----|
|
||||||
|
@ -349,8 +349,7 @@ def pure_callback(
|
|||||||
|
|
||||||
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
||||||
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
|
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
|
||||||
it should also return JAX arrays on CPU. The ``callback`` function must not
|
it should also return JAX arrays on CPU.
|
||||||
include any calls back into JAX.
|
|
||||||
|
|
||||||
The callback is treated as functionally pure, meaning it has no side-effects
|
The callback is treated as functionally pure, meaning it has no side-effects
|
||||||
and its output value depends only on its argument values. As a consequence, it
|
and its output value depends only on its argument values. As a consequence, it
|
||||||
@ -389,9 +388,8 @@ def pure_callback(
|
|||||||
Args:
|
Args:
|
||||||
callback: function to execute on the host. The callback is assumed to be a pure
|
callback: function to execute on the host. The callback is assumed to be a pure
|
||||||
function (i.e. one without side-effects): if an impure function is passed, it
|
function (i.e. one without side-effects): if an impure function is passed, it
|
||||||
may behave in unexpected ways, particularly under transformation.
|
may behave in unexpected ways, particularly under transformation. The callable
|
||||||
Furthermore, the callback must not call into JAX. The callable will
|
will be passed PyTrees of arrays as arguments, and should return a PyTree of
|
||||||
be passed PyTrees of arrays as arguments, and should return a PyTree of
|
|
||||||
arrays that matches ``result_shape_dtypes``.
|
arrays that matches ``result_shape_dtypes``.
|
||||||
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
||||||
whose structure matches the expected output of the callback function at runtime.
|
whose structure matches the expected output of the callback function at runtime.
|
||||||
@ -630,15 +628,14 @@ def io_callback(
|
|||||||
ordered: bool = False,
|
ordered: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Calls an impure Python callback. The callback function must not include any
|
"""Calls an impure Python callback.
|
||||||
calls back into JAX.
|
|
||||||
|
|
||||||
For more explanation, see `External Callbacks`_.
|
For more explanation, see `External Callbacks`_.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback: function to execute on the host. It is assumed to be an impure function.
|
callback: function to execute on the host. It is assumed to be an impure function.
|
||||||
If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to
|
If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to
|
||||||
more efficient execution. The ``callback`` must not call into JAX.
|
more efficient execution.
|
||||||
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
||||||
whose structure matches the expected output of the callback function at runtime.
|
whose structure matches the expected output of the callback function at runtime.
|
||||||
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
|
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
|
||||||
|
@ -249,11 +249,8 @@ def debug_callback(callback: Callable[..., None], *args: Any,
|
|||||||
possible while revealing as much about them as possible, such as which parts
|
possible while revealing as much about them as possible, such as which parts
|
||||||
of the computation are duplicated or dropped.
|
of the computation are duplicated or dropped.
|
||||||
|
|
||||||
Inside of the ``callback`` function there should not be a call back into JAX.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback: A Python callable returning None. The ``callback`` must not call
|
callback: A Python callable returning None.
|
||||||
into JAX.
|
|
||||||
*args: The positional arguments to the callback.
|
*args: The positional arguments to the callback.
|
||||||
ordered: A keyword only argument used to indicate whether or not the
|
ordered: A keyword only argument used to indicate whether or not the
|
||||||
staged out computation will enforce ordering of this callback w.r.t.
|
staged out computation will enforce ordering of this callback w.r.t.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user