Merge pull request #25594 from zhenying-liu:activation-offloading-doc

PiperOrigin-RevId: 713170813
This commit is contained in:
jax authors 2025-01-07 23:26:21 -08:00
commit 81db3219b7

View File

@ -354,6 +354,61 @@ print_saved_residuals(loss_checkpoint2, params, x, y)
Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`.
#### Custom policies for offload
You may consider offloading to CPU memory instead of recomputing when checkpointing to save accelerator memory. `jax.checkpoint_policies.offload_dot_with_no_batch_dims` can offload the results of matrix multiplications with no batch dimensions to the CPU.
```{code-cell}
from jax.ad_checkpoint import checkpoint
def checkpoint_offload_dot_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host")
@functools.partial(checkpoint, policy=policy)
def f(x):
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.sum(x)
return x
```
One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed.
```{code-cell}
from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu
def checkpoint_names_saved_offloaded_recomputed(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
offload_src='device', offload_dst='pinned_host')
@functools.partial(checkpoint, policy=policy)
def f(x):
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
z = z.T
w = checkpoint_name(jnp.sin(z), "w")
return (w.T, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
```
The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.
#### List of policies
The policies are: