mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Applied review suggestsions
This commit is contained in:
parent
f03baaec51
commit
3938018228
@ -2721,20 +2721,12 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
|
||||
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
|
||||
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
|
||||
|
||||
Note that in some cases the XLA compiler may do its optimization that affect
|
||||
the meory usage, e.g., common-subexpression optimization, fusion, or
|
||||
even rematerialization. For example, if the code uses only element-wise
|
||||
operations, like in our example, XLA may fuse both the forward and backward
|
||||
pass, or may rematerialize aggressively, resulting in low memory usage even
|
||||
without ``jax.checkpoint``. In fact, in some such cases ``jax.checkpoint`` may
|
||||
hinder the compiler and you may see larger memory usage. For complex examples,
|
||||
the effect of ``jax.checkpoint`` is likely to be more significant than the
|
||||
XLA optimizations.
|
||||
|
||||
The best way to use ``jax.checkpoint`` is to experiment with its placement
|
||||
on sub-computations. For example, ``lambda x: f(jax.checkpoint(g)(x))`` is
|
||||
likely to have lower memory usage under ``jax.grad`` than when not using
|
||||
``jax.checkpoint``, or than when using it on ``f`` or the whole function.
|
||||
While ``jax.checkpoint`` controls what values are stored from the forward-pass
|
||||
to be used on the backward pass, the total amount of memory required to
|
||||
evaluate a function or its VJP depends on many additional internal details of
|
||||
that function. Those details include which numerical primitives are used,
|
||||
how they're composed, where jit and control flow primitives like scan
|
||||
are used, and other factors.
|
||||
|
||||
The :func:`jax.checkpoint` decorator can be applied recursively to express
|
||||
sophisticated autodiff rematerialization strategies. For example:
|
||||
|
Loading…
x
Reference in New Issue
Block a user