From 9f8e1bc34a9398e214b25d34471018fa4c21b371 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Tue, 14 Nov 2023 23:52:41 -0500 Subject: [PATCH] Add nn.squareplus. --- CHANGELOG.md | 3 +++ docs/faq.rst | 4 ++-- docs/jax.nn.rst | 1 + jax/_src/nn/functions.py | 22 ++++++++++++++++++++++ jax/nn/__init__.py | 1 + tests/nn_test.py | 26 +++++++++++++++++++++++++- 6 files changed, 54 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d72975b5..becdd9b0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.21 +* New Features + * Added {obj}`jax.nn.squareplus`. + * Deprecations * The previously-deprecated `sym_pos` argument has been removed from {func}`jax.scipy.linalg.solve`. Use `assume_a='pos'` instead. diff --git a/docs/faq.rst b/docs/faq.rst index 0f8709b4f..441af2436 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -770,8 +770,8 @@ well-defined gradients:: The :mod:`jax.nn` submodule also has smooth versions of other common rank-based functions, for example :func:`jax.nn.softmax` can replace uses of :func:`jax.numpy.argmax`, :func:`jax.nn.soft_sign` can replace uses of -:func:`jax.numpy.sign`, :func:`jax.nn.softplus` can replace uses of -:func:`jax.nn.relu`, etc. +:func:`jax.numpy.sign`, :func:`jax.nn.softplus` or :func:`jax.nn.squareplus` +can replace uses of :func:`jax.nn.relu`, etc. How can I convert a JAX Tracer to a NumPy array? ------------------------------------------------ diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index 31ad050df..ae2c6f24a 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -36,6 +36,7 @@ Activation functions selu gelu glu + squareplus Other functions --------------- diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 2d9b7c431..896717796 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -72,6 +72,28 @@ def relu(x: ArrayLike) -> Array: # For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0))) +@jax.jit +def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array: + r"""Squareplus activation function. + + Computes the element-wise function + + .. math:: + \mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2} + + as described in https://arxiv.org/abs/2112.11687. + + Args: + x : input array + b : smoothness parameter + """ + numpy_util.check_arraylike("squareplus", x) + numpy_util.check_arraylike("squareplus", b) + x = jnp.asarray(x) + b = jnp.asarray(b) + y = x + jnp.sqrt(jnp.square(x) + b) + return y / 2 + @jax.jit def softplus(x: ArrayLike) -> Array: r"""Softplus activation function. diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 07ed7716b..0f008b30e 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -43,6 +43,7 @@ from jax._src.nn.functions import ( softplus as softplus, silu as silu, swish as swish, + squareplus as squareplus, ) # Deprecations diff --git a/tests/nn_test.py b/tests/nn_test.py index 3012fb981..3c361ef7d 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -63,6 +63,26 @@ class NNFunctionsTest(jtu.JaxTestCase): def testSoftplusZero(self, dtype): self.assertEqual(jnp.log(dtype(2)), nn.softplus(dtype(0))) + def testSquareplusGrad(self): + check_grads(nn.squareplus, (1e-8,), order=4, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + def testSquareplusGradZero(self): + check_grads(nn.squareplus, (0.,), order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + def testSquareplusGradNegInf(self): + check_grads(nn.squareplus, (-float('inf'),), order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + def testSquareplusGradNan(self): + check_grads(nn.squareplus, (float('nan'),), order=1, + rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) + + @parameterized.parameters([float] + jtu.dtypes.floating) + def testSquareplusZero(self, dtype): + self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4))) + def testReluGrad(self): rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None check_grads(nn.relu, (1.,), order=3, rtol=rtol) @@ -81,6 +101,10 @@ class NNFunctionsTest(jtu.JaxTestCase): val = nn.softplus(89.) self.assertAllClose(val, 89., check_dtypes=False) + def testSquareplusValue(self): + val = nn.squareplus(1e3) + self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testEluGrad(self): check_grads(nn.elu, (1e4,), order=4, eps=1.) @@ -113,7 +137,7 @@ class NNFunctionsTest(jtu.JaxTestCase): (jnp.float32, jnp.bfloat16, jnp.float16), (partial(nn.gelu, approximate=False), partial(nn.gelu, approximate=True), - nn.relu, nn.softplus, nn.sigmoid))) + nn.relu, nn.softplus, nn.sigmoid, nn.squareplus))) def testDtypeMatchesInput(self, dtype, fn): x = jnp.zeros((), dtype=dtype) out = fn(x)