Merge pull request #14840 from mattjj:relu6-grad-at-0-and-6

PiperOrigin-RevId: 515106619
This commit is contained in:
jax authors 2023-03-08 12:15:38 -08:00
commit 6634600c46
2 changed files with 21 additions and 0 deletions

View File

@ -431,6 +431,7 @@ def one_hot(x: Array, num_classes: int, *,
return _one_hot(x, num_classes, dtype=dtype, axis=axis)
@jax.custom_jvp
@jax.jit
def relu6(x: Array) -> Array:
r"""Rectified Linear Unit 6 activation function.
@ -440,10 +441,23 @@ def relu6(x: Array) -> Array:
.. math::
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
except under differentiation, we take:
.. math::
\nabla \mathrm{relu}(0) = 0
and
.. math::
\nabla \mathrm{relu}(6) = 0
Args:
x : input array
"""
return jnp.minimum(jnp.maximum(x, 0), 6.)
relu6.defjvps(lambda g, ans, x:
lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0)))
@jax.jit
def hard_sigmoid(x: Array) -> Array:

View File

@ -68,6 +68,13 @@ class NNFunctionsTest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
def testRelu6Grad(self):
rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
check_grads(nn.relu6, (1.,), order=3, rtol=rtol)
check_grads(nn.relu6, (-1.,), order=3, rtol=rtol)
self.assertAllClose(jax.grad(nn.relu6)(0.), 0., check_dtypes=False)
self.assertAllClose(jax.grad(nn.relu6)(6.), 0., check_dtypes=False)
def testSoftplusValue(self):
val = nn.softplus(89.)
self.assertAllClose(val, 89., check_dtypes=False)