mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add nn.squareplus.
This commit is contained in:
parent
2356d7afd0
commit
9f8e1bc34a
@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.21
|
||||
|
||||
* New Features
|
||||
* Added {obj}`jax.nn.squareplus`.
|
||||
|
||||
* Deprecations
|
||||
* The previously-deprecated `sym_pos` argument has been removed from
|
||||
{func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead.
|
||||
|
@ -770,8 +770,8 @@ well-defined gradients::
|
||||
The :mod:`jax.nn` submodule also has smooth versions of other common rank-based
|
||||
functions, for example :func:`jax.nn.softmax` can replace uses of
|
||||
:func:`jax.numpy.argmax`, :func:`jax.nn.soft_sign` can replace uses of
|
||||
:func:`jax.numpy.sign`, :func:`jax.nn.softplus` can replace uses of
|
||||
:func:`jax.nn.relu`, etc.
|
||||
:func:`jax.numpy.sign`, :func:`jax.nn.softplus` or :func:`jax.nn.squareplus`
|
||||
can replace uses of :func:`jax.nn.relu`, etc.
|
||||
|
||||
How can I convert a JAX Tracer to a NumPy array?
|
||||
------------------------------------------------
|
||||
|
@ -36,6 +36,7 @@ Activation functions
|
||||
selu
|
||||
gelu
|
||||
glu
|
||||
squareplus
|
||||
|
||||
Other functions
|
||||
---------------
|
||||
|
@ -72,6 +72,28 @@ def relu(x: ArrayLike) -> Array:
|
||||
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm
|
||||
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
|
||||
|
||||
@jax.jit
|
||||
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array:
|
||||
r"""Squareplus activation function.
|
||||
|
||||
Computes the element-wise function
|
||||
|
||||
.. math::
|
||||
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
|
||||
|
||||
as described in https://arxiv.org/abs/2112.11687.
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
b : smoothness parameter
|
||||
"""
|
||||
numpy_util.check_arraylike("squareplus", x)
|
||||
numpy_util.check_arraylike("squareplus", b)
|
||||
x = jnp.asarray(x)
|
||||
b = jnp.asarray(b)
|
||||
y = x + jnp.sqrt(jnp.square(x) + b)
|
||||
return y / 2
|
||||
|
||||
@jax.jit
|
||||
def softplus(x: ArrayLike) -> Array:
|
||||
r"""Softplus activation function.
|
||||
|
@ -43,6 +43,7 @@ from jax._src.nn.functions import (
|
||||
softplus as softplus,
|
||||
silu as silu,
|
||||
swish as swish,
|
||||
squareplus as squareplus,
|
||||
)
|
||||
|
||||
# Deprecations
|
||||
|
@ -63,6 +63,26 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
def testSoftplusZero(self, dtype):
|
||||
self.assertEqual(jnp.log(dtype(2)), nn.softplus(dtype(0)))
|
||||
|
||||
def testSquareplusGrad(self):
|
||||
check_grads(nn.squareplus, (1e-8,), order=4,
|
||||
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
||||
|
||||
def testSquareplusGradZero(self):
|
||||
check_grads(nn.squareplus, (0.,), order=1,
|
||||
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
||||
|
||||
def testSquareplusGradNegInf(self):
|
||||
check_grads(nn.squareplus, (-float('inf'),), order=1,
|
||||
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
||||
|
||||
def testSquareplusGradNan(self):
|
||||
check_grads(nn.squareplus, (float('nan'),), order=1,
|
||||
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
||||
|
||||
@parameterized.parameters([float] + jtu.dtypes.floating)
|
||||
def testSquareplusZero(self, dtype):
|
||||
self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4)))
|
||||
|
||||
def testReluGrad(self):
|
||||
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
|
||||
check_grads(nn.relu, (1.,), order=3, rtol=rtol)
|
||||
@ -81,6 +101,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
val = nn.softplus(89.)
|
||||
self.assertAllClose(val, 89., check_dtypes=False)
|
||||
|
||||
def testSquareplusValue(self):
|
||||
val = nn.squareplus(1e3)
|
||||
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testEluGrad(self):
|
||||
check_grads(nn.elu, (1e4,), order=4, eps=1.)
|
||||
@ -113,7 +137,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
(jnp.float32, jnp.bfloat16, jnp.float16),
|
||||
(partial(nn.gelu, approximate=False),
|
||||
partial(nn.gelu, approximate=True),
|
||||
nn.relu, nn.softplus, nn.sigmoid)))
|
||||
nn.relu, nn.softplus, nn.sigmoid, nn.squareplus)))
|
||||
def testDtypeMatchesInput(self, dtype, fn):
|
||||
x = jnp.zeros((), dtype=dtype)
|
||||
out = fn(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user