mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add basic lax.stop_gradient primitive
This commit is contained in:
parent
cc3e8df47a
commit
c293b3775b
25
jax/lax.py
25
jax/lax.py
@ -485,6 +485,10 @@ def broadcasted_eye(dtype, shape, axes):
|
||||
return EyeConstant(shape, axes, dtype)
|
||||
|
||||
|
||||
def stop_gradient(x):
|
||||
return stop_gradient_p.bind(x)
|
||||
|
||||
|
||||
### convenience wrappers around traceables
|
||||
|
||||
|
||||
@ -2827,6 +2831,27 @@ for t in [FilledConstant, IotaConstant, EyeConstant]:
|
||||
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
||||
|
||||
|
||||
### stop_gradient
|
||||
|
||||
|
||||
def stop_gradient_jvp_rule(primals, tangents):
|
||||
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
|
||||
x, = primals
|
||||
return stop_gradient(x), ad_util.zero
|
||||
|
||||
def stop_gradient_batch_rule(batched_args, batch_dims):
|
||||
x, = batched_args
|
||||
dim, = batch_dims
|
||||
return stop_gradient(x), dim
|
||||
|
||||
stop_gradient_p = Primitive('stop_gradient')
|
||||
stop_gradient_p.def_impl(identity)
|
||||
stop_gradient_p.def_abstract_eval(identity)
|
||||
xla.translations[stop_gradient_p] = lambda c, x: x
|
||||
ad.primitive_jvps[stop_gradient_p] = stop_gradient_jvp_rule
|
||||
batching.primitive_batchers[stop_gradient_p] = stop_gradient_batch_rule
|
||||
|
||||
|
||||
### util
|
||||
|
||||
def _ndim(x):
|
||||
|
@ -2186,6 +2186,19 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
y = rng(update_shape, dtype)
|
||||
check_grads(scatter_add, (x, y), 2, 1e-2, 1e-2, 1e-2)
|
||||
|
||||
def testStopGradient(self):
|
||||
def f(x):
|
||||
return lax.sin(x) * lax.cos(lax.stop_gradient(x))
|
||||
|
||||
def f2(x, y):
|
||||
return lax.sin(x) * lax.cos(y)
|
||||
|
||||
x = 3.14
|
||||
ans = api.grad(f)(x)
|
||||
expected = api.grad(f2)(x, x)
|
||||
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user