From b069c20e56f25fc52aeef71afb2269eec3f5bcb8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 25 Jan 2024 15:35:51 -0800 Subject: [PATCH] [key reuse] don't consume key in fold_in Why? We've found in practice that downstream projects use fold_in multiple times with the same key. This is safe so long as the folded-in value is different every time; in this sense fold_in() is similar to seed(), and for now we must trust the user to not repeat seeds. --- jax/experimental/key_reuse/_forwarding.py | 4 +++- jax/experimental/key_reuse/_simple.py | 4 +++- tests/key_reuse_test.py | 16 ++++++++++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/jax/experimental/key_reuse/_forwarding.py b/jax/experimental/key_reuse/_forwarding.py index 69b0a29b7..1c070bf10 100644 --- a/jax/experimental/key_reuse/_forwarding.py +++ b/jax/experimental/key_reuse/_forwarding.py @@ -53,7 +53,9 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {} key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)]) key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)]) key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], []) -key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)]) +# TODO(jakevdp): should fold_in sink its input key? +# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)]) +key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([], [Source(0)]) key_reuse_signatures[prng.random_seed_p] = KeyReuseSignatureWithForwards([], [Source(0)]) key_reuse_signatures[prng.random_split_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)]) key_reuse_signatures[random.random_gamma_p] = KeyReuseSignatureWithForwards([Sink(0)], []) diff --git a/jax/experimental/key_reuse/_simple.py b/jax/experimental/key_reuse/_simple.py index 86b1a1984..a7fe4881a 100644 --- a/jax/experimental/key_reuse/_simple.py +++ b/jax/experimental/key_reuse/_simple.py @@ -43,7 +43,9 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {} key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], []) key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)]) key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], []) -key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +# TODO(jakevdp): should fold_in sink its input key? +# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)]) +key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)]) key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)]) key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)]) key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], []) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index d42217a4c..1d141b343 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -118,7 +118,7 @@ class KeyReuseUnitTestSimple(jtu.JaxTestCase): def f(key): assert_unconsumed(key) key2 = jax.random.fold_in(key, 2) - assert_consumed(key) + assert_unconsumed(key) assert_unconsumed(key2) self.check_key_reuse(f, jax.random.key(0)) @@ -355,7 +355,7 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase): def f(key): assert_unconsumed(key) key2 = jax.random.fold_in(key, 2) - assert_consumed(key) + assert_unconsumed(key) assert_unconsumed(key2) self.check_key_reuse(f, jax.random.key(0)) @@ -603,14 +603,22 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase): with self.assertRaisesRegex(KeyReuseError, self.random_split_error): self.check_key_reuse(f_bad_2) + def test_repeated_fold_ins(self): + # TODO(jakevdp): should we allow repeated fold-ins? + def f(): + key = jax.random.key(0) + keys = [jax.random.fold_in(key, i) + for i in range(10)] + return [jax.random.uniform(k) for k in keys] + self.check_key_reuse(f) + def test_reuse_after_fold_in(self): def f(): key = jax.random.key(0) _ = jax.random.fold_in(key, 1) return jax.random.uniform(key) - with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): - self.check_key_reuse(f) + self.check_key_reuse(f) def test_reuse_after_broadcast(self): def f():