diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index ab96f48f3..1b5a431d6 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -357,6 +357,16 @@ def _slice_signature(eqn): key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature +def _concatenate_signature(eqn): + num_vals = len(eqn.invars) + # TODO(jakevdp): should this signature be more granular? + if num_vals == 1: + return KeyReuseSignature(Forward(0, 0)) + else: + return KeyReuseSignature(*(Sink(i) for i in range(num_vals)), Source(0)) + +key_reuse_signatures_dynamic[lax.concatenate_p] = _concatenate_signature + def _pjit_key_type_signature(eqn): return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 5d5d6e12e..e3ed02e90 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -209,6 +209,18 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase): assert_consumed(key2) self.check_key_reuse(f, jax.random.key(0)) + def test_concatenate(self): + def f(key1, key2): + assert_unconsumed(key1) + assert_unconsumed(key2) + keys = jax.lax.concatenate([key1, key2], dimension=0) + assert_consumed(key1) + assert_consumed(key2) + assert_unconsumed(keys) + key1 = jax.random.split(jax.random.key(0)) + key2 = jax.random.split(jax.random.key(1)) + self.check_key_reuse(f, key1, key2) + def test_slice(self): def f(keys): assert_unconsumed(keys)