mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #23707 from jakevdp:stop-gradient-doc
PiperOrigin-RevId: 676876785
This commit is contained in:
commit
886aa944fa
@ -77,7 +77,7 @@ def meta_loss_fn(params, data):
|
||||
meta_grads = jax.grad(meta_loss_fn)(params, data)
|
||||
```
|
||||
|
||||
|
||||
(stopping-gradients)=
|
||||
### Stopping gradients
|
||||
|
||||
Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph.
|
||||
|
@ -1373,18 +1373,43 @@ def stop_gradient(x: T) -> T:
|
||||
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
|
||||
gradients during forward or reverse-mode automatic differentiation. If there
|
||||
are multiple nested gradient computations, ``stop_gradient`` stops gradients
|
||||
for all of them.
|
||||
for all of them. For some discussion of where this is useful, refer to
|
||||
:ref:`stopping-gradients`.
|
||||
|
||||
For example:
|
||||
Args:
|
||||
x: array or pytree of arrays
|
||||
|
||||
>>> jax.grad(lambda x: x**2)(3.)
|
||||
Array(6., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
|
||||
Array(0., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
|
||||
Array(2., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
|
||||
Array(0., dtype=float32, weak_type=True)
|
||||
Returns:
|
||||
input value is returned unchanged, but within autodiff will be treated as
|
||||
a constant.
|
||||
|
||||
Examples:
|
||||
Consider a simple function that returns the square of the input value:
|
||||
|
||||
>>> def f1(x):
|
||||
... return x ** 2
|
||||
>>> x = jnp.float32(3.0)
|
||||
>>> f1(x)
|
||||
Array(9.0, dtype=float32)
|
||||
>>> jax.grad(f1)(x)
|
||||
Array(6.0, dtype=float32)
|
||||
|
||||
The same function with ``stop_gradient`` around ``x`` will be equivalent
|
||||
under normal evaluation, but return a zero gradient because ``x`` is
|
||||
effectively treated as a constant:
|
||||
|
||||
>>> def f2(x):
|
||||
... return jax.lax.stop_gradient(x) ** 2
|
||||
>>> f2(x)
|
||||
Array(9.0, dtype=float32)
|
||||
>>> jax.grad(f2)(x)
|
||||
Array(0.0, dtype=float32)
|
||||
|
||||
This is used in a number of places within the JAX codebase; for example
|
||||
:func:`jax.nn.softmax` internally normalizes the input by its maximum
|
||||
value, and this maximum value is wrapped in ``stop_gradient`` for
|
||||
efficiency. Refer to :ref:`stopping-gradients` for more discussion of
|
||||
the applicability of ``stop_gradient``.
|
||||
"""
|
||||
def stop(x):
|
||||
# only bind primitive on inexact dtypes, to avoid some staging
|
||||
|
Loading…
x
Reference in New Issue
Block a user