mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 12:56:08 +00:00
consolidate the code example
This commit is contained in:
parent
f124232f53
commit
522a8fd792
@ -360,6 +360,8 @@ You may consider offloading to CPU memory instead of recomputing when checkpoint
|
|||||||
|
|
||||||
```{code-cell}
|
```{code-cell}
|
||||||
from jax.ad_checkpoint import checkpoint
|
from jax.ad_checkpoint import checkpoint
|
||||||
|
|
||||||
|
def checkpoint_offload_dot_with_no_batch_dims(self):
|
||||||
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
||||||
"device", "pinned_host")
|
"device", "pinned_host")
|
||||||
|
|
||||||
@ -379,8 +381,14 @@ One of JAX's checkpoint policies allows specified checkpoint names to be offload
|
|||||||
|
|
||||||
```{code-cell}
|
```{code-cell}
|
||||||
from jax.ad_checkpoint import checkpoint, checkpoint_name
|
from jax.ad_checkpoint import checkpoint, checkpoint_name
|
||||||
|
from jax._src import test_util as jtu
|
||||||
|
|
||||||
def g(self):
|
def checkpoint_names_saved_offloaded_recomputed(self):
|
||||||
|
mesh = jtu.create_mesh((2,), ("x",))
|
||||||
|
shape = (256, 128)
|
||||||
|
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||||
|
s = NamedSharding(mesh, P("x"))
|
||||||
|
inp = jax.device_put(np_inp, s)
|
||||||
|
|
||||||
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
||||||
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
|
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
|
||||||
@ -388,12 +396,19 @@ def g(self):
|
|||||||
|
|
||||||
@functools.partial(checkpoint, policy=policy)
|
@functools.partial(checkpoint, policy=policy)
|
||||||
def f(x):
|
def f(x):
|
||||||
|
def g(ys, _):
|
||||||
|
y, _ = ys
|
||||||
y = checkpoint_name(jnp.sin(y), "y")
|
y = checkpoint_name(jnp.sin(y), "y")
|
||||||
z = checkpoint_name(jnp.sin(y), "z")
|
z = checkpoint_name(jnp.sin(y), "z")
|
||||||
|
z = z.T
|
||||||
w = checkpoint_name(jnp.sin(z), "w")
|
w = checkpoint_name(jnp.sin(z), "w")
|
||||||
return jnp.sum(w)
|
return (w.T, jnp.sum(w)), None
|
||||||
|
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
|
||||||
|
return scan_out
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.
|
||||||
|
|
||||||
#### List of policies
|
#### List of policies
|
||||||
|
|
||||||
The policies are:
|
The policies are:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user