Add identity activation

Fix typo
This commit is contained in:
Jesse Perla 2025-03-23 15:03:49 -07:00
parent 540541a3d3
commit 5d79df7e67
4 changed files with 28 additions and 1 deletions

View File

@ -40,6 +40,7 @@ Activation functions
glu
squareplus
mish
identity
Other functions
---------------

View File

@ -54,6 +54,25 @@ _UNSPECIFIED = Unspecified()
# activations
@jax.jit
def identity(x: ArrayLike) -> Array:
r"""Identity activation function.
Returns the argument unmodified.
Args:
x : input array
Returns:
The argument `x` unmodified.
Examples:
>>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)
"""
numpy_util.check_arraylike("identity", x)
return jnp.asarray(x)
@custom_jvp
@jax.jit

View File

@ -35,6 +35,7 @@ from jax._src.nn.functions import (
standardize as standardize,
one_hot as one_hot,
relu as relu,
identity as identity,
relu6 as relu6,
dot_product_attention as dot_product_attention,
scaled_dot_general as scaled_dot_general,

View File

@ -543,7 +543,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
(jnp.float32, jnp.bfloat16, jnp.float16),
(partial(nn.gelu, approximate=False),
partial(nn.gelu, approximate=True),
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
nn.relu, nn.identity, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
def testDtypeMatchesInput(self, dtype, fn):
x = jnp.zeros((), dtype=dtype)
out = fn(x)
@ -831,6 +831,12 @@ class NNInitializersTest(jtu.JaxTestCase):
):
initializer(rng, shape)
def testIdentity(self):
x = jnp.array([1., 2., 3.])
self.assertAllClose(nn.identity(x), x, check_dtypes=False)
grad = jax.grad(nn.identity)(6.0)
self.assertEqual(grad, 1.)
def testAccidentalUpcasting(self):
rng = random.PRNGKey(0)
shape = (4, 4)