mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
consolidate the code example
This commit is contained in:
parent
f124232f53
commit
522a8fd792
@ -360,40 +360,55 @@ You may consider offloading to CPU memory instead of recomputing when checkpoint
|
||||
|
||||
```{code-cell}
|
||||
from jax.ad_checkpoint import checkpoint
|
||||
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
||||
"device", "pinned_host")
|
||||
|
||||
@functools.partial(checkpoint, policy=policy)
|
||||
def f(x):
|
||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.sum(x)
|
||||
return x
|
||||
def checkpoint_offload_dot_with_no_batch_dims(self):
|
||||
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
||||
"device", "pinned_host")
|
||||
|
||||
@functools.partial(checkpoint, policy=policy)
|
||||
def f(x):
|
||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.sum(x)
|
||||
return x
|
||||
```
|
||||
|
||||
One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed.
|
||||
|
||||
```{code-cell}
|
||||
from jax.ad_checkpoint import checkpoint, checkpoint_name
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
def g(self):
|
||||
def checkpoint_names_saved_offloaded_recomputed(self):
|
||||
mesh = jtu.create_mesh((2,), ("x",))
|
||||
shape = (256, 128)
|
||||
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
s = NamedSharding(mesh, P("x"))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
|
||||
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
||||
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
|
||||
offload_src='device', offload_dst='pinned_host')
|
||||
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
||||
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
|
||||
offload_src='device', offload_dst='pinned_host')
|
||||
|
||||
@functools.partial(checkpoint, policy=policy)
|
||||
def f(x):
|
||||
@functools.partial(checkpoint, policy=policy)
|
||||
def f(x):
|
||||
def g(ys, _):
|
||||
y, _ = ys
|
||||
y = checkpoint_name(jnp.sin(y), "y")
|
||||
z = checkpoint_name(jnp.sin(y), "z")
|
||||
z = z.T
|
||||
w = checkpoint_name(jnp.sin(z), "w")
|
||||
return jnp.sum(w)
|
||||
return (w.T, jnp.sum(w)), None
|
||||
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
|
||||
return scan_out
|
||||
```
|
||||
|
||||
The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.
|
||||
|
||||
#### List of policies
|
||||
|
||||
The policies are:
|
||||
|
Loading…
x
Reference in New Issue
Block a user