Applied review suggestsions

This commit is contained in:
George Necula 2021-10-08 10:11:31 +02:00
parent f03baaec51
commit 3938018228

View File

@ -2721,20 +2721,12 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
Note that in some cases the XLA compiler may do its optimization that affect
the meory usage, e.g., common-subexpression optimization, fusion, or
even rematerialization. For example, if the code uses only element-wise
operations, like in our example, XLA may fuse both the forward and backward
pass, or may rematerialize aggressively, resulting in low memory usage even
without ``jax.checkpoint``. In fact, in some such cases ``jax.checkpoint`` may
hinder the compiler and you may see larger memory usage. For complex examples,
the effect of ``jax.checkpoint`` is likely to be more significant than the
XLA optimizations.
The best way to use ``jax.checkpoint`` is to experiment with its placement
on sub-computations. For example, ``lambda x: f(jax.checkpoint(g)(x))`` is
likely to have lower memory usage under ``jax.grad`` than when not using
``jax.checkpoint``, or than when using it on ``f`` or the whole function.
While ``jax.checkpoint`` controls what values are stored from the forward-pass
to be used on the backward pass, the total amount of memory required to
evaluate a function or its VJP depends on many additional internal details of
that function. Those details include which numerical primitives are used,
how they're composed, where jit and control flow primitives like scan
are used, and other factors.
The :func:`jax.checkpoint` decorator can be applied recursively to express
sophisticated autodiff rematerialization strategies. For example: