mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Define reuse_key primitive in jax._src.prng
This commit is contained in:
parent
b9824d7de3
commit
49eb7008c0
docs
jax
_src
experimental
tests
@ -9,5 +9,5 @@ API
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
unconsumed_copy
|
||||
reuse_key
|
||||
KeyReuseError
|
||||
|
@ -396,7 +396,7 @@ def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
|
||||
# because the scan body may consume any keys within it.
|
||||
# Import here to avoid circular imports
|
||||
from jax.experimental import key_reuse
|
||||
xs_unconsumed = _map(key_reuse.unconsumed_copy, xs)
|
||||
xs_unconsumed = _map(key_reuse.reuse_key, xs)
|
||||
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
|
||||
out_flat = f_impl(*consts, *carry, *x)
|
||||
carry_out, y_updates = split_list(out_flat, [num_carry])
|
||||
|
@ -1338,3 +1338,26 @@ unsafe_rbg_prng_impl = PRNGImpl(
|
||||
tag='urbg')
|
||||
|
||||
register_prng(unsafe_rbg_prng_impl)
|
||||
|
||||
|
||||
# Primitives related to key reuse
|
||||
reuse_key_p = core.Primitive("reuse_key")
|
||||
reuse_key_p.def_impl(lambda x: x)
|
||||
reuse_key_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(reuse_key_p)
|
||||
mlir.register_lowering(reuse_key_p, lambda _, k: [k])
|
||||
|
||||
def reuse_key(key):
|
||||
"""Explicitly mark a key as unconsumed.
|
||||
|
||||
Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`)
|
||||
this function operates as an identity.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax
|
||||
>>> key = jax.random.key(0)
|
||||
>>> data = jax.random.uniform(key)
|
||||
>>> same_data = jax.random.uniform(reuse_key(key))
|
||||
"""
|
||||
return reuse_key_p.bind(key)
|
||||
|
@ -69,7 +69,6 @@ from jax._src.lax import slicing as lax_slicing
|
||||
from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
from jax.experimental.key_reuse._common import unconsumed_copy_p
|
||||
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
@ -1529,7 +1528,7 @@ tf_not_yet_impl = [
|
||||
"consume",
|
||||
]
|
||||
|
||||
tf_impl[unconsumed_copy_p] = lambda x: x
|
||||
tf_impl[prng.reuse_key_p] = lambda x: x
|
||||
|
||||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||||
|
||||
|
@ -38,8 +38,10 @@ Key reuse checking can be enabled on `jit`-compiled functions using the
|
||||
This flag can also be set globally if you wish to enagle key reuse checks in
|
||||
every JIT-compiled function.
|
||||
"""
|
||||
from jax._src.prng import (
|
||||
reuse_key as reuse_key,
|
||||
)
|
||||
|
||||
from jax.experimental.key_reuse._common import (
|
||||
unconsumed_copy as unconsumed_copy,
|
||||
KeyReuseError as KeyReuseError,
|
||||
)
|
||||
|
@ -62,17 +62,6 @@ def consume(key):
|
||||
"""Consume the key and return a consumed copy."""
|
||||
return consume_p.bind(key)
|
||||
|
||||
unconsumed_copy_p = core.Primitive("unconsumed_copy")
|
||||
unconsumed_copy_p.def_impl(lambda x: x)
|
||||
unconsumed_copy_p.def_abstract_eval(lambda x: x)
|
||||
batching.defvectorized(unconsumed_copy_p)
|
||||
mlir.register_lowering(
|
||||
unconsumed_copy_p,
|
||||
mlir.lower_fun(lambda x: x, multiple_results=False))
|
||||
|
||||
def unconsumed_copy(key):
|
||||
"""Return a copy of key marked as unconsumed."""
|
||||
return unconsumed_copy_p.bind(key)
|
||||
|
||||
assert_consumed_value_p = core.Primitive("assert_consumed_value")
|
||||
assert_consumed_value_p.def_impl(lambda x, *, value: x)
|
||||
|
@ -33,7 +33,7 @@ from jax._src.debugging import debug_callback_p
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
from jax.experimental.key_reuse._common import (
|
||||
consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError,
|
||||
consume_p, assert_consumed_value_p, KeyReuseError,
|
||||
Sink, Source, KeyReuseSignature
|
||||
)
|
||||
import numpy as np
|
||||
@ -52,7 +52,7 @@ class KeyReuseSignatureWithForwards(NamedTuple):
|
||||
key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {}
|
||||
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)])
|
||||
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignatureWithForwards([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], [])
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
|
||||
|
@ -33,7 +33,7 @@ from jax._src.debugging import debug_callback_p
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
from jax.experimental.key_reuse._common import (
|
||||
consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError,
|
||||
consume_p, assert_consumed_value_p, KeyReuseError,
|
||||
Sink, Source, KeyReuseSignature
|
||||
)
|
||||
import numpy as np
|
||||
@ -42,7 +42,7 @@ import numpy as np
|
||||
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
|
||||
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [])
|
||||
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
|
@ -22,9 +22,8 @@ import jax.numpy as jnp
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.key_reuse._common import (
|
||||
assert_consumed, assert_unconsumed, consume, consume_p, unconsumed_copy_p)
|
||||
from jax.experimental.key_reuse import (
|
||||
_forwarding, _simple, KeyReuseError, unconsumed_copy)
|
||||
assert_consumed, assert_unconsumed, consume, consume_p)
|
||||
from jax.experimental.key_reuse import _forwarding, _simple, KeyReuseError
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -36,7 +35,7 @@ key1D = jax.eval_shape(lambda key: key[None], key)
|
||||
|
||||
primitives_with_static_signatures = {
|
||||
consume_p: (consume, key),
|
||||
unconsumed_copy_p: (unconsumed_copy, key),
|
||||
prng.reuse_key_p: (prng.reuse_key, key),
|
||||
prng.random_bits_p: (jax.random.bits, key),
|
||||
prng.random_fold_in_p: (jax.random.fold_in, key, 2),
|
||||
prng.random_seed_p: (jax.random.key, 0),
|
||||
@ -91,12 +90,12 @@ class KeyReuseUnitTestSimple(jtu.JaxTestCase):
|
||||
assert_consumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
def test_unconsumed_copy(self):
|
||||
def test_reuse_key(self):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
consume(key)
|
||||
assert_consumed(key)
|
||||
key2 = unconsumed_copy(key)
|
||||
key2 = prng.reuse_key(key)
|
||||
assert_unconsumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
@ -337,12 +336,12 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
assert_consumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
def test_unconsumed_copy(self):
|
||||
def test_reuse_key(self):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
consume(key)
|
||||
assert_consumed(key)
|
||||
key2 = unconsumed_copy(key)
|
||||
key2 = prng.reuse_key(key)
|
||||
assert_unconsumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user