mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[key reuse] define rule for lax.concatenate
This commit is contained in:
parent
b6e985ffe7
commit
3eff032aba
@ -357,6 +357,16 @@ def _slice_signature(eqn):
|
||||
|
||||
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
|
||||
|
||||
def _concatenate_signature(eqn):
|
||||
num_vals = len(eqn.invars)
|
||||
# TODO(jakevdp): should this signature be more granular?
|
||||
if num_vals == 1:
|
||||
return KeyReuseSignature(Forward(0, 0))
|
||||
else:
|
||||
return KeyReuseSignature(*(Sink(i) for i in range(num_vals)), Source(0))
|
||||
|
||||
key_reuse_signatures_dynamic[lax.concatenate_p] = _concatenate_signature
|
||||
|
||||
def _pjit_key_type_signature(eqn):
|
||||
return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
|
||||
|
||||
|
@ -209,6 +209,18 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
assert_consumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
def test_concatenate(self):
|
||||
def f(key1, key2):
|
||||
assert_unconsumed(key1)
|
||||
assert_unconsumed(key2)
|
||||
keys = jax.lax.concatenate([key1, key2], dimension=0)
|
||||
assert_consumed(key1)
|
||||
assert_consumed(key2)
|
||||
assert_unconsumed(keys)
|
||||
key1 = jax.random.split(jax.random.key(0))
|
||||
key2 = jax.random.split(jax.random.key(1))
|
||||
self.check_key_reuse(f, key1, key2)
|
||||
|
||||
def test_slice(self):
|
||||
def f(keys):
|
||||
assert_unconsumed(keys)
|
||||
|
Loading…
x
Reference in New Issue
Block a user