mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add identity activation
Fix typo
This commit is contained in:
parent
540541a3d3
commit
5d79df7e67
@ -40,6 +40,7 @@ Activation functions
|
||||
glu
|
||||
squareplus
|
||||
mish
|
||||
identity
|
||||
|
||||
Other functions
|
||||
---------------
|
||||
|
@ -54,6 +54,25 @@ _UNSPECIFIED = Unspecified()
|
||||
|
||||
|
||||
# activations
|
||||
@jax.jit
|
||||
def identity(x: ArrayLike) -> Array:
|
||||
r"""Identity activation function.
|
||||
|
||||
Returns the argument unmodified.
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
|
||||
Returns:
|
||||
The argument `x` unmodified.
|
||||
|
||||
Examples:
|
||||
>>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
|
||||
Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)
|
||||
|
||||
"""
|
||||
numpy_util.check_arraylike("identity", x)
|
||||
return jnp.asarray(x)
|
||||
|
||||
@custom_jvp
|
||||
@jax.jit
|
||||
|
@ -35,6 +35,7 @@ from jax._src.nn.functions import (
|
||||
standardize as standardize,
|
||||
one_hot as one_hot,
|
||||
relu as relu,
|
||||
identity as identity,
|
||||
relu6 as relu6,
|
||||
dot_product_attention as dot_product_attention,
|
||||
scaled_dot_general as scaled_dot_general,
|
||||
|
@ -543,7 +543,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
(jnp.float32, jnp.bfloat16, jnp.float16),
|
||||
(partial(nn.gelu, approximate=False),
|
||||
partial(nn.gelu, approximate=True),
|
||||
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
|
||||
nn.relu, nn.identity, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
|
||||
def testDtypeMatchesInput(self, dtype, fn):
|
||||
x = jnp.zeros((), dtype=dtype)
|
||||
out = fn(x)
|
||||
@ -831,6 +831,12 @@ class NNInitializersTest(jtu.JaxTestCase):
|
||||
):
|
||||
initializer(rng, shape)
|
||||
|
||||
def testIdentity(self):
|
||||
x = jnp.array([1., 2., 3.])
|
||||
self.assertAllClose(nn.identity(x), x, check_dtypes=False)
|
||||
grad = jax.grad(nn.identity)(6.0)
|
||||
self.assertEqual(grad, 1.)
|
||||
|
||||
def testAccidentalUpcasting(self):
|
||||
rng = random.PRNGKey(0)
|
||||
shape = (4, 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user