Merge pull request #26157 from jakevdp:callbacks-vectorized

PiperOrigin-RevId: 720701900
This commit is contained in:
jax authors 2025-01-28 14:06:44 -08:00
commit fa9bb231f1

View File

@ -102,28 +102,31 @@ def f_host(x):
def f(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(f_host, result_shape, x)
return jax.pure_callback(f_host, result_shape, x, vmap_method='sequential')
x = jnp.arange(5.0)
f(x)
```
Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:"
Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` as well as higher-order primitives like `scan` and `while_loop`:"
```{code-cell}
jax.jit(f)(x)
```
```{code-cell}
jax.vmap(f)(x)
```
```{code-cell}
def body_fun(_, x):
return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
```
Because we specified a `vmap_method` in the `pure_callback` function call, it will also
be compatible with `vmap`:
```{code-cell}
jax.vmap(f)(x)
```
However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:
```{code-cell}
@ -307,8 +310,8 @@ def jv(v, z):
shape=jnp.broadcast_shapes(v.shape, z.shape),
dtype=z.dtype)
# You use vectorize=True because scipy.special.jv handles broadcasted inputs.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
# Use vmap_method="broadcast_all" because scipy.special.jv handles broadcasted inputs.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vmap_method="broadcast_all")
```
This lets us call into {func}`scipy.special.jv` from transformed JAX code, including when transformed by {func}`~jax.jit` and {func}`~jax.vmap`: