mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
improve jax.nn.relu differentiation (#2342)
This commit is contained in:
parent
e7debd732f
commit
1e61ba429d
@ -18,14 +18,15 @@
|
||||
import numpy as onp
|
||||
|
||||
from jax import dtypes
|
||||
from jax import custom_transforms, defjvp
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax.scipy.special import expit
|
||||
import jax.numpy as np
|
||||
from jax import jarrett
|
||||
|
||||
# activations
|
||||
|
||||
@custom_transforms
|
||||
def relu(x):
|
||||
r"""Rectified linear unit activation function.
|
||||
|
||||
@ -35,6 +36,7 @@ def relu(x):
|
||||
\mathrm{relu}(x) = \max(x, 0)
|
||||
"""
|
||||
return np.maximum(x, 0)
|
||||
defjvp(relu, lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
|
||||
|
||||
def softplus(x):
|
||||
r"""Softplus activation function.
|
||||
|
@ -36,16 +36,23 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testSoftplusGrad(self):
|
||||
check_grads(nn.softplus, (1e-8,), 4,
|
||||
check_grads(nn.softplus, (1e-8,), order=4,
|
||||
rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
|
||||
|
||||
def testReluGrad(self):
|
||||
rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
|
||||
check_grads(nn.relu, (1.,), order=3, rtol=rtol)
|
||||
check_grads(nn.relu, (-1.,), order=3, rtol=rtol)
|
||||
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
|
||||
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
|
||||
|
||||
def testSoftplusValue(self):
|
||||
val = nn.softplus(89.)
|
||||
self.assertAllClose(val, 89., check_dtypes=False)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testEluGrad(self):
|
||||
check_grads(nn.elu, (1e4,), 4, eps=1.)
|
||||
check_grads(nn.elu, (1e4,), order=4, eps=1.)
|
||||
|
||||
def testEluValue(self):
|
||||
val = nn.elu(1e4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user