attempt to fix CI failure (from #3845 test?) (#3846)

This commit is contained in:
Matthew Johnson 2020-07-23 20:59:12 -07:00 committed by GitHub
parent cc9528d97d
commit e2424e3b24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2534,13 +2534,17 @@ class CustomJVPTest(jtu.JaxTestCase):
def run():
return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0]))
expected = run()
# we just don't want this to crash
n_workers = 20
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e:
futures = []
for _ in range(n_workers):
futures.append(e.submit(run))
_ = [f.result() for f in futures]
results = [f.result() for f in futures]
for ans in results:
self.assertAllClose(ans, expected)
class CustomVJPTest(jtu.JaxTestCase):