mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
cc9528d97d
commit
e2424e3b24
@ -2534,13 +2534,17 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
def run():
|
||||
return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0]))
|
||||
|
||||
expected = run()
|
||||
|
||||
# we just don't want this to crash
|
||||
n_workers = 20
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e:
|
||||
futures = []
|
||||
for _ in range(n_workers):
|
||||
futures.append(e.submit(run))
|
||||
_ = [f.result() for f in futures]
|
||||
results = [f.result() for f in futures]
|
||||
for ans in results:
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user