[key reuse] don't consume key in fold_in

Why? We've found in practice that downstream projects use fold_in multiple
times with the same key. This is safe so long as the folded-in value is
different every time; in this sense fold_in() is similar to seed(), and
for now we must trust the user to not repeat seeds.
This commit is contained in:
Jake VanderPlas 2024-01-25 15:35:51 -08:00
parent 45daced7c9
commit b069c20e56
3 changed files with 18 additions and 6 deletions

View File

@ -53,7 +53,9 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {}
key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)])
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], [])
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
# TODO(jakevdp): should fold_in sink its input key?
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([], [Source(0)])
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignatureWithForwards([], [Source(0)])
key_reuse_signatures[prng.random_split_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignatureWithForwards([Sink(0)], [])

View File

@ -43,7 +43,9 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [])
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
# TODO(jakevdp): should fold_in sink its input key?
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)])
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], [])

View File

@ -118,7 +118,7 @@ class KeyReuseUnitTestSimple(jtu.JaxTestCase):
def f(key):
assert_unconsumed(key)
key2 = jax.random.fold_in(key, 2)
assert_consumed(key)
assert_unconsumed(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))
@ -355,7 +355,7 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
def f(key):
assert_unconsumed(key)
key2 = jax.random.fold_in(key, 2)
assert_consumed(key)
assert_unconsumed(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))
@ -603,14 +603,22 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
with self.assertRaisesRegex(KeyReuseError, self.random_split_error):
self.check_key_reuse(f_bad_2)
def test_repeated_fold_ins(self):
# TODO(jakevdp): should we allow repeated fold-ins?
def f():
key = jax.random.key(0)
keys = [jax.random.fold_in(key, i)
for i in range(10)]
return [jax.random.uniform(k) for k in keys]
self.check_key_reuse(f)
def test_reuse_after_fold_in(self):
def f():
key = jax.random.key(0)
_ = jax.random.fold_in(key, 1)
return jax.random.uniform(key)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(f)
self.check_key_reuse(f)
def test_reuse_after_broadcast(self):
def f():