Add jax.nn.mish.

This commit is contained in:
carlosgmartin 2024-04-03 16:37:07 -04:00
parent dcd45c8d20
commit f0314c70e8
4 changed files with 52 additions and 3 deletions

View File

@ -38,6 +38,7 @@ Activation functions
gelu
glu
squareplus
mish
Other functions
---------------

View File

@ -199,6 +199,29 @@ def silu(x: ArrayLike) -> Array:
swish = silu
@jax.jit
def mish(x: ArrayLike) -> Array:
r"""Mish activation function.
Computes the element-wise function:
.. math::
\mathrm{mish}(x) = x \cdot \mathrm{tanh}(\mathrm{softplus}(x))
For more information, see
`Mish: A Self Regularized Non-Monotonic Activation Function
<https://arxiv.org/abs/1908.08681>`_.
Args:
x : input array
Returns:
An array.
"""
numpy_util.check_arraylike("mish", x)
x_arr = jnp.asarray(x)
return x_arr * jnp.tanh(softplus(x_arr))
@jax.jit
def log_sigmoid(x: ArrayLike) -> Array:
r"""Log-sigmoid activation function.
@ -314,7 +337,7 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
For more information, see
`Continuously Differentiable Exponential Linear Units
<https://arxiv.org/pdf/1704.07483.pdf>`_.
<https://arxiv.org/abs/1704.07483>`_.
Args:
x : input array
@ -342,7 +365,7 @@ def selu(x: ArrayLike) -> Array:
For more information, see
`Self-Normalizing Neural Networks
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
<https://arxiv.org/abs/1706.02515>`_.
Args:
x : input array

View File

@ -45,6 +45,7 @@ from jax._src.nn.functions import (
silu as silu,
swish as swish,
squareplus as squareplus,
mish as mish,
)
# Deprecations

View File

@ -91,6 +91,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
def testSquareplusZero(self, dtype):
self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4)))
def testMishGrad(self):
check_grads(nn.mish, (1e-8,), order=4,
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
def testMishGradZero(self):
check_grads(nn.mish, (0.,), order=1,
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
def testMishGradNegInf(self):
check_grads(nn.mish, (-float('inf'),), order=1,
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
def testMishGradNan(self):
check_grads(nn.mish, (float('nan'),), order=1,
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
@parameterized.parameters([float] + jtu.dtypes.floating)
def testMishZero(self, dtype):
self.assertEqual(dtype(0), nn.mish(dtype(0)))
def testReluGrad(self):
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
check_grads(nn.relu, (1.,), order=3, rtol=rtol)
@ -117,6 +137,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
val = nn.squareplus(1e3)
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
def testMishValue(self):
val = nn.mish(1e3)
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testEluGrad(self):
check_grads(nn.elu, (1e4,), order=4, eps=1.)
@ -149,7 +173,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
(jnp.float32, jnp.bfloat16, jnp.float16),
(partial(nn.gelu, approximate=False),
partial(nn.gelu, approximate=True),
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus)))
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
def testDtypeMatchesInput(self, dtype, fn):
x = jnp.zeros((), dtype=dtype)
out = fn(x)