fix typo in test

This commit is contained in:
Matthew Johnson 2022-09-23 12:42:15 -07:00
parent e76aa77895
commit 03abcc7c5c

View File

@ -227,7 +227,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
a = jnp.array(1., 'float32')
def f(hx, _):
hx = jax.nn.relu(hx + a)
hx = sigmoid(hx + a)
return hx, None
hx = jnp.array(0., 'float32')