Add key reuse config to trace context

This commit is contained in:
Jake VanderPlas 2024-03-14 06:59:37 -07:00
parent 64bd95ded5
commit ae4e273b74
2 changed files with 7 additions and 8 deletions

View File

@ -213,6 +213,7 @@ def trace_context():
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
enable_key_reuse_checks.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value)

View File

@ -608,15 +608,13 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
traced_bits_msg = "In random_bits, argument 0 is already consumed."
def test_clone_eager(self):
# TODO(b/329326258): run this test under JIT
with jax.disable_jit():
key = jax.random.key(0)
key2 = jax.random.clone(key)
self.assertIsNot(key, key2)
key = jax.random.key(0)
key2 = jax.random.clone(key)
self.assertIsNot(key, key2)
_ = jax.random.uniform(key)
self.assertTrue(key._consumed)
self.assertFalse(key2._consumed)
_ = jax.random.uniform(key)
self.assertTrue(key._consumed)
self.assertFalse(key2._consumed)
def test_simple_reuse_nojit(self):
key = jax.random.key(0)