[key reuse] fix signature for device_put

This commit is contained in:
Jake VanderPlas 2025-01-17 09:47:50 -08:00
parent a4a657bc43
commit f83175fc94
2 changed files with 26 additions and 7 deletions

View File

@ -292,7 +292,6 @@ key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature(Sink(0))
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.copy_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.device_put_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.reshape_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature(Source(0))
@ -561,6 +560,14 @@ def _remat_key_type_signature(eqn):
key_reuse_signatures[remat_p] = _remat_key_type_signature
@dynamic_key_reuse_signature
def _device_put_signature(eqn):
num_vals = len(eqn.invars)
return KeyReuseSignature(*(Forward(i, i) for i in range(num_vals)))
key_reuse_signatures[lax.device_put_p] = _device_put_signature
def call_impl_with_key_reuse_checks(prim: core.Primitive, raw_impl: Callable[..., Any], *args, **kwargs) -> Any:
if prim not in key_reuse_signatures:
# TODO(jakevdp): should we use an unknown signature here?

View File

@ -49,7 +49,6 @@ primitives_with_static_signatures = {
jax.lax.broadcast_in_dim_p: (lambda key: key[None], key),
jax.lax.copy_p: (jnp.array, key),
jax.lax.convert_element_type_p: (lambda key: jnp.array(key, dtype=key.dtype), key),
jax.lax.device_put_p: (jax.device_put, key),
jax.lax.reshape_p: (lambda key: key.reshape((1,)), key),
jax.lax.squeeze_p: (jnp.squeeze, key1D),
jax.lax.dynamic_slice_p: (partial(jax.lax.dynamic_slice, slice_sizes=(1,)), key1D, (0,)),
@ -178,14 +177,27 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
def test_device_put(self):
def f(key):
assert_unconsumed(key)
key2 = jax.device_put(key)
assert_unconsumed(key)
assert_unconsumed(key2)
key_d = jax.device_put(key)
assert_unconsumed(key_d)
consume(key)
assert_consumed(key)
assert_consumed(key2)
assert_consumed(key_d)
self.check_key_reuse(f, jax.random.key(0))
def test_device_put_multiple(self):
def f(key1, key2):
assert_unconsumed(key1)
assert_unconsumed(key2)
key1_d, key2_d = jax.device_put((key1, key2))
assert_unconsumed(key1_d)
consume(key1)
assert_consumed(key1_d)
assert_unconsumed(key2_d)
consume(key2)
assert_consumed(key2_d)
self.check_key_reuse(f, jax.random.key(0), jax.random.key(1))
def test_squeeze(self):
def f(key):
assert_unconsumed(key)