[key reuse] define rule for lax.concatenate

This commit is contained in:
Jake VanderPlas 2024-03-11 15:06:59 -07:00
parent b6e985ffe7
commit 3eff032aba
2 changed files with 22 additions and 0 deletions

View File

@ -357,6 +357,16 @@ def _slice_signature(eqn):
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
def _concatenate_signature(eqn):
num_vals = len(eqn.invars)
# TODO(jakevdp): should this signature be more granular?
if num_vals == 1:
return KeyReuseSignature(Forward(0, 0))
else:
return KeyReuseSignature(*(Sink(i) for i in range(num_vals)), Source(0))
key_reuse_signatures_dynamic[lax.concatenate_p] = _concatenate_signature
def _pjit_key_type_signature(eqn):
return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)

View File

@ -209,6 +209,18 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_concatenate(self):
def f(key1, key2):
assert_unconsumed(key1)
assert_unconsumed(key2)
keys = jax.lax.concatenate([key1, key2], dimension=0)
assert_consumed(key1)
assert_consumed(key2)
assert_unconsumed(keys)
key1 = jax.random.split(jax.random.key(0))
key2 = jax.random.split(jax.random.key(1))
self.check_key_reuse(f, key1, key2)
def test_slice(self):
def f(keys):
assert_unconsumed(keys)