Merge pull request #12496 from mattjj:improve-leak-checker-2

PiperOrigin-RevId: 476907407
This commit is contained in:
jax authors 2022-09-26 08:50:13 -07:00
commit 28672cca0e

View File

@ -227,7 +227,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
a = jnp.array(1., 'float32')
def f(hx, _):
hx = jax.nn.relu(hx + a)
hx = sigmoid(hx + a)
return hx, None
hx = jnp.array(0., 'float32')