rocm_jax/tests/key_reuse_test.py
2025-01-17 09:47:50 -08:00

773 lines
24 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest, parameterized
from functools import partial
import operator
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import prng
from jax._src import random
from jax._src import test_util as jtu
from jax.errors import KeyReuseError
from jax.experimental.key_reuse._core import (
assert_consumed, assert_unconsumed, consume, consume_p,
Source, Sink, Forward, KeyReuseSignature)
from jax.experimental.key_reuse import _core
jax.config.parse_flags_with_absl()
key = jax.eval_shape(jax.random.key, 0)
key1D = jax.eval_shape(lambda key: key[None], key)
primitives_with_static_signatures = {
consume_p: (consume, key),
random.random_clone_p: (random.clone, key),
prng.random_bits_p: (jax.random.bits, key),
# prng.random_fold_in_p: (jax.random.fold_in, key, 2),
prng.random_seed_p: (jax.random.key, 0),
prng.random_split_p: (jax.random.split, key),
prng.random_wrap_p: (jax.random.wrap_key_data, np.uint32([0, 0])),
# prng.random_unwrap_p: (jax.random.key_data, key),
jax.random.random_gamma_p: (jax.random.gamma, key, 1.0),
jax.lax.broadcast_in_dim_p: (lambda key: key[None], key),
jax.lax.copy_p: (jnp.array, key),
jax.lax.convert_element_type_p: (lambda key: jnp.array(key, dtype=key.dtype), key),
jax.lax.reshape_p: (lambda key: key.reshape((1,)), key),
jax.lax.squeeze_p: (jnp.squeeze, key1D),
jax.lax.dynamic_slice_p: (partial(jax.lax.dynamic_slice, slice_sizes=(1,)), key1D, (0,)),
jax.lax.dynamic_update_slice_p: (jax.lax.dynamic_update_slice, key1D, key1D, (0,)),
}
# Primitive that is unknown to the key reuse machinery
unknown_p = core.Primitive("unknown")
unknown_p.def_abstract_eval(lambda x: x)
unknown_p.def_impl(lambda x: x)
def apply_unknown_primitive(key):
return unknown_p.bind(key)
@jtu.with_config(
jax_enable_custom_prng=False,
jax_debug_key_reuse=False)
class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
def check_key_reuse(self, *args):
return _core.check_key_reuse(*args)
def test_assertions(self):
key = jax.random.key(0)
self.check_key_reuse(assert_unconsumed, key)
with self.assertRaises(AssertionError):
self.check_key_reuse(assert_consumed, key)
def test_unknown(self):
def f(key):
assert_unconsumed(key)
key2 = apply_unknown_primitive(key)
assert_consumed(key)
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_consume(self):
def f(key):
assert_unconsumed(key)
key2 = consume(key)
assert_consumed(key)
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_random_clone(self):
def f(key):
assert_unconsumed(key)
consume(key)
assert_consumed(key)
key2 = jax.random.clone(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_seed(self):
def f():
key = jax.random.key(0)
assert_unconsumed(key)
self.check_key_reuse(f)
def test_split(self):
def f(key):
assert_unconsumed(key)
key2 = jax.random.split(key)
assert_unconsumed(key2)
assert_consumed(key)
self.check_key_reuse(f, jax.random.key(0))
def test_fold_in(self):
def f(key):
assert_unconsumed(key)
key2 = jax.random.fold_in(key, 2)
assert_unconsumed(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_bits(self):
def f(key):
assert_unconsumed(key)
bits = jax.random.bits(key, (), 'uint32')
assert_consumed(key)
return bits
self.check_key_reuse(f, jax.random.key(0))
def test_wrap(self):
def f(key_data):
key = jax.random.wrap_key_data(key_data)
assert_unconsumed(key)
self.check_key_reuse(f, jax.random.PRNGKey(0))
def test_unwrap(self):
def f(key):
assert_unconsumed(key)
key_data = jax.random.key_data(key)
assert_unconsumed(key)
self.check_key_reuse(f, jax.random.key(0))
def test_gamma(self):
def f(key):
assert_unconsumed(key)
values = jax.random.gamma(key, 1.0)
assert_consumed(key)
return values
self.check_key_reuse(f, jax.random.key(0))
def test_broadcast_in_dim(self):
def f(key):
assert_unconsumed(key)
key2 = key[None]
assert_unconsumed(key)
assert_unconsumed(key2)
consume(key)
assert_consumed(key)
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_copy(self):
def f(key):
assert_unconsumed(key)
key2 = jnp.array(key, copy=True)
assert_unconsumed(key)
assert_unconsumed(key2)
consume(key)
assert_consumed(key)
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_device_put(self):
def f(key):
assert_unconsumed(key)
key_d = jax.device_put(key)
assert_unconsumed(key_d)
consume(key)
assert_consumed(key_d)
self.check_key_reuse(f, jax.random.key(0))
def test_device_put_multiple(self):
def f(key1, key2):
assert_unconsumed(key1)
assert_unconsumed(key2)
key1_d, key2_d = jax.device_put((key1, key2))
assert_unconsumed(key1_d)
consume(key1)
assert_consumed(key1_d)
assert_unconsumed(key2_d)
consume(key2)
assert_consumed(key2_d)
self.check_key_reuse(f, jax.random.key(0), jax.random.key(1))
def test_squeeze(self):
def f(key):
assert_unconsumed(key)
key2 = jax.lax.squeeze(key, (0,))
assert_unconsumed(key)
assert_unconsumed(key2)
consume(key)
assert_consumed(key)
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0)[None])
def test_reshape(self):
def f(key):
assert_unconsumed(key)
key2 = key.reshape(1, *key.shape)
assert_unconsumed(key)
assert_unconsumed(key2)
consume(key)
assert_consumed(key)
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_concatenate(self):
def f(key1, key2):
assert_unconsumed(key1)
assert_unconsumed(key2)
keys = jax.lax.concatenate([key1, key2], dimension=0)
assert_consumed(key1)
assert_consumed(key2)
assert_unconsumed(keys)
key1 = jax.random.split(jax.random.key(0))
key2 = jax.random.split(jax.random.key(1))
self.check_key_reuse(f, key1, key2)
def test_slice(self):
def f(keys):
assert_unconsumed(keys)
assert_unconsumed(keys[0])
assert_consumed(keys, np.array([True, False]))
assert_unconsumed(keys[1])
assert_consumed(keys, np.array([True, True]))
self.check_key_reuse(f, jax.random.split(jax.random.key(0)))
@parameterized.parameters(operator.eq, operator.ne)
def test_equality_checks(self, op):
def f(key1, key2):
assert_unconsumed(key1)
assert_unconsumed(key2)
result = op(key1, key2)
assert_unconsumed(key1)
assert_unconsumed(key2)
return result
self.check_key_reuse(f, jax.random.key(0), jax.random.key(1))
def test_jit_can_consume_input(self):
def f(key):
assert_unconsumed(key)
jax.jit(jax.random.bits)(key)
assert_consumed(key)
self.check_key_reuse(f, jax.random.key(0))
def test_jit_can_return_consumed_output(self):
def f():
def g():
key = jax.random.key(0)
assert_unconsumed(key)
bits = jax.random.bits(key)
assert_consumed(key)
return bits, key
_, key = jax.jit(g)()
assert_consumed(key)
self.check_key_reuse(f)
def test_jit_duplicate_inputs(self):
def f(key):
assert_unconsumed(key)
def g(key1, key2):
assert_unconsumed(key1)
assert_unconsumed(key2)
return jax.random.bits(key1)
_ = jax.jit(g)(key, key)
assert_consumed(key)
self.check_key_reuse(f, jax.random.key(0))
def test_jit_propagates_consumption_bit(self):
def f(key):
assert_unconsumed(key)
g = jax.jit(lambda: key)
key2 = g()
assert_unconsumed(key)
assert_unconsumed(key2)
consume(key)
assert_consumed(key)
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_jit_duplicate_outputs(self):
# TODO(jakevdp): implement this case
def f(key):
assert_unconsumed(key)
def g(key):
return key, key
key1, key2 = jax.jit(g)(key)
assert_unconsumed(key)
assert_unconsumed(key1)
assert_unconsumed(key2)
_ = jax.random.bits(key1)
assert_consumed(key)
assert_consumed(key1)
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_cond_both_consumed(self):
@jax.jit
def f(flag, key):
assert_unconsumed(key)
_ = jax.lax.cond(
flag, jax.random.uniform, jax.random.normal, key)
assert_consumed(key)
self.check_key_reuse(f, True, jax.random.key(0))
def test_cond_one_consumed(self):
@jax.jit
def f(flag, key):
assert_unconsumed(key)
_ = jax.lax.cond(
flag, jax.random.uniform, lambda k: 1.0, key)
assert_consumed(key)
self.check_key_reuse(f, True, jax.random.key(0))
def test_cond_neither_consumed(self):
@jax.jit
def f(flag, key):
assert_unconsumed(key)
_ = jax.lax.cond(
flag, lambda k: 0.0, lambda k: 1.0, key)
assert_unconsumed(key)
self.check_key_reuse(f, True, jax.random.key(0))
def test_simple_vmap(self):
@jax.jit
def f(seed):
key = jax.random.key(seed)
assert_unconsumed(key)
result = jax.random.uniform(key)
assert_consumed(key)
return result
self.check_key_reuse(f, 0)
self.check_key_reuse(jax.vmap(f), jnp.arange(4))
@parameterized.parameters(*primitives_with_static_signatures)
def test_jaxpr_type_signature(self, primitive):
func, *args = primitives_with_static_signatures[primitive]
signature = _core.key_reuse_signatures[primitive]
jaxpr = jax.make_jaxpr(func)(*args)
self.assertEqual(signature, _core.jaxpr_type_signature(jaxpr.jaxpr))
@parameterized.parameters(*primitives_with_static_signatures)
def test_function_type_signature(self, primitive):
func, *args = primitives_with_static_signatures[primitive]
signature = _core.key_reuse_signatures[primitive]
self.assertEqual(signature, _core.function_type_signature(func, *args))
@jtu.with_config(jax_debug_key_reuse=False)
class KeyReuseIntegrationTest(jtu.JaxTestCase):
random_bits_error = "In random_bits, argument [0-9]+ is already consumed.*"
random_split_error = "In random_split, argument [0-9]+ is already consumed.*"
generic_error = ".*argument [0-9]+ is already consumed.*"
pjit_error = "In pjit, argument 0 is already consumed."
def check_key_reuse(self, f, *args):
return _core.check_key_reuse(f, *args)
def test_reuse(self):
def f():
key = jax.random.key(0)
return jax.random.uniform(key) + jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f)
def test_reuse_after_split(self):
def f_good():
key = jax.random.key(0)
key1, key2 = jax.random.split(key)
return jax.random.uniform(key1) + jax.random.uniform(key2)
self.check_key_reuse(f_good)
def f_bad():
key = jax.random.key(0)
_ = jax.random.split(key)
return jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad)
def f_bad_2():
key = jax.random.key(0)
_ = jax.random.split(key)
key1, _ = jax.random.split(key)
return jax.random.uniform(key1)
with self.assertRaisesRegex(KeyReuseError, self.random_split_error):
self.check_key_reuse(f_bad_2)
def test_repeated_fold_ins(self):
# TODO(jakevdp): should we allow repeated fold-ins?
def f():
key = jax.random.key(0)
keys = [jax.random.fold_in(key, i)
for i in range(10)]
return [jax.random.uniform(k) for k in keys]
self.check_key_reuse(f)
def test_reuse_after_fold_in(self):
def f():
key = jax.random.key(0)
_ = jax.random.fold_in(key, 1)
return jax.random.uniform(key)
self.check_key_reuse(f)
def test_reuse_after_broadcast(self):
def f():
key = jax.random.key(0)
key2 = key[None]
return jax.random.bits(key) + jax.vmap(jax.random.bits)(key2)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(f)
def test_reuse_after_reshape(self):
def f():
key = jax.random.key(0)
key2 = key.reshape((1,))
return jax.random.bits(key) + jax.random.bits(key2.squeeze())
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(f)
def test_reuse_after_squeeze(self):
def f():
key = jax.random.split(jax.random.key(0), 1)
key2 = jax.lax.squeeze(key, (0,))
return jax.random.bits(key.squeeze()) + jax.random.bits(key2)
with self.assertRaisesRegex(KeyReuseError, self.generic_error):
self.check_key_reuse(f)
def test_reuse_after_cond(self):
def f_good(key, condition):
return jax.lax.cond(condition, jax.random.uniform, jax.random.normal, key)
key = jax.random.key(0)
self.check_key_reuse(f_good, key, True)
self.check_key_reuse(f_good, key, False)
# Check where both branches consume the key
def f_bad(key, condition):
r1 = jax.lax.cond(condition, jax.random.uniform, jax.random.normal, key)
return r1 + jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad, key, True)
# Check where only one branch consumes the key
def f_bad_2(key, condition):
r1 = jax.lax.cond(condition, jax.random.uniform, lambda key: 1.0, key)
return r1 + jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad_2, key, True)
def test_simple_scan(self):
def f_good(key):
def body_fun(key, _):
key, subkey = jax.random.split(key)
return key, jax.random.bits(subkey)
return jax.lax.scan(body_fun, key, xs=jnp.arange(10))
self.check_key_reuse(f_good, jax.random.key(0))
def test_scan_sink_on_consts(self):
def f(key):
def body_fun(carry, _):
return carry, jax.random.uniform(key)
return jax.lax.scan(body_fun, None, xs=jnp.arange(10))
with self.assertRaisesRegex(KeyReuseError, "scan body function leads to key reuse"):
self.check_key_reuse(f, jax.random.key(0))
def test_scan_reuse_in_body(self):
def f_bad(key):
def body_fun(key, _):
return key, jax.random.bits(key)
return jax.lax.scan(body_fun, key, xs=jnp.arange(10))
with self.assertRaisesRegex(KeyReuseError, "scan body function leads to key reuse"):
self.check_key_reuse(f_bad, jax.random.key(0))
def test_scan_good_over_keys(self):
def f_scan_over_keys(key):
keys = jax.random.split(key, 5)
return jax.lax.map(jax.random.bits, keys)
self.check_key_reuse(f_scan_over_keys, jax.random.key(0))
def test_scan_consume_one(self):
def f_scan_over_keys(*keys):
def body_func(keys, x):
return tuple(jax.random.split(keys[0])), x
return jax.lax.scan(body_func, keys, xs=jnp.arange(10))
self.check_key_reuse(f_scan_over_keys, jax.random.key(0), jax.random.key(1))
def test_vmap(self):
@jax.vmap
def f_good(seed):
key = jax.random.key(seed)
return jax.random.bits(key)
self.check_key_reuse(f_good, jnp.arange(4))
@jax.vmap
def f_bad(seed):
key = jax.random.key(0)
return jax.random.bits(key) + jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(f_bad, jnp.arange(4))
def test_while_simple(self):
def f(seed):
key = jax.random.key(seed)
def cond_fun(carry):
return carry[1] < 10
def body_fun(carry):
key, subkey = jax.random.split(carry[0])
return key, carry[1] + jax.random.uniform(subkey)
return jax.lax.while_loop(cond_fun, body_fun, (key, 0))
self.check_key_reuse(f, 0)
def test_while_bad_cond(self):
def f(seed):
key = jax.random.key(seed)
def cond_fun(carry):
i, key = carry
return i < jax.random.uniform(key)
def body_fun(carry):
i, key = carry
return i + 1, key
return jax.lax.while_loop(cond_fun, body_fun, (0, key))
with self.assertRaisesRegex(KeyReuseError, "while_loop cond"):
self.check_key_reuse(f, 0)
def test_while_bad_body(self):
def f(seed):
key = jax.random.key(seed)
def cond_fun(carry):
key, i = carry
return i < 5
def body_fun(carry):
key, i = carry
return key, i + jax.random.randint(key, (), 1, 3)
return jax.lax.while_loop(cond_fun, body_fun, (key, 0))
with self.assertRaisesRegex(KeyReuseError, "while_loop body function leads to key reuse"):
self.check_key_reuse(f, 0)
def test_while_sink_on_body_consts(self):
def f(seed):
key = jax.random.key(seed)
def cond_fun(i):
return i < 5
def body_fun(i):
return i + jax.random.randint(key, (), 1, 3)
return jax.lax.while_loop(cond_fun, body_fun, 0)
with self.assertRaisesRegex(KeyReuseError, "while_loop body function leads to key reuse"):
self.check_key_reuse(f, 0)
def test_while_sink_on_cond_consts(self):
def f(seed):
key = jax.random.key(seed)
def cond_fun(i):
return i < jax.random.uniform(key)
def body_fun(i):
return i + 1
return jax.lax.while_loop(cond_fun, body_fun, 0)
with self.assertRaisesRegex(KeyReuseError, "while_loop cond function leads to key reuse"):
self.check_key_reuse(f, 0)
def test_pjit_consumed_input(self):
@jax.jit
def g(key, x): # doesn't consume key
return x
def f(seed):
key = jax.random.key(seed)
x = jax.random.bits(key)
return g(key, x)
self.check_key_reuse(f, 0)
@jax.numpy_dtype_promotion('standard')
def test_remat(self):
@jax.checkpoint
def f_bad(x, key):
return x * jax.random.bits(key) + jax.random.bits(key)
@jax.checkpoint
def f_good(x, key):
return x * jax.random.bits(key)
x = jnp.float32(1.0)
key = jax.random.key(0)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(f_bad, x, key)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(jax.grad(f_bad), x, key)
self.check_key_reuse(f_good, x, key)
self.check_key_reuse(jax.grad(f_good), x, key)
@jtu.with_config(jax_debug_key_reuse=True)
class KeyReuseEagerTest(jtu.JaxTestCase):
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
traced_bits_msg = "In random_bits, argument 0 is already consumed."
def test_clone_eager(self):
key = jax.random.key(0)
key2 = jax.random.clone(key)
self.assertIsNot(key, key2)
_ = jax.random.uniform(key)
self.assertTrue(key._consumed)
self.assertFalse(key2._consumed)
def test_simple_reuse_nojit(self):
key = jax.random.key(0)
with jax.disable_jit():
_ = jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg):
_ = jax.random.bits(key)
def test_simple_key_reuse_jit(self):
key = jax.random.key(0)
_ = jax.jit(jax.random.bits)(key)
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = jax.jit(jax.random.bits)(key)
def test_closed_over_key_reuse_jit(self):
key = jax.random.key(0)
@jax.jit
def f():
return jax.random.uniform(key)
_ = f()
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = f()
def test_key_reuse_within_jit(self):
@jax.jit
def f():
key = jax.random.key(0)
return jax.random.bits(key) + jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.traced_bits_msg):
f()
class KeyReuseImplementationTest(jtu.JaxTestCase):
def assertEquivalent(self, a, b):
self.assertEqual(a, b)
self.assertEqual(hash(a), hash(b))
def assertNotEquivalent(self, a, b):
self.assertNotEqual(a, b)
self.assertNotEqual(hash(a), hash(b))
def test_source_sink_immutability(self):
mask = np.array([True, False])
orig_mask_writeable = mask.flags.writeable
sink = Sink(0, mask)
source = Source(0, mask)
self.assertFalse(sink.mask.flags.writeable)
self.assertFalse(source.mask.flags.writeable)
self.assertEqual(mask.flags.writeable, orig_mask_writeable)
with self.assertRaises(ValueError):
sink.idx = 1
with self.assertRaises(ValueError):
sink.mask = True
with self.assertRaises(ValueError):
source.idx = 1
with self.assertRaises(ValueError):
source.mask = True
def test_source_sink_forward_equivalence_semantics(self):
true_mask = np.array([True, True])
false_mask = np.array([False, False])
mixed_mask = np.array([True, False])
self.assertEquivalent(Source(0), Source(0, True))
self.assertEquivalent(Source(0, True), Source(0, true_mask))
self.assertEquivalent(Source(0, False), Source(0, false_mask))
self.assertEquivalent(Source(0, mixed_mask), Source(0, mixed_mask))
self.assertNotEquivalent(Source(0), Source(1))
self.assertNotEquivalent(Source(0), Source(0, False))
self.assertNotEquivalent(Source(0), Source(0, mixed_mask))
self.assertEquivalent(Sink(0), Sink(0, True))
self.assertEquivalent(Sink(0, True), Sink(0, true_mask))
self.assertEquivalent(Sink(0, False), Sink(0, false_mask))
self.assertEquivalent(Sink(0, mixed_mask), Sink(0, mixed_mask))
self.assertNotEquivalent(Sink(0), Sink(1))
self.assertNotEquivalent(Sink(0), Sink(0, False))
self.assertNotEquivalent(Sink(0), Sink(0, mixed_mask))
self.assertNotEquivalent(Source(0), Sink(0))
self.assertEquivalent(Forward(0, 1), Forward(0, 1))
self.assertNotEquivalent(Forward(0, 1), Forward(1, 0))
def test_signature_equality_semantics(self):
self.assertEquivalent(
KeyReuseSignature(Sink(0), Source(1), Forward(1, 0)),
KeyReuseSignature(Forward(1, 0), Source(1), Sink(0)))
self.assertEquivalent(
KeyReuseSignature(), KeyReuseSignature())
self.assertNotEquivalent(
KeyReuseSignature(Source(0)), KeyReuseSignature(Sink(0)))
def test_reprs(self):
self.assertEqual(repr(Sink(0)), "Sink(0)")
self.assertEqual(repr(Source(0)), "Source(0)")
self.assertEqual(repr(Forward(0, 1)), "Forward(0, 1)")
self.assertEqual(repr(KeyReuseSignature(Sink(1), Source(0))),
"KeyReuseSignature(Sink(1), Source(0))")
self.assertEqual(repr(KeyReuseSignature(Sink(1), Sink(0))),
"KeyReuseSignature(Sink(0), Sink(1))")
@jtu.with_config(jax_enable_checks=False)
class KeyReuseGlobalFlagsTest(jtu.JaxTestCase):
def test_key_reuse_flag(self):
@jax.jit
def f_bad(key):
return jax.random.bits(key) + jax.random.bits(key)
@jax.jit
def f_good(key):
return jax.random.bits(key)
key = jax.random.key(0)
with jax.debug_key_reuse(False):
f_good(key)
f_bad(key) # No failure
f_bad.clear_cache()
f_good.clear_cache()
with jax.debug_key_reuse(True):
f_good(key)
with self.assertRaisesRegex(KeyReuseError, "In random_bits.*"):
f_bad(key)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())