mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[key reuse] don't consume key in fold_in
Why? We've found in practice that downstream projects use fold_in multiple times with the same key. This is safe so long as the folded-in value is different every time; in this sense fold_in() is similar to seed(), and for now we must trust the user to not repeat seeds.
This commit is contained in:
parent
45daced7c9
commit
b069c20e56
@ -53,7 +53,9 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {}
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], [])
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignatureWithForwards([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_split_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignatureWithForwards([Sink(0)], [])
|
||||
|
@ -43,7 +43,9 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [])
|
||||
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], [])
|
||||
|
@ -118,7 +118,7 @@ class KeyReuseUnitTestSimple(jtu.JaxTestCase):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
key2 = jax.random.fold_in(key, 2)
|
||||
assert_consumed(key)
|
||||
assert_unconsumed(key)
|
||||
assert_unconsumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
@ -355,7 +355,7 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
key2 = jax.random.fold_in(key, 2)
|
||||
assert_consumed(key)
|
||||
assert_unconsumed(key)
|
||||
assert_unconsumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
@ -603,14 +603,22 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_split_error):
|
||||
self.check_key_reuse(f_bad_2)
|
||||
|
||||
def test_repeated_fold_ins(self):
|
||||
# TODO(jakevdp): should we allow repeated fold-ins?
|
||||
def f():
|
||||
key = jax.random.key(0)
|
||||
keys = [jax.random.fold_in(key, i)
|
||||
for i in range(10)]
|
||||
return [jax.random.uniform(k) for k in keys]
|
||||
self.check_key_reuse(f)
|
||||
|
||||
def test_reuse_after_fold_in(self):
|
||||
def f():
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.fold_in(key, 1)
|
||||
return jax.random.uniform(key)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
self.check_key_reuse(f)
|
||||
self.check_key_reuse(f)
|
||||
|
||||
def test_reuse_after_broadcast(self):
|
||||
def f():
|
||||
|
Loading…
x
Reference in New Issue
Block a user