diff --git a/jax/_src/random.py b/jax/_src/random.py index 69bd8d4b9..cc1a4a38f 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -30,6 +30,7 @@ from jax.numpy.linalg import cholesky, svd, eigh from jax._src import config from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import prng from jax._src import xla_bridge @@ -2615,7 +2616,7 @@ def binomial( # Functions related to key reuse checking random_clone_p = core.Primitive("random_clone") -random_clone_p.def_impl(lambda x: x) +dispatch.simple_impl(random_clone_p) random_clone_p.def_abstract_eval(lambda x: x) batching.defvectorized(random_clone_p) mlir.register_lowering(random_clone_p, lambda _, k: [k]) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 697b1e318..5d5d6e12e 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -595,6 +595,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase): eager_bits_msg = "Previously-consumed key passed to random_bits at index 0" traced_bits_msg = "In random_bits, argument 0 is already consumed." + def test_clone_eager(self): + key = jax.random.key(0) + key2 = jax.random.clone(key) + self.assertIsNot(key, key2) + + _ = jax.random.uniform(key) + self.assertTrue(key._consumed) + self.assertFalse(key2._consumed) + def test_simple_reuse_nojit(self): key = jax.random.key(0) _ = jax.random.bits(key)