mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add regression test for #7613
This commit is contained in:
parent
563e08fa31
commit
a0b9946a30
@ -2931,6 +2931,24 @@ class APITest(jtu.JaxTestCase):
|
||||
expected = (6.,)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_leaked_tracer_issue_7613(self):
|
||||
# from https://github.com/google/jax/issues/7613
|
||||
import numpy.random as npr
|
||||
|
||||
def sigmoid(x):
|
||||
return 1. / (1. + jnp.exp(-x))
|
||||
|
||||
x = jnp.ones((50,))
|
||||
A = jnp.array(npr.randn(50, 50))
|
||||
|
||||
@jax.jit
|
||||
def loss(A, x):
|
||||
h = jax.nn.sigmoid(A * x)
|
||||
return jnp.sum((h - x)**2)
|
||||
|
||||
with jax.checking_leaks():
|
||||
_ = jax.grad(loss)(A, x) # doesn't crash
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user