Merge pull request #23707 from jakevdp:stop-gradient-doc

PiperOrigin-RevId: 676876785
This commit is contained in:
jax authors 2024-09-20 09:48:08 -07:00
commit 886aa944fa
2 changed files with 36 additions and 11 deletions

View File

@ -77,7 +77,7 @@ def meta_loss_fn(params, data):
meta_grads = jax.grad(meta_loss_fn)(params, data)
```
(stopping-gradients)=
### Stopping gradients
Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph.

View File

@ -1373,18 +1373,43 @@ def stop_gradient(x: T) -> T:
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
gradients during forward or reverse-mode automatic differentiation. If there
are multiple nested gradient computations, ``stop_gradient`` stops gradients
for all of them.
for all of them. For some discussion of where this is useful, refer to
:ref:`stopping-gradients`.
For example:
Args:
x: array or pytree of arrays
>>> jax.grad(lambda x: x**2)(3.)
Array(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
Array(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
Array(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
Array(0., dtype=float32, weak_type=True)
Returns:
input value is returned unchanged, but within autodiff will be treated as
a constant.
Examples:
Consider a simple function that returns the square of the input value:
>>> def f1(x):
... return x ** 2
>>> x = jnp.float32(3.0)
>>> f1(x)
Array(9.0, dtype=float32)
>>> jax.grad(f1)(x)
Array(6.0, dtype=float32)
The same function with ``stop_gradient`` around ``x`` will be equivalent
under normal evaluation, but return a zero gradient because ``x`` is
effectively treated as a constant:
>>> def f2(x):
... return jax.lax.stop_gradient(x) ** 2
>>> f2(x)
Array(9.0, dtype=float32)
>>> jax.grad(f2)(x)
Array(0.0, dtype=float32)
This is used in a number of places within the JAX codebase; for example
:func:`jax.nn.softmax` internally normalizes the input by its maximum
value, and this maximum value is wrapped in ``stop_gradient`` for
efficiency. Refer to :ref:`stopping-gradients` for more discussion of
the applicability of ``stop_gradient``.
"""
def stop(x):
# only bind primitive on inexact dtypes, to avoid some staging