Expand the docstring for jax.checkpoint.

I found the old docstring to be a bit inaccurate, but mostly
not explanatory enough. I think that this feature deserves
a tutorial, in addition to an expanded docstring.
This commit is contained in:
George Necula 2021-09-23 11:14:19 +03:00
parent 484b5af5c9
commit f03baaec51

View File

@ -2708,15 +2708,33 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
... z = jnp.sin(y)
... return z
...
>>> jax.grad(g)(2.0)
DeviceArray(-0.25563914, dtype=float32)
>>> jax.value_and_grad(g)(2.0)
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32))
Here, the same value is produced whether or not the :func:`jax.checkpoint`
decorator is present. But when using :func:`jax.checkpoint`, the value
``jnp.sin(2.0)`` is computed twice: once on the forward pass, and once on the
backward pass. The values ``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are
also computed twice. Without using the decorator, both ``jnp.cos(2.0)`` and
``jnp.cos(jnp.sin(2.0))`` would be stored and reused.
decorator is present. When the decorator is not present, the values
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are computed on the forward
pass and are stored for use in the backward pass, because they are needed
on the backward pass and depend only on the primal inputs. When using
:func:`jax.checkpoint`, the forward pass will compute only the primal outputs
and only the primal inputs (``2.0``) will be stored for the backward pass.
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
Note that in some cases the XLA compiler may do its optimization that affect
the meory usage, e.g., common-subexpression optimization, fusion, or
even rematerialization. For example, if the code uses only element-wise
operations, like in our example, XLA may fuse both the forward and backward
pass, or may rematerialize aggressively, resulting in low memory usage even
without ``jax.checkpoint``. In fact, in some such cases ``jax.checkpoint`` may
hinder the compiler and you may see larger memory usage. For complex examples,
the effect of ``jax.checkpoint`` is likely to be more significant than the
XLA optimizations.
The best way to use ``jax.checkpoint`` is to experiment with its placement
on sub-computations. For example, ``lambda x: f(jax.checkpoint(g)(x))`` is
likely to have lower memory usage under ``jax.grad`` than when not using
``jax.checkpoint``, or than when using it on ``f`` or the whole function.
The :func:`jax.checkpoint` decorator can be applied recursively to express
sophisticated autodiff rematerialization strategies. For example: