[key reuse] add jax.random.clone

This commit is contained in:
Jake VanderPlas 2024-03-08 09:06:00 -08:00
parent 4244b218ca
commit 6771a59181
12 changed files with 45 additions and 40 deletions

View File

@ -9,5 +9,4 @@ API
.. autosummary::
:toctree: _autosummary
reuse_key
KeyReuseError

View File

@ -18,6 +18,7 @@ Key Creation & Manipulation
wrap_key_data
fold_in
split
clone
Random Samplers
~~~~~~~~~~~~~~~

View File

@ -394,9 +394,7 @@ def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
i_ = length - i - 1 if reverse else i
# TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right,
# because the scan body may consume any keys within it.
# Import here to avoid circular imports
from jax.experimental import key_reuse
xs_unconsumed = _map(key_reuse.reuse_key, xs)
xs_unconsumed = _map(jax.random.clone, xs)
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])

View File

@ -1377,26 +1377,3 @@ unsafe_rbg_prng_impl = PRNGImpl(
tag='urbg')
register_prng(unsafe_rbg_prng_impl)
# Primitives related to key reuse
reuse_key_p = core.Primitive("reuse_key")
reuse_key_p.def_impl(lambda x: x)
reuse_key_p.def_abstract_eval(lambda x: x)
batching.defvectorized(reuse_key_p)
mlir.register_lowering(reuse_key_p, lambda _, k: [k])
def reuse_key(key):
"""Explicitly mark a key as unconsumed.
Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`)
this function operates as an identity.
Example:
>>> import jax
>>> key = jax.random.key(0)
>>> data = jax.random.uniform(key)
>>> same_data = jax.random.uniform(reuse_key(key))
"""
return reuse_key_p.bind(key)

View File

@ -2611,3 +2611,28 @@ def binomial(
if shape is not None:
shape = core.canonicalize_shape(shape)
return _binomial(key, n, p, shape, dtype)
# Functions related to key reuse checking
random_clone_p = core.Primitive("random_clone")
random_clone_p.def_impl(lambda x: x)
random_clone_p.def_abstract_eval(lambda x: x)
batching.defvectorized(random_clone_p)
mlir.register_lowering(random_clone_p, lambda _, k: [k])
def clone(key):
"""Clone a key for reuse
Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`)
this function operates as an identity.
Example:
>>> import jax
>>> key = jax.random.key(0)
>>> data = jax.random.uniform(key)
>>> cloned_key = jax.random.clone(key)
>>> same_data = jax.random.uniform(cloned_key)
>>> assert data == same_data
"""
return random_clone_p.bind(key)

View File

@ -1528,7 +1528,7 @@ tf_not_yet_impl = [
"consume",
]
tf_impl[prng.reuse_key_p] = lambda x: x
tf_impl[random_internal.random_clone_p] = lambda x: x
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient

View File

@ -39,10 +39,6 @@ context manager::
... print(jax.random.normal(key))
-0.20584226
"""
from jax._src.prng import (
reuse_key as reuse_key,
)
from jax.experimental.key_reuse._core import (
KeyReuseError as KeyReuseError,
)

View File

@ -149,7 +149,7 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)])
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[random.random_clone_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
# TODO(jakevdp): should fold_in sink its input key?
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])

View File

@ -144,6 +144,7 @@ from jax._src.random import (
cauchy as cauchy,
chisquare as chisquare,
choice as choice,
clone as clone,
dirichlet as dirichlet,
double_sided_maxwell as double_sided_maxwell,
exponential as exponential,

View File

@ -21,6 +21,7 @@ import jax
from jax import core
import jax.numpy as jnp
from jax._src import prng
from jax._src import random
from jax._src import test_util as jtu
from jax.experimental.key_reuse._core import (
assert_consumed, assert_unconsumed, consume, consume_p)
@ -36,7 +37,7 @@ key1D = jax.eval_shape(lambda key: key[None], key)
primitives_with_static_signatures = {
consume_p: (consume, key),
prng.reuse_key_p: (prng.reuse_key, key),
random.random_clone_p: (random.clone, key),
prng.random_bits_p: (jax.random.bits, key),
# prng.random_fold_in_p: (jax.random.fold_in, key, 2),
prng.random_seed_p: (jax.random.key, 0),
@ -91,12 +92,12 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_reuse_key(self):
def test_random_clone(self):
def f(key):
assert_unconsumed(key)
consume(key)
assert_consumed(key)
key2 = prng.reuse_key(key)
key2 = jax.random.clone(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))

View File

@ -588,6 +588,14 @@ class KeyArrayTest(jtu.JaxTestCase):
key = random.key(42)
self.assertIsInstance(key, prng_internal.PRNGKeyArray)
def test_random_clone(self):
# Here we test value semantics and compatibility with jit/vmap
# key reuse semantics are tested in key_reuse_test.py
keys = jax.random.split(jax.random.key(0), 5)
self.assertKeysEqual(keys, jax.random.clone(keys))
self.assertKeysEqual(keys, jax.jit(jax.random.clone)(keys))
self.assertKeysEqual(keys, jax.vmap(jax.random.clone)(keys))
def test_issubdtype(self):
key = random.key(42)

View File

@ -29,7 +29,6 @@ from jax._src import core
from jax._src import config
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.util import tuple_insert
import jax.numpy as jnp
@ -1735,8 +1734,8 @@ if CAN_USE_HYPOTHESIS:
y, impl_vjp = jax.vjp(impl, x)
y_ref, ref_vjp = jax.vjp(ref, x)
self.assertAllClose(y, y_ref)
t = random.normal(prng.reuse_key(k2), x.shape)
y2 = random.normal(prng.reuse_key(k1), y.shape)
t = random.normal(jax.random.clone(k2), x.shape)
y2 = random.normal(jax.random.clone(k1), y.shape)
self.assertAllClose(impl_vjp(t), ref_vjp(t))
# Second order
@ -1752,7 +1751,7 @@ if CAN_USE_HYPOTHESIS:
(x,), impl_vjp2 = jax.vjp(impl_vjp, t2)
(x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2)
self.assertAllClose(x, x_ref)
y2 = random.normal(prng.reuse_key(k1), y.shape)
y2 = random.normal(jax.random.clone(k1), y.shape)
self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,)))
if __name__ == '__main__':