mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
fab4dde12e
commit
4608582991
13
jax/lax.py
13
jax/lax.py
@ -1114,6 +1114,19 @@ def broadcasted_eye(dtype, shape, axes):
|
||||
|
||||
|
||||
def stop_gradient(x):
|
||||
"""Stops gradient computation.
|
||||
|
||||
Operationally `stop_gradient` is the identity function, that is, it returns
|
||||
argument `x` unchanged. However, `stop_gradient` prevents the flow of
|
||||
gradients during forward or reverse-mode automatic differentiation.
|
||||
|
||||
For example:
|
||||
|
||||
>>> jax.grad(lambda x: x**2)(3.)
|
||||
array(6., dtype=float32)
|
||||
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
|
||||
array(0., dtype=float32)
|
||||
"""
|
||||
return stop_gradient_p.bind(x)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user