improve jax.nn.relu differentiation (#2342)

This commit is contained in:
Matthew Johnson 2020-03-03 16:27:53 -08:00 committed by GitHub
parent e7debd732f
commit 1e61ba429d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 3 deletions

View File

@ -18,14 +18,15 @@
import numpy as onp
from jax import dtypes
from jax import custom_transforms, defjvp
from jax import lax
from jax import random
from jax.scipy.special import expit
import jax.numpy as np
from jax import jarrett
# activations
@custom_transforms
def relu(x):
r"""Rectified linear unit activation function.
@ -35,6 +36,7 @@ def relu(x):
\mathrm{relu}(x) = \max(x, 0)
"""
return np.maximum(x, 0)
defjvp(relu, lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
def softplus(x):
r"""Softplus activation function.

View File

@ -36,16 +36,23 @@ class NNFunctionsTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
check_grads(nn.softplus, (1e-8,), 4,
check_grads(nn.softplus, (1e-8,), order=4,
rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
def testReluGrad(self):
rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
check_grads(nn.relu, (1.,), order=3, rtol=rtol)
check_grads(nn.relu, (-1.,), order=3, rtol=rtol)
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
def testSoftplusValue(self):
val = nn.softplus(89.)
self.assertAllClose(val, 89., check_dtypes=False)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testEluGrad(self):
check_grads(nn.elu, (1e4,), 4, eps=1.)
check_grads(nn.elu, (1e4,), order=4, eps=1.)
def testEluValue(self):
val = nn.elu(1e4)