32 Commits

Author SHA1 Message Date
Jake VanderPlas
f83175fc94 [key reuse] fix signature for device_put 2025-01-17 09:47:50 -08:00
Jake VanderPlas
6541a62099 jax.core: deprecate a number of APIs 2024-12-10 11:11:32 -08:00
Jake VanderPlas
d5405bd92f [key reuse] handle reuse of closed-over constants 2024-04-11 15:39:45 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
Jake VanderPlas
332450bac7 [key reuse] add internal function_type_signature utility 2024-03-19 10:19:12 -07:00
Jake VanderPlas
ae4e273b74 Add key reuse config to trace context 2024-03-14 06:59:37 -07:00
Jake VanderPlas
2ba9b45277 [key-reuse] fix flaky test 2024-03-12 16:49:16 -07:00
jax authors
63538771b5 Merge pull request #20183 from jakevdp:key-reuse-concatenate
PiperOrigin-RevId: 614851106
2024-03-11 17:37:59 -07:00
Jake VanderPlas
6cf740ceb1 [key reuse] improve repr for signatures 2024-03-11 15:17:08 -07:00
Jake VanderPlas
3eff032aba [key reuse] define rule for lax.concatenate 2024-03-11 15:06:59 -07:00
Jake VanderPlas
d1e49f9c89 [key reuse] fix random_clone impl rule 2024-03-08 15:16:39 -08:00
Jake VanderPlas
0644f192f2 [key reuse] improve KeyReuseSignature semantics 2024-03-08 12:28:00 -08:00
Jake VanderPlas
7634708743 [key reuse] define KeyReuseError in jax.errors 2024-03-08 10:59:06 -08:00
Jake VanderPlas
6771a59181 [key reuse] add jax.random.clone 2024-03-08 09:06:00 -08:00
Yash Katariya
1cb8d31c66 Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
2024-03-06 11:42:40 -08:00
Jake VanderPlas
735ec63dd1 [key reuse] improve error message using source_info_util 2024-03-05 11:02:39 -08:00
Jake VanderPlas
bb91bf2e09 [key reuse] improve some key reuse errors. 2024-03-05 08:14:20 -08:00
Jake VanderPlas
84d11d7b11 [key reuse] don't consume on equality check 2024-03-04 13:32:35 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Jake VanderPlas
5ae4bffcb2 [key reuse] avoid inadvertently duplicated test cases 2024-02-29 11:44:46 -08:00
Jake VanderPlas
8eab599530 [key reuse] simplify key reuse logic through context-free jaxpr evaluation
The args_consumed and forwarded_inputs context is not actually needed, because it can be checked
afterward. The only reason for this was to have more granular errors, but arguably it's better
to error on jaxpr input.
2024-02-15 15:50:50 -08:00
Jake VanderPlas
8284c164a3 [key reuse] remove alternate simple implementation 2024-02-15 11:35:58 -08:00
Jake VanderPlas
49eb7008c0 Define reuse_key primitive in jax._src.prng 2024-02-14 14:01:08 -08:00
Jake VanderPlas
7360edd404 [key reuse]: handle remat 2024-02-05 16:02:12 -08:00
jax authors
c1c0c1cf33 Merge pull request #19634 from jakevdp:key-reuse-scan
PiperOrigin-RevId: 604418753
2024-02-05 13:39:07 -08:00
Jake VanderPlas
f4f8293c38 [key reuse] fix scan single-key consumption issue 2024-02-01 15:36:36 -08:00
Jake VanderPlas
f4534427ad [key reuse] fix pjit false positive 2024-02-01 14:28:42 -08:00
Jake VanderPlas
17935aff01 [key reuse] fix key reuse type for cond with sources 2024-01-26 14:42:55 -08:00
Jake VanderPlas
b069c20e56 [key reuse] don't consume key in fold_in
Why? We've found in practice that downstream projects use fold_in multiple
times with the same key. This is safe so long as the folded-in value is
different every time; in this sense fold_in() is similar to seed(), and
for now we must trust the user to not repeat seeds.
2024-01-25 15:35:51 -08:00
Jake VanderPlas
03ce8ca0ca jax.random: deprecate passing of batched keys to APIs 2024-01-17 12:53:24 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00