add basic lax.stop_gradient primitive

This commit is contained in:
Matthew Johnson 2019-01-30 10:39:35 -08:00
parent cc3e8df47a
commit c293b3775b
2 changed files with 38 additions and 0 deletions

View File

@ -485,6 +485,10 @@ def broadcasted_eye(dtype, shape, axes):
return EyeConstant(shape, axes, dtype)
def stop_gradient(x):
return stop_gradient_p.bind(x)
### convenience wrappers around traceables
@ -2827,6 +2831,27 @@ for t in [FilledConstant, IotaConstant, EyeConstant]:
ad_util.jaxval_zeros_likers[t] = zeros_like_array
### stop_gradient
def stop_gradient_jvp_rule(primals, tangents):
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
x, = primals
return stop_gradient(x), ad_util.zero
def stop_gradient_batch_rule(batched_args, batch_dims):
x, = batched_args
dim, = batch_dims
return stop_gradient(x), dim
stop_gradient_p = Primitive('stop_gradient')
stop_gradient_p.def_impl(identity)
stop_gradient_p.def_abstract_eval(identity)
xla.translations[stop_gradient_p] = lambda c, x: x
ad.primitive_jvps[stop_gradient_p] = stop_gradient_jvp_rule
batching.primitive_batchers[stop_gradient_p] = stop_gradient_batch_rule
### util
def _ndim(x):

View File

@ -2186,6 +2186,19 @@ class LaxAutodiffTest(jtu.JaxTestCase):
y = rng(update_shape, dtype)
check_grads(scatter_add, (x, y), 2, 1e-2, 1e-2, 1e-2)
def testStopGradient(self):
def f(x):
return lax.sin(x) * lax.cos(lax.stop_gradient(x))
def f2(x, y):
return lax.sin(x) * lax.cos(y)
x = 3.14
ans = api.grad(f)(x)
expected = api.grad(f2)(x, x)
self.assertAllClose(ans, expected, check_dtypes=True)
if __name__ == '__main__':
absltest.main()