mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] add jax.random.clone
This commit is contained in:
parent
4244b218ca
commit
6771a59181
@ -9,5 +9,4 @@ API
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
reuse_key
|
||||
KeyReuseError
|
||||
|
@ -18,6 +18,7 @@ Key Creation & Manipulation
|
||||
wrap_key_data
|
||||
fold_in
|
||||
split
|
||||
clone
|
||||
|
||||
Random Samplers
|
||||
~~~~~~~~~~~~~~~
|
||||
|
@ -394,9 +394,7 @@ def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
|
||||
i_ = length - i - 1 if reverse else i
|
||||
# TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right,
|
||||
# because the scan body may consume any keys within it.
|
||||
# Import here to avoid circular imports
|
||||
from jax.experimental import key_reuse
|
||||
xs_unconsumed = _map(key_reuse.reuse_key, xs)
|
||||
xs_unconsumed = _map(jax.random.clone, xs)
|
||||
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
|
||||
out_flat = f_impl(*consts, *carry, *x)
|
||||
carry_out, y_updates = split_list(out_flat, [num_carry])
|
||||
|
@ -1377,26 +1377,3 @@ unsafe_rbg_prng_impl = PRNGImpl(
|
||||
tag='urbg')
|
||||
|
||||
register_prng(unsafe_rbg_prng_impl)
|
||||
|
||||
|
||||
# Primitives related to key reuse
|
||||
reuse_key_p = core.Primitive("reuse_key")
|
||||
reuse_key_p.def_impl(lambda x: x)
|
||||
reuse_key_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(reuse_key_p)
|
||||
mlir.register_lowering(reuse_key_p, lambda _, k: [k])
|
||||
|
||||
def reuse_key(key):
|
||||
"""Explicitly mark a key as unconsumed.
|
||||
|
||||
Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`)
|
||||
this function operates as an identity.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax
|
||||
>>> key = jax.random.key(0)
|
||||
>>> data = jax.random.uniform(key)
|
||||
>>> same_data = jax.random.uniform(reuse_key(key))
|
||||
"""
|
||||
return reuse_key_p.bind(key)
|
||||
|
@ -2611,3 +2611,28 @@ def binomial(
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
return _binomial(key, n, p, shape, dtype)
|
||||
|
||||
|
||||
# Functions related to key reuse checking
|
||||
random_clone_p = core.Primitive("random_clone")
|
||||
random_clone_p.def_impl(lambda x: x)
|
||||
random_clone_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(random_clone_p)
|
||||
mlir.register_lowering(random_clone_p, lambda _, k: [k])
|
||||
|
||||
def clone(key):
|
||||
"""Clone a key for reuse
|
||||
|
||||
Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`)
|
||||
this function operates as an identity.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax
|
||||
>>> key = jax.random.key(0)
|
||||
>>> data = jax.random.uniform(key)
|
||||
>>> cloned_key = jax.random.clone(key)
|
||||
>>> same_data = jax.random.uniform(cloned_key)
|
||||
>>> assert data == same_data
|
||||
"""
|
||||
return random_clone_p.bind(key)
|
||||
|
@ -1528,7 +1528,7 @@ tf_not_yet_impl = [
|
||||
"consume",
|
||||
]
|
||||
|
||||
tf_impl[prng.reuse_key_p] = lambda x: x
|
||||
tf_impl[random_internal.random_clone_p] = lambda x: x
|
||||
|
||||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||||
|
||||
|
@ -39,10 +39,6 @@ context manager::
|
||||
... print(jax.random.normal(key))
|
||||
-0.20584226
|
||||
"""
|
||||
from jax._src.prng import (
|
||||
reuse_key as reuse_key,
|
||||
)
|
||||
|
||||
from jax.experimental.key_reuse._core import (
|
||||
KeyReuseError as KeyReuseError,
|
||||
)
|
||||
|
@ -149,7 +149,7 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
|
||||
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[random.random_clone_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
|
@ -144,6 +144,7 @@ from jax._src.random import (
|
||||
cauchy as cauchy,
|
||||
chisquare as chisquare,
|
||||
choice as choice,
|
||||
clone as clone,
|
||||
dirichlet as dirichlet,
|
||||
double_sided_maxwell as double_sided_maxwell,
|
||||
exponential as exponential,
|
||||
|
@ -21,6 +21,7 @@ import jax
|
||||
from jax import core
|
||||
import jax.numpy as jnp
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.key_reuse._core import (
|
||||
assert_consumed, assert_unconsumed, consume, consume_p)
|
||||
@ -36,7 +37,7 @@ key1D = jax.eval_shape(lambda key: key[None], key)
|
||||
|
||||
primitives_with_static_signatures = {
|
||||
consume_p: (consume, key),
|
||||
prng.reuse_key_p: (prng.reuse_key, key),
|
||||
random.random_clone_p: (random.clone, key),
|
||||
prng.random_bits_p: (jax.random.bits, key),
|
||||
# prng.random_fold_in_p: (jax.random.fold_in, key, 2),
|
||||
prng.random_seed_p: (jax.random.key, 0),
|
||||
@ -91,12 +92,12 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
assert_consumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
def test_reuse_key(self):
|
||||
def test_random_clone(self):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
consume(key)
|
||||
assert_consumed(key)
|
||||
key2 = prng.reuse_key(key)
|
||||
key2 = jax.random.clone(key)
|
||||
assert_unconsumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
|
@ -588,6 +588,14 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
key = random.key(42)
|
||||
self.assertIsInstance(key, prng_internal.PRNGKeyArray)
|
||||
|
||||
def test_random_clone(self):
|
||||
# Here we test value semantics and compatibility with jit/vmap
|
||||
# key reuse semantics are tested in key_reuse_test.py
|
||||
keys = jax.random.split(jax.random.key(0), 5)
|
||||
self.assertKeysEqual(keys, jax.random.clone(keys))
|
||||
self.assertKeysEqual(keys, jax.jit(jax.random.clone)(keys))
|
||||
self.assertKeysEqual(keys, jax.vmap(jax.random.clone)(keys))
|
||||
|
||||
def test_issubdtype(self):
|
||||
key = random.key(42)
|
||||
|
||||
|
@ -29,7 +29,6 @@ from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import tuple_insert
|
||||
import jax.numpy as jnp
|
||||
@ -1735,8 +1734,8 @@ if CAN_USE_HYPOTHESIS:
|
||||
y, impl_vjp = jax.vjp(impl, x)
|
||||
y_ref, ref_vjp = jax.vjp(ref, x)
|
||||
self.assertAllClose(y, y_ref)
|
||||
t = random.normal(prng.reuse_key(k2), x.shape)
|
||||
y2 = random.normal(prng.reuse_key(k1), y.shape)
|
||||
t = random.normal(jax.random.clone(k2), x.shape)
|
||||
y2 = random.normal(jax.random.clone(k1), y.shape)
|
||||
self.assertAllClose(impl_vjp(t), ref_vjp(t))
|
||||
|
||||
# Second order
|
||||
@ -1752,7 +1751,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
(x,), impl_vjp2 = jax.vjp(impl_vjp, t2)
|
||||
(x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2)
|
||||
self.assertAllClose(x, x_ref)
|
||||
y2 = random.normal(prng.reuse_key(k1), y.shape)
|
||||
y2 = random.normal(jax.random.clone(k1), y.shape)
|
||||
self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user