mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] don't consume on equality check
This commit is contained in:
parent
67b0eb3af4
commit
84d11d7b11
@ -83,6 +83,9 @@ key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([], [], [Forward(0
|
||||
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([Sink(1)], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[lax.gather_p] = KeyReuseSignature([], [], [Forward(0, 0)])
|
||||
key_reuse_signatures[lax.scatter_p] = KeyReuseSignature([Sink(2)], [], [Forward(0, 0)])
|
||||
# Equality checks don't consume
|
||||
key_reuse_signatures[lax.eq_p] = KeyReuseSignature([], [], [])
|
||||
key_reuse_signatures[lax.ne_p] = KeyReuseSignature([], [], [])
|
||||
|
||||
# Rules which require more dynamic logic.
|
||||
key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {}
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
@ -216,6 +217,17 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
assert_consumed(keys, np.array([True, True]))
|
||||
self.check_key_reuse(f, jax.random.split(jax.random.key(0)))
|
||||
|
||||
@parameterized.parameters(operator.eq, operator.ne)
|
||||
def test_equality_checks(self, op):
|
||||
def f(key1, key2):
|
||||
assert_unconsumed(key1)
|
||||
assert_unconsumed(key2)
|
||||
result = op(key1, key2)
|
||||
assert_unconsumed(key1)
|
||||
assert_unconsumed(key2)
|
||||
return result
|
||||
self.check_key_reuse(f, jax.random.key(0), jax.random.key(1))
|
||||
|
||||
def test_jit_can_consume_input(self):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
|
Loading…
x
Reference in New Issue
Block a user