[key reuse] fix random_clone impl rule

This commit is contained in:
Jake VanderPlas 2024-03-08 15:16:39 -08:00
parent 5e039f7af5
commit d1e49f9c89
2 changed files with 11 additions and 1 deletions

View File

@ -30,6 +30,7 @@ from jax.numpy.linalg import cholesky, svd, eigh
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import prng
from jax._src import xla_bridge
@ -2615,7 +2616,7 @@ def binomial(
# Functions related to key reuse checking
random_clone_p = core.Primitive("random_clone")
random_clone_p.def_impl(lambda x: x)
dispatch.simple_impl(random_clone_p)
random_clone_p.def_abstract_eval(lambda x: x)
batching.defvectorized(random_clone_p)
mlir.register_lowering(random_clone_p, lambda _, k: [k])

View File

@ -595,6 +595,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
traced_bits_msg = "In random_bits, argument 0 is already consumed."
def test_clone_eager(self):
key = jax.random.key(0)
key2 = jax.random.clone(key)
self.assertIsNot(key, key2)
_ = jax.random.uniform(key)
self.assertTrue(key._consumed)
self.assertFalse(key2._consumed)
def test_simple_reuse_nojit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)