Jake VanderPlas
f83175fc94
[key reuse] fix signature for device_put
2025-01-17 09:47:50 -08:00
Jake VanderPlas
6541a62099
jax.core: deprecate a number of APIs
2024-12-10 11:11:32 -08:00
Jake VanderPlas
d5405bd92f
[key reuse] handle reuse of closed-over constants
2024-04-11 15:39:45 -07:00
Jake VanderPlas
f090074d86
Avoid 'from jax import config' imports
...
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
8949a63ce1
[key reuse] rename flag to jax_debug_key_reuse
2024-03-22 05:37:30 -07:00
Jake VanderPlas
332450bac7
[key reuse] add internal function_type_signature utility
2024-03-19 10:19:12 -07:00
Jake VanderPlas
ae4e273b74
Add key reuse config to trace context
2024-03-14 06:59:37 -07:00
Jake VanderPlas
2ba9b45277
[key-reuse] fix flaky test
2024-03-12 16:49:16 -07:00
jax authors
63538771b5
Merge pull request #20183 from jakevdp:key-reuse-concatenate
...
PiperOrigin-RevId: 614851106
2024-03-11 17:37:59 -07:00
Jake VanderPlas
6cf740ceb1
[key reuse] improve repr for signatures
2024-03-11 15:17:08 -07:00
Jake VanderPlas
3eff032aba
[key reuse] define rule for lax.concatenate
2024-03-11 15:06:59 -07:00
Jake VanderPlas
d1e49f9c89
[key reuse] fix random_clone impl rule
2024-03-08 15:16:39 -08:00
Jake VanderPlas
0644f192f2
[key reuse] improve KeyReuseSignature semantics
2024-03-08 12:28:00 -08:00
Jake VanderPlas
7634708743
[key reuse] define KeyReuseError in jax.errors
2024-03-08 10:59:06 -08:00
Jake VanderPlas
6771a59181
[key reuse] add jax.random.clone
2024-03-08 09:06:00 -08:00
Yash Katariya
1cb8d31c66
Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays.
...
Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.
PiperOrigin-RevId: 613289252
2024-03-06 11:42:40 -08:00
Jake VanderPlas
735ec63dd1
[key reuse] improve error message using source_info_util
2024-03-05 11:02:39 -08:00
Jake VanderPlas
bb91bf2e09
[key reuse] improve some key reuse errors.
2024-03-05 08:14:20 -08:00
Jake VanderPlas
84d11d7b11
[key reuse] don't consume on equality check
2024-03-04 13:32:35 -08:00
Jake VanderPlas
d08e9a03d8
[key reuse] add eager checks
2024-02-29 15:30:19 -08:00
Jake VanderPlas
5ae4bffcb2
[key reuse] avoid inadvertently duplicated test cases
2024-02-29 11:44:46 -08:00
Jake VanderPlas
8eab599530
[key reuse] simplify key reuse logic through context-free jaxpr evaluation
...
The args_consumed and forwarded_inputs context is not actually needed, because it can be checked
afterward. The only reason for this was to have more granular errors, but arguably it's better
to error on jaxpr input.
2024-02-15 15:50:50 -08:00
Jake VanderPlas
8284c164a3
[key reuse] remove alternate simple implementation
2024-02-15 11:35:58 -08:00
Jake VanderPlas
49eb7008c0
Define reuse_key primitive in jax._src.prng
2024-02-14 14:01:08 -08:00
Jake VanderPlas
7360edd404
[key reuse]: handle remat
2024-02-05 16:02:12 -08:00
jax authors
c1c0c1cf33
Merge pull request #19634 from jakevdp:key-reuse-scan
...
PiperOrigin-RevId: 604418753
2024-02-05 13:39:07 -08:00
Jake VanderPlas
f4f8293c38
[key reuse] fix scan single-key consumption issue
2024-02-01 15:36:36 -08:00
Jake VanderPlas
f4534427ad
[key reuse] fix pjit false positive
2024-02-01 14:28:42 -08:00
Jake VanderPlas
17935aff01
[key reuse] fix key reuse type for cond with sources
2024-01-26 14:42:55 -08:00
Jake VanderPlas
b069c20e56
[key reuse] don't consume key in fold_in
...
Why? We've found in practice that downstream projects use fold_in multiple
times with the same key. This is safe so long as the folded-in value is
different every time; in this sense fold_in() is similar to seed(), and
for now we must trust the user to not repeat seeds.
2024-01-25 15:35:51 -08:00
Jake VanderPlas
03ce8ca0ca
jax.random: deprecate passing of batched keys to APIs
2024-01-17 12:53:24 -08:00
Jake VanderPlas
a52d18781e
Add experimental static key reuse checking
2023-12-11 12:03:48 -08:00