mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #12496 from mattjj:improve-leak-checker-2
PiperOrigin-RevId: 476907407
This commit is contained in:
commit
28672cca0e
@ -227,7 +227,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
a = jnp.array(1., 'float32')
|
||||
|
||||
def f(hx, _):
|
||||
hx = jax.nn.relu(hx + a)
|
||||
hx = sigmoid(hx + a)
|
||||
return hx, None
|
||||
|
||||
hx = jnp.array(0., 'float32')
|
||||
|
Loading…
x
Reference in New Issue
Block a user