From 2ba9b4527735b7a17edb7196607b530fc9827424 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 12 Mar 2024 16:49:16 -0700 Subject: [PATCH] [key-reuse] fix flaky test --- tests/key_reuse_test.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 8bc612163..5935fb9ab 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -608,13 +608,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase): traced_bits_msg = "In random_bits, argument 0 is already consumed." def test_clone_eager(self): - key = jax.random.key(0) - key2 = jax.random.clone(key) - self.assertIsNot(key, key2) + # TODO(b/329326258): run this test under JIT + with jax.disable_jit(): + key = jax.random.key(0) + key2 = jax.random.clone(key) + self.assertIsNot(key, key2) - _ = jax.random.uniform(key) - self.assertTrue(key._consumed) - self.assertFalse(key2._consumed) + _ = jax.random.uniform(key) + self.assertTrue(key._consumed) + self.assertFalse(key2._consumed) def test_simple_reuse_nojit(self): key = jax.random.key(0)