[key reuse] don't consume on equality check

This commit is contained in:
Jake VanderPlas 2024-03-04 13:32:35 -08:00
parent 67b0eb3af4
commit 84d11d7b11
2 changed files with 15 additions and 0 deletions

View File

@ -83,6 +83,9 @@ key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([], [], [Forward(0
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([Sink(1)], [], [Forward(0, 0)])
key_reuse_signatures[lax.gather_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[lax.scatter_p] = KeyReuseSignature([Sink(2)], [], [Forward(0, 0)])
# Equality checks don't consume
key_reuse_signatures[lax.eq_p] = KeyReuseSignature([], [], [])
key_reuse_signatures[lax.ne_p] = KeyReuseSignature([], [], [])
# Rules which require more dynamic logic.
key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {}

View File

@ -14,6 +14,7 @@
from absl.testing import absltest, parameterized
from functools import partial
import operator
import numpy as np
import jax
@ -216,6 +217,17 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
assert_consumed(keys, np.array([True, True]))
self.check_key_reuse(f, jax.random.split(jax.random.key(0)))
@parameterized.parameters(operator.eq, operator.ne)
def test_equality_checks(self, op):
def f(key1, key2):
assert_unconsumed(key1)
assert_unconsumed(key2)
result = op(key1, key2)
assert_unconsumed(key1)
assert_unconsumed(key2)
return result
self.check_key_reuse(f, jax.random.key(0), jax.random.key(1))
def test_jit_can_consume_input(self):
def f(key):
assert_unconsumed(key)