Add documentation for lax.stop_gradient.

Fixes #492.
This commit is contained in:
Peter Hawkins 2019-03-10 18:08:04 -04:00
parent fab4dde12e
commit 4608582991

View File

@ -1114,6 +1114,19 @@ def broadcasted_eye(dtype, shape, axes):
def stop_gradient(x):
"""Stops gradient computation.
Operationally `stop_gradient` is the identity function, that is, it returns
argument `x` unchanged. However, `stop_gradient` prevents the flow of
gradients during forward or reverse-mode automatic differentiation.
For example:
>>> jax.grad(lambda x: x**2)(3.)
array(6., dtype=float32)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
array(0., dtype=float32)
"""
return stop_gradient_p.bind(x)