mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Expand the docstring for jax.checkpoint.
I found the old docstring to be a bit inaccurate, but mostly not explanatory enough. I think that this feature deserves a tutorial, in addition to an expanded docstring.
This commit is contained in:
parent
484b5af5c9
commit
f03baaec51
@ -2708,15 +2708,33 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
|
||||
... z = jnp.sin(y)
|
||||
... return z
|
||||
...
|
||||
>>> jax.grad(g)(2.0)
|
||||
DeviceArray(-0.25563914, dtype=float32)
|
||||
>>> jax.value_and_grad(g)(2.0)
|
||||
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32))
|
||||
|
||||
Here, the same value is produced whether or not the :func:`jax.checkpoint`
|
||||
decorator is present. But when using :func:`jax.checkpoint`, the value
|
||||
``jnp.sin(2.0)`` is computed twice: once on the forward pass, and once on the
|
||||
backward pass. The values ``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are
|
||||
also computed twice. Without using the decorator, both ``jnp.cos(2.0)`` and
|
||||
``jnp.cos(jnp.sin(2.0))`` would be stored and reused.
|
||||
decorator is present. When the decorator is not present, the values
|
||||
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are computed on the forward
|
||||
pass and are stored for use in the backward pass, because they are needed
|
||||
on the backward pass and depend only on the primal inputs. When using
|
||||
:func:`jax.checkpoint`, the forward pass will compute only the primal outputs
|
||||
and only the primal inputs (``2.0``) will be stored for the backward pass.
|
||||
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
|
||||
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
|
||||
|
||||
Note that in some cases the XLA compiler may do its optimization that affect
|
||||
the meory usage, e.g., common-subexpression optimization, fusion, or
|
||||
even rematerialization. For example, if the code uses only element-wise
|
||||
operations, like in our example, XLA may fuse both the forward and backward
|
||||
pass, or may rematerialize aggressively, resulting in low memory usage even
|
||||
without ``jax.checkpoint``. In fact, in some such cases ``jax.checkpoint`` may
|
||||
hinder the compiler and you may see larger memory usage. For complex examples,
|
||||
the effect of ``jax.checkpoint`` is likely to be more significant than the
|
||||
XLA optimizations.
|
||||
|
||||
The best way to use ``jax.checkpoint`` is to experiment with its placement
|
||||
on sub-computations. For example, ``lambda x: f(jax.checkpoint(g)(x))`` is
|
||||
likely to have lower memory usage under ``jax.grad`` than when not using
|
||||
``jax.checkpoint``, or than when using it on ``f`` or the whole function.
|
||||
|
||||
The :func:`jax.checkpoint` decorator can be applied recursively to express
|
||||
sophisticated autodiff rematerialization strategies. For example:
|
||||
|
Loading…
x
Reference in New Issue
Block a user