From f83175fc944238fa2dcb26359500cc92e2883f9c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 17 Jan 2025 09:47:50 -0800 Subject: [PATCH] [key reuse] fix signature for device_put --- jax/experimental/key_reuse/_core.py | 9 ++++++++- tests/key_reuse_test.py | 24 ++++++++++++++++++------ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index ef19f94c4..87c41e13c 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -292,7 +292,6 @@ key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature(Sink(0)) key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature(Forward(0, 0)) key_reuse_signatures[lax.copy_p] = KeyReuseSignature(Forward(0, 0)) key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature(Forward(0, 0)) -key_reuse_signatures[lax.device_put_p] = KeyReuseSignature(Forward(0, 0)) key_reuse_signatures[lax.reshape_p] = KeyReuseSignature(Forward(0, 0)) key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature(Forward(0, 0)) key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature(Source(0)) @@ -561,6 +560,14 @@ def _remat_key_type_signature(eqn): key_reuse_signatures[remat_p] = _remat_key_type_signature +@dynamic_key_reuse_signature +def _device_put_signature(eqn): + num_vals = len(eqn.invars) + return KeyReuseSignature(*(Forward(i, i) for i in range(num_vals))) + +key_reuse_signatures[lax.device_put_p] = _device_put_signature + + def call_impl_with_key_reuse_checks(prim: core.Primitive, raw_impl: Callable[..., Any], *args, **kwargs) -> Any: if prim not in key_reuse_signatures: # TODO(jakevdp): should we use an unknown signature here? diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 3364c9be9..0d7ac9c18 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -49,7 +49,6 @@ primitives_with_static_signatures = { jax.lax.broadcast_in_dim_p: (lambda key: key[None], key), jax.lax.copy_p: (jnp.array, key), jax.lax.convert_element_type_p: (lambda key: jnp.array(key, dtype=key.dtype), key), - jax.lax.device_put_p: (jax.device_put, key), jax.lax.reshape_p: (lambda key: key.reshape((1,)), key), jax.lax.squeeze_p: (jnp.squeeze, key1D), jax.lax.dynamic_slice_p: (partial(jax.lax.dynamic_slice, slice_sizes=(1,)), key1D, (0,)), @@ -178,14 +177,27 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase): def test_device_put(self): def f(key): assert_unconsumed(key) - key2 = jax.device_put(key) - assert_unconsumed(key) - assert_unconsumed(key2) + key_d = jax.device_put(key) + assert_unconsumed(key_d) consume(key) - assert_consumed(key) - assert_consumed(key2) + assert_consumed(key_d) self.check_key_reuse(f, jax.random.key(0)) + def test_device_put_multiple(self): + def f(key1, key2): + assert_unconsumed(key1) + assert_unconsumed(key2) + key1_d, key2_d = jax.device_put((key1, key2)) + + assert_unconsumed(key1_d) + consume(key1) + assert_consumed(key1_d) + + assert_unconsumed(key2_d) + consume(key2) + assert_consumed(key2_d) + self.check_key_reuse(f, jax.random.key(0), jax.random.key(1)) + def test_squeeze(self): def f(key): assert_unconsumed(key)