mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key-reuse] fix flaky test
This commit is contained in:
parent
75dbd30a93
commit
2ba9b45277
@ -608,13 +608,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
|
||||
traced_bits_msg = "In random_bits, argument 0 is already consumed."
|
||||
|
||||
def test_clone_eager(self):
|
||||
key = jax.random.key(0)
|
||||
key2 = jax.random.clone(key)
|
||||
self.assertIsNot(key, key2)
|
||||
# TODO(b/329326258): run this test under JIT
|
||||
with jax.disable_jit():
|
||||
key = jax.random.key(0)
|
||||
key2 = jax.random.clone(key)
|
||||
self.assertIsNot(key, key2)
|
||||
|
||||
_ = jax.random.uniform(key)
|
||||
self.assertTrue(key._consumed)
|
||||
self.assertFalse(key2._consumed)
|
||||
_ = jax.random.uniform(key)
|
||||
self.assertTrue(key._consumed)
|
||||
self.assertFalse(key2._consumed)
|
||||
|
||||
def test_simple_reuse_nojit(self):
|
||||
key = jax.random.key(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user