mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #14840 from mattjj:relu6-grad-at-0-and-6
PiperOrigin-RevId: 515106619
This commit is contained in:
commit
6634600c46
@ -431,6 +431,7 @@ def one_hot(x: Array, num_classes: int, *,
|
||||
return _one_hot(x, num_classes, dtype=dtype, axis=axis)
|
||||
|
||||
|
||||
@jax.custom_jvp
|
||||
@jax.jit
|
||||
def relu6(x: Array) -> Array:
|
||||
r"""Rectified Linear Unit 6 activation function.
|
||||
@ -440,10 +441,23 @@ def relu6(x: Array) -> Array:
|
||||
.. math::
|
||||
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
|
||||
|
||||
except under differentiation, we take:
|
||||
|
||||
.. math::
|
||||
\nabla \mathrm{relu}(0) = 0
|
||||
|
||||
and
|
||||
|
||||
.. math::
|
||||
\nabla \mathrm{relu}(6) = 0
|
||||
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
"""
|
||||
return jnp.minimum(jnp.maximum(x, 0), 6.)
|
||||
relu6.defjvps(lambda g, ans, x:
|
||||
lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0)))
|
||||
|
||||
@jax.jit
|
||||
def hard_sigmoid(x: Array) -> Array:
|
||||
|
@ -68,6 +68,13 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
|
||||
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
|
||||
|
||||
def testRelu6Grad(self):
|
||||
rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
|
||||
check_grads(nn.relu6, (1.,), order=3, rtol=rtol)
|
||||
check_grads(nn.relu6, (-1.,), order=3, rtol=rtol)
|
||||
self.assertAllClose(jax.grad(nn.relu6)(0.), 0., check_dtypes=False)
|
||||
self.assertAllClose(jax.grad(nn.relu6)(6.), 0., check_dtypes=False)
|
||||
|
||||
def testSoftplusValue(self):
|
||||
val = nn.softplus(89.)
|
||||
self.assertAllClose(val, 89., check_dtypes=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user