mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19520 from jakevdp:fold-in-consume
PiperOrigin-RevId: 601609582
This commit is contained in:
commit
1264700e73
@ -53,7 +53,9 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {}
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], [])
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignatureWithForwards([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_split_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignatureWithForwards([Sink(0)], [])
|
||||
|
@ -43,7 +43,9 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [])
|
||||
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)])
|
||||
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], [])
|
||||
|
@ -118,7 +118,7 @@ class KeyReuseUnitTestSimple(jtu.JaxTestCase):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
key2 = jax.random.fold_in(key, 2)
|
||||
assert_consumed(key)
|
||||
assert_unconsumed(key)
|
||||
assert_unconsumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
@ -355,7 +355,7 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
key2 = jax.random.fold_in(key, 2)
|
||||
assert_consumed(key)
|
||||
assert_unconsumed(key)
|
||||
assert_unconsumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
@ -603,14 +603,22 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_split_error):
|
||||
self.check_key_reuse(f_bad_2)
|
||||
|
||||
def test_repeated_fold_ins(self):
|
||||
# TODO(jakevdp): should we allow repeated fold-ins?
|
||||
def f():
|
||||
key = jax.random.key(0)
|
||||
keys = [jax.random.fold_in(key, i)
|
||||
for i in range(10)]
|
||||
return [jax.random.uniform(k) for k in keys]
|
||||
self.check_key_reuse(f)
|
||||
|
||||
def test_reuse_after_fold_in(self):
|
||||
def f():
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.fold_in(key, 1)
|
||||
return jax.random.uniform(key)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
self.check_key_reuse(f)
|
||||
self.check_key_reuse(f)
|
||||
|
||||
def test_reuse_after_broadcast(self):
|
||||
def f():
|
||||
|
Loading…
x
Reference in New Issue
Block a user