mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #26154 from jakevdp:pure-callback-doc
PiperOrigin-RevId: 720763192
This commit is contained in:
commit
bf22b53cf4
@ -24,7 +24,7 @@ kernelspec:
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-16' } *-->
|
||||
|
||||
This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`.
|
||||
This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are `jax.pure_callback`, `jax.experimental.io_callback` and `jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`.
|
||||
|
||||
## Why callbacks?
|
||||
|
||||
@ -66,8 +66,11 @@ This works by passing the runtime value of `y` as a CPU {class}`jax.Array` back
|
||||
In earlier versions of JAX, there was only one kind of callback available, implemented in {func}`jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:
|
||||
|
||||
- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effects.
|
||||
See {ref}`external-callbacks-exploring-pure-callback`.
|
||||
- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.
|
||||
See {ref}`external-callbacks-exploring-io-callback`.
|
||||
- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.
|
||||
See {ref}`external-callbacks-exploring-debug-callback`.
|
||||
|
||||
(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).
|
||||
|
||||
@ -85,6 +88,7 @@ From the user perspective, these three flavors of callback are mainly distinguis
|
||||
|
||||
³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases.
|
||||
|
||||
(external-callbacks-exploring-pure-callback)=
|
||||
### Exploring `pure_callback`
|
||||
|
||||
{func}`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).
|
||||
@ -163,6 +167,41 @@ f2();
|
||||
In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.
|
||||
In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.
|
||||
|
||||
#### `pure_callback` and exceptions
|
||||
|
||||
In the context of JAX transformations, Python runtime exceptions should be considered side-effects:
|
||||
this means that intentionally raising an error within a `pure_callback` breaks the API contract,
|
||||
and the behavior of the resulting program is undefined. In particular, the manner in which
|
||||
such a program halts will generally depend on the backend, and the details of that behavior may
|
||||
change in future releases.
|
||||
|
||||
Additionally, passing impure functions to `pure_callback` may result in unexpected behavior during
|
||||
transformations like {func}`jax.jit` or {func}`jax.vmap`, because the transformation rules for
|
||||
`pure_callback` are defined under the assumption that the callback function is pure. Here's one
|
||||
simple example of an impure callback behaving unexpectedly under `vmap`:
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
def raise_via_callback(x):
|
||||
def _raise(x):
|
||||
raise ValueError(f"value of x is {x}")
|
||||
return jax.pure_callback(_raise, x, x)
|
||||
|
||||
def raise_if_negative(x):
|
||||
return jax.lax.cond(x < 0, raise_via_callback, lambda x: x, x)
|
||||
|
||||
x_batch = jnp.arange(4)
|
||||
|
||||
[raise_if_negative(x) for x in x_batch] # does not raise
|
||||
|
||||
jax.vmap(raise_if_negative)(x_batch) # ValueError: value of x is 0
|
||||
```
|
||||
To avoid this and similar unexpected behavior, we recommend not attempting to use
|
||||
`pure_callback` to raise runtime errors.
|
||||
|
||||
|
||||
(external-callbacks-exploring-io-callback)=
|
||||
### Exploring `io_callback`
|
||||
|
||||
In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.
|
||||
|
@ -359,6 +359,13 @@ def pure_callback(
|
||||
`jit`-decorated function has no data dependence on its value. Pure callbacks
|
||||
may also be reordered if data-dependence allows.
|
||||
|
||||
.. warning::
|
||||
|
||||
In the context of JAX transformations, Python exceptions should be
|
||||
considered side-effects: this means that intentionally raising an error
|
||||
within a `pure_callback` breaks the API contract, and the behavior of
|
||||
the resulting program is undefined.
|
||||
|
||||
When `vmap`-ed the behavior will depend on the value of the ``vmap_method``.
|
||||
|
||||
* Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method``
|
||||
@ -440,7 +447,7 @@ def pure_callback(
|
||||
(4,) (4,)
|
||||
Array([1., 2., 3., 4.], dtype=float32)
|
||||
|
||||
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
||||
.. _External Callbacks: https://jax.readthedocs.io/en/latest/external-callbacks.html
|
||||
"""
|
||||
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
|
||||
deprecations.warn(
|
||||
|
Loading…
x
Reference in New Issue
Block a user