consolidate the code example

This commit is contained in:
Jane Liu 2024-12-19 11:35:24 -08:00
parent f124232f53
commit 522a8fd792

View File

@ -360,6 +360,8 @@ You may consider offloading to CPU memory instead of recomputing when checkpoint
```{code-cell} ```{code-cell}
from jax.ad_checkpoint import checkpoint from jax.ad_checkpoint import checkpoint
def checkpoint_offload_dot_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host") "device", "pinned_host")
@ -379,8 +381,14 @@ One of JAX's checkpoint policies allows specified checkpoint names to be offload
```{code-cell} ```{code-cell}
from jax.ad_checkpoint import checkpoint, checkpoint_name from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu
def g(self): def checkpoint_names_saved_offloaded_recomputed(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)
policy = jax.checkpoint_policies.save_and_offload_only_these_names( policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"], names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
@ -388,12 +396,19 @@ def g(self):
@functools.partial(checkpoint, policy=policy) @functools.partial(checkpoint, policy=policy)
def f(x): def f(x):
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y") y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z") z = checkpoint_name(jnp.sin(y), "z")
z = z.T
w = checkpoint_name(jnp.sin(z), "w") w = checkpoint_name(jnp.sin(z), "w")
return jnp.sum(w) return (w.T, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
``` ```
The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.
#### List of policies #### List of policies
The policies are: The policies are: