mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26157 from jakevdp:callbacks-vectorized
PiperOrigin-RevId: 720701900
This commit is contained in:
commit
fa9bb231f1
@ -102,28 +102,31 @@ def f_host(x):
|
||||
|
||||
def f(x):
|
||||
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
return jax.pure_callback(f_host, result_shape, x)
|
||||
return jax.pure_callback(f_host, result_shape, x, vmap_method='sequential')
|
||||
|
||||
x = jnp.arange(5.0)
|
||||
f(x)
|
||||
```
|
||||
|
||||
Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:"
|
||||
Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` as well as higher-order primitives like `scan` and `while_loop`:"
|
||||
|
||||
```{code-cell}
|
||||
jax.jit(f)(x)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
jax.vmap(f)(x)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def body_fun(_, x):
|
||||
return _, f(x)
|
||||
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
|
||||
```
|
||||
|
||||
Because we specified a `vmap_method` in the `pure_callback` function call, it will also
|
||||
be compatible with `vmap`:
|
||||
|
||||
```{code-cell}
|
||||
jax.vmap(f)(x)
|
||||
```
|
||||
|
||||
However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:
|
||||
|
||||
```{code-cell}
|
||||
@ -307,8 +310,8 @@ def jv(v, z):
|
||||
shape=jnp.broadcast_shapes(v.shape, z.shape),
|
||||
dtype=z.dtype)
|
||||
|
||||
# You use vectorize=True because scipy.special.jv handles broadcasted inputs.
|
||||
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
|
||||
# Use vmap_method="broadcast_all" because scipy.special.jv handles broadcasted inputs.
|
||||
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vmap_method="broadcast_all")
|
||||
```
|
||||
|
||||
This lets us call into {func}`scipy.special.jv` from transformed JAX code, including when transformed by {func}`~jax.jit` and {func}`~jax.vmap`:
|
||||
|
Loading…
x
Reference in New Issue
Block a user