mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[key reuse] fix random_clone impl rule
This commit is contained in:
parent
5e039f7af5
commit
d1e49f9c89
@ -30,6 +30,7 @@ from jax.numpy.linalg import cholesky, svd, eigh
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import prng
|
||||
from jax._src import xla_bridge
|
||||
@ -2615,7 +2616,7 @@ def binomial(
|
||||
|
||||
# Functions related to key reuse checking
|
||||
random_clone_p = core.Primitive("random_clone")
|
||||
random_clone_p.def_impl(lambda x: x)
|
||||
dispatch.simple_impl(random_clone_p)
|
||||
random_clone_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(random_clone_p)
|
||||
mlir.register_lowering(random_clone_p, lambda _, k: [k])
|
||||
|
@ -595,6 +595,15 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
|
||||
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
|
||||
traced_bits_msg = "In random_bits, argument 0 is already consumed."
|
||||
|
||||
def test_clone_eager(self):
|
||||
key = jax.random.key(0)
|
||||
key2 = jax.random.clone(key)
|
||||
self.assertIsNot(key, key2)
|
||||
|
||||
_ = jax.random.uniform(key)
|
||||
self.assertTrue(key._consumed)
|
||||
self.assertFalse(key2._consumed)
|
||||
|
||||
def test_simple_reuse_nojit(self):
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.bits(key)
|
||||
|
Loading…
x
Reference in New Issue
Block a user