1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Define reuse_key primitive in jax._src.prng

This commit is contained in:
Jake VanderPlas 2024-02-14 14:01:08 -08:00
parent b9824d7de3
commit 49eb7008c0
9 changed files with 40 additions and 28 deletions

@ -9,5 +9,5 @@ API
.. autosummary::
:toctree: _autosummary
unconsumed_copy
reuse_key
KeyReuseError

@ -396,7 +396,7 @@ def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
# because the scan body may consume any keys within it.
# Import here to avoid circular imports
from jax.experimental import key_reuse
xs_unconsumed = _map(key_reuse.unconsumed_copy, xs)
xs_unconsumed = _map(key_reuse.reuse_key, xs)
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])

@ -1338,3 +1338,26 @@ unsafe_rbg_prng_impl = PRNGImpl(
tag='urbg')
register_prng(unsafe_rbg_prng_impl)
# Primitives related to key reuse
reuse_key_p = core.Primitive("reuse_key")
reuse_key_p.def_impl(lambda x: x)
reuse_key_p.def_abstract_eval(lambda x: x)
batching.defvectorized(reuse_key_p)
mlir.register_lowering(reuse_key_p, lambda _, k: [k])
def reuse_key(key):
"""Explicitly mark a key as unconsumed.
Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`)
this function operates as an identity.
Example:
>>> import jax
>>> key = jax.random.key(0)
>>> data = jax.random.uniform(key)
>>> same_data = jax.random.uniform(reuse_key(key))
"""
return reuse_key_p.bind(key)

@ -69,7 +69,6 @@ from jax._src.lax import slicing as lax_slicing
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.numpy.ufuncs import logaddexp
from jax.experimental.key_reuse._common import unconsumed_copy_p
import tensorflow as tf # type: ignore[import]
@ -1529,7 +1528,7 @@ tf_not_yet_impl = [
"consume",
]
tf_impl[unconsumed_copy_p] = lambda x: x
tf_impl[prng.reuse_key_p] = lambda x: x
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient

@ -38,8 +38,10 @@ Key reuse checking can be enabled on `jit`-compiled functions using the
This flag can also be set globally if you wish to enagle key reuse checks in
every JIT-compiled function.
"""
from jax._src.prng import (
reuse_key as reuse_key,
)
from jax.experimental.key_reuse._common import (
unconsumed_copy as unconsumed_copy,
KeyReuseError as KeyReuseError,
)

@ -62,17 +62,6 @@ def consume(key):
"""Consume the key and return a consumed copy."""
return consume_p.bind(key)
unconsumed_copy_p = core.Primitive("unconsumed_copy")
unconsumed_copy_p.def_impl(lambda x: x)
unconsumed_copy_p.def_abstract_eval(lambda x: x)
batching.defvectorized(unconsumed_copy_p)
mlir.register_lowering(
unconsumed_copy_p,
mlir.lower_fun(lambda x: x, multiple_results=False))
def unconsumed_copy(key):
"""Return a copy of key marked as unconsumed."""
return unconsumed_copy_p.bind(key)
assert_consumed_value_p = core.Primitive("assert_consumed_value")
assert_consumed_value_p.def_impl(lambda x, *, value: x)

@ -33,7 +33,7 @@ from jax._src.debugging import debug_callback_p
from jax._src.interpreters import partial_eval as pe
from jax.experimental.key_reuse._common import (
consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError,
consume_p, assert_consumed_value_p, KeyReuseError,
Sink, Source, KeyReuseSignature
)
import numpy as np
@ -52,7 +52,7 @@ class KeyReuseSignatureWithForwards(NamedTuple):
key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {}
key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)])
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)])
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignatureWithForwards([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], [])
# TODO(jakevdp): should fold_in sink its input key?
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])

@ -33,7 +33,7 @@ from jax._src.debugging import debug_callback_p
from jax._src.interpreters import partial_eval as pe
from jax.experimental.key_reuse._common import (
consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError,
consume_p, assert_consumed_value_p, KeyReuseError,
Sink, Source, KeyReuseSignature
)
import numpy as np
@ -42,7 +42,7 @@ import numpy as np
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [])
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
# TODO(jakevdp): should fold_in sink its input key?
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])

@ -22,9 +22,8 @@ import jax.numpy as jnp
from jax._src import prng
from jax._src import test_util as jtu
from jax.experimental.key_reuse._common import (
assert_consumed, assert_unconsumed, consume, consume_p, unconsumed_copy_p)
from jax.experimental.key_reuse import (
_forwarding, _simple, KeyReuseError, unconsumed_copy)
assert_consumed, assert_unconsumed, consume, consume_p)
from jax.experimental.key_reuse import _forwarding, _simple, KeyReuseError
from jax import config
config.parse_flags_with_absl()
@ -36,7 +35,7 @@ key1D = jax.eval_shape(lambda key: key[None], key)
primitives_with_static_signatures = {
consume_p: (consume, key),
unconsumed_copy_p: (unconsumed_copy, key),
prng.reuse_key_p: (prng.reuse_key, key),
prng.random_bits_p: (jax.random.bits, key),
prng.random_fold_in_p: (jax.random.fold_in, key, 2),
prng.random_seed_p: (jax.random.key, 0),
@ -91,12 +90,12 @@ class KeyReuseUnitTestSimple(jtu.JaxTestCase):
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_unconsumed_copy(self):
def test_reuse_key(self):
def f(key):
assert_unconsumed(key)
consume(key)
assert_consumed(key)
key2 = unconsumed_copy(key)
key2 = prng.reuse_key(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))
@ -337,12 +336,12 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_unconsumed_copy(self):
def test_reuse_key(self):
def f(key):
assert_unconsumed(key)
consume(key)
assert_consumed(key)
key2 = unconsumed_copy(key)
key2 = prng.reuse_key(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))