[key-reuse] fix flaky test

This commit is contained in:
Jake VanderPlas 2024-03-12 16:49:16 -07:00
parent 75dbd30a93
commit 2ba9b45277

View File

@ -608,13 +608,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
traced_bits_msg = "In random_bits, argument 0 is already consumed."
def test_clone_eager(self):
key = jax.random.key(0)
key2 = jax.random.clone(key)
self.assertIsNot(key, key2)
# TODO(b/329326258): run this test under JIT
with jax.disable_jit():
key = jax.random.key(0)
key2 = jax.random.clone(key)
self.assertIsNot(key, key2)
_ = jax.random.uniform(key)
self.assertTrue(key._consumed)
self.assertFalse(key2._consumed)
_ = jax.random.uniform(key)
self.assertTrue(key._consumed)
self.assertFalse(key2._consumed)
def test_simple_reuse_nojit(self):
key = jax.random.key(0)