mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 05:26:07 +00:00
consolidate the code example
This commit is contained in:
parent
f124232f53
commit
522a8fd792
@ -360,40 +360,55 @@ You may consider offloading to CPU memory instead of recomputing when checkpoint
|
|||||||
|
|
||||||
```{code-cell}
|
```{code-cell}
|
||||||
from jax.ad_checkpoint import checkpoint
|
from jax.ad_checkpoint import checkpoint
|
||||||
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
|
||||||
"device", "pinned_host")
|
|
||||||
|
|
||||||
@functools.partial(checkpoint, policy=policy)
|
def checkpoint_offload_dot_with_no_batch_dims(self):
|
||||||
def f(x):
|
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
||||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
"device", "pinned_host")
|
||||||
x = jnp.sin(x)
|
|
||||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
@functools.partial(checkpoint, policy=policy)
|
||||||
x = jnp.sin(x)
|
def f(x):
|
||||||
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||||
x = jnp.sin(x)
|
x = jnp.sin(x)
|
||||||
x = jnp.sum(x)
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||||
return x
|
x = jnp.sin(x)
|
||||||
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
||||||
|
x = jnp.sin(x)
|
||||||
|
x = jnp.sum(x)
|
||||||
|
return x
|
||||||
```
|
```
|
||||||
|
|
||||||
One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed.
|
One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed.
|
||||||
|
|
||||||
```{code-cell}
|
```{code-cell}
|
||||||
from jax.ad_checkpoint import checkpoint, checkpoint_name
|
from jax.ad_checkpoint import checkpoint, checkpoint_name
|
||||||
|
from jax._src import test_util as jtu
|
||||||
|
|
||||||
def g(self):
|
def checkpoint_names_saved_offloaded_recomputed(self):
|
||||||
|
mesh = jtu.create_mesh((2,), ("x",))
|
||||||
|
shape = (256, 128)
|
||||||
|
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||||
|
s = NamedSharding(mesh, P("x"))
|
||||||
|
inp = jax.device_put(np_inp, s)
|
||||||
|
|
||||||
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
||||||
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
|
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
|
||||||
offload_src='device', offload_dst='pinned_host')
|
offload_src='device', offload_dst='pinned_host')
|
||||||
|
|
||||||
@functools.partial(checkpoint, policy=policy)
|
@functools.partial(checkpoint, policy=policy)
|
||||||
def f(x):
|
def f(x):
|
||||||
|
def g(ys, _):
|
||||||
|
y, _ = ys
|
||||||
y = checkpoint_name(jnp.sin(y), "y")
|
y = checkpoint_name(jnp.sin(y), "y")
|
||||||
z = checkpoint_name(jnp.sin(y), "z")
|
z = checkpoint_name(jnp.sin(y), "z")
|
||||||
|
z = z.T
|
||||||
w = checkpoint_name(jnp.sin(z), "w")
|
w = checkpoint_name(jnp.sin(z), "w")
|
||||||
return jnp.sum(w)
|
return (w.T, jnp.sum(w)), None
|
||||||
|
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
|
||||||
|
return scan_out
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.
|
||||||
|
|
||||||
#### List of policies
|
#### List of policies
|
||||||
|
|
||||||
The policies are:
|
The policies are:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user