add regression test for #7613

This commit is contained in:
Matthew Johnson 2021-08-12 21:49:17 -07:00
parent 563e08fa31
commit a0b9946a30

View File

@ -2931,6 +2931,24 @@ class APITest(jtu.JaxTestCase):
expected = (6.,)
self.assertEqual(actual, expected)
def test_leaked_tracer_issue_7613(self):
# from https://github.com/google/jax/issues/7613
import numpy.random as npr
def sigmoid(x):
return 1. / (1. + jnp.exp(-x))
x = jnp.ones((50,))
A = jnp.array(npr.randn(50, 50))
@jax.jit
def loss(A, x):
h = jax.nn.sigmoid(A * x)
return jnp.sum((h - x)**2)
with jax.checking_leaks():
_ = jax.grad(loss)(A, x) # doesn't crash
class RematTest(jtu.JaxTestCase):