mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add key reuse config to trace context
This commit is contained in:
parent
64bd95ded5
commit
ae4e273b74
@ -213,6 +213,7 @@ def trace_context():
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
enable_key_reuse_checks.value,
|
||||
jax_xla_profile_version.value,
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value)
|
||||
|
@ -608,15 +608,13 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
|
||||
traced_bits_msg = "In random_bits, argument 0 is already consumed."
|
||||
|
||||
def test_clone_eager(self):
|
||||
# TODO(b/329326258): run this test under JIT
|
||||
with jax.disable_jit():
|
||||
key = jax.random.key(0)
|
||||
key2 = jax.random.clone(key)
|
||||
self.assertIsNot(key, key2)
|
||||
key = jax.random.key(0)
|
||||
key2 = jax.random.clone(key)
|
||||
self.assertIsNot(key, key2)
|
||||
|
||||
_ = jax.random.uniform(key)
|
||||
self.assertTrue(key._consumed)
|
||||
self.assertFalse(key2._consumed)
|
||||
_ = jax.random.uniform(key)
|
||||
self.assertTrue(key._consumed)
|
||||
self.assertFalse(key2._consumed)
|
||||
|
||||
def test_simple_reuse_nojit(self):
|
||||
key = jax.random.key(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user