mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #20558 from carlosgmartin:mish
PiperOrigin-RevId: 621708823
This commit is contained in:
commit
29a2762b64
@ -38,6 +38,7 @@ Activation functions
|
||||
gelu
|
||||
glu
|
||||
squareplus
|
||||
mish
|
||||
|
||||
Other functions
|
||||
---------------
|
||||
|
@ -199,6 +199,29 @@ def silu(x: ArrayLike) -> Array:
|
||||
|
||||
swish = silu
|
||||
|
||||
@jax.jit
|
||||
def mish(x: ArrayLike) -> Array:
|
||||
r"""Mish activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
|
||||
.. math::
|
||||
\mathrm{mish}(x) = x \cdot \mathrm{tanh}(\mathrm{softplus}(x))
|
||||
|
||||
For more information, see
|
||||
`Mish: A Self Regularized Non-Monotonic Activation Function
|
||||
<https://arxiv.org/abs/1908.08681>`_.
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
|
||||
Returns:
|
||||
An array.
|
||||
"""
|
||||
numpy_util.check_arraylike("mish", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
return x_arr * jnp.tanh(softplus(x_arr))
|
||||
|
||||
@jax.jit
|
||||
def log_sigmoid(x: ArrayLike) -> Array:
|
||||
r"""Log-sigmoid activation function.
|
||||
@ -314,7 +337,7 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
|
||||
|
||||
For more information, see
|
||||
`Continuously Differentiable Exponential Linear Units
|
||||
<https://arxiv.org/pdf/1704.07483.pdf>`_.
|
||||
<https://arxiv.org/abs/1704.07483>`_.
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
@ -342,7 +365,7 @@ def selu(x: ArrayLike) -> Array:
|
||||
|
||||
For more information, see
|
||||
`Self-Normalizing Neural Networks
|
||||
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
|
||||
<https://arxiv.org/abs/1706.02515>`_.
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
|
@ -45,6 +45,7 @@ from jax._src.nn.functions import (
|
||||
silu as silu,
|
||||
swish as swish,
|
||||
squareplus as squareplus,
|
||||
mish as mish,
|
||||
)
|
||||
|
||||
# Deprecations
|
||||
|
@ -91,6 +91,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
def testSquareplusZero(self, dtype):
|
||||
self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4)))
|
||||
|
||||
def testMishGrad(self):
|
||||
check_grads(nn.mish, (1e-8,), order=4,
|
||||
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
||||
|
||||
def testMishGradZero(self):
|
||||
check_grads(nn.mish, (0.,), order=1,
|
||||
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
||||
|
||||
def testMishGradNegInf(self):
|
||||
check_grads(nn.mish, (-float('inf'),), order=1,
|
||||
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
||||
|
||||
def testMishGradNan(self):
|
||||
check_grads(nn.mish, (float('nan'),), order=1,
|
||||
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
||||
|
||||
@parameterized.parameters([float] + jtu.dtypes.floating)
|
||||
def testMishZero(self, dtype):
|
||||
self.assertEqual(dtype(0), nn.mish(dtype(0)))
|
||||
|
||||
def testReluGrad(self):
|
||||
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
|
||||
check_grads(nn.relu, (1.,), order=3, rtol=rtol)
|
||||
@ -117,6 +137,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
val = nn.squareplus(1e3)
|
||||
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
||||
|
||||
def testMishValue(self):
|
||||
val = nn.mish(1e3)
|
||||
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testEluGrad(self):
|
||||
check_grads(nn.elu, (1e4,), order=4, eps=1.)
|
||||
@ -149,7 +173,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
(jnp.float32, jnp.bfloat16, jnp.float16),
|
||||
(partial(nn.gelu, approximate=False),
|
||||
partial(nn.gelu, approximate=True),
|
||||
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus)))
|
||||
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
|
||||
def testDtypeMatchesInput(self, dtype, fn):
|
||||
x = jnp.zeros((), dtype=dtype)
|
||||
out = fn(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user