mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 23:26:05 +00:00
consolidate the code example
This commit is contained in:
parent
f124232f53
commit
522a8fd792
@ -360,6 +360,8 @@ You may consider offloading to CPU memory instead of recomputing when checkpoint
|
||||
|
||||
```{code-cell}
|
||||
from jax.ad_checkpoint import checkpoint
|
||||
|
||||
def checkpoint_offload_dot_with_no_batch_dims(self):
|
||||
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
||||
"device", "pinned_host")
|
||||
|
||||
@ -379,8 +381,14 @@ One of JAX's checkpoint policies allows specified checkpoint names to be offload
|
||||
|
||||
```{code-cell}
|
||||
from jax.ad_checkpoint import checkpoint, checkpoint_name
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
def g(self):
|
||||
def checkpoint_names_saved_offloaded_recomputed(self):
|
||||
mesh = jtu.create_mesh((2,), ("x",))
|
||||
shape = (256, 128)
|
||||
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
s = NamedSharding(mesh, P("x"))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
|
||||
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
||||
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
|
||||
@ -388,12 +396,19 @@ def g(self):
|
||||
|
||||
@functools.partial(checkpoint, policy=policy)
|
||||
def f(x):
|
||||
def g(ys, _):
|
||||
y, _ = ys
|
||||
y = checkpoint_name(jnp.sin(y), "y")
|
||||
z = checkpoint_name(jnp.sin(y), "z")
|
||||
z = z.T
|
||||
w = checkpoint_name(jnp.sin(z), "w")
|
||||
return jnp.sum(w)
|
||||
return (w.T, jnp.sum(w)), None
|
||||
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
|
||||
return scan_out
|
||||
```
|
||||
|
||||
The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.
|
||||
|
||||
#### List of policies
|
||||
|
||||
The policies are:
|
||||
|
Loading…
x
Reference in New Issue
Block a user