11 Commits

Author SHA1 Message Date
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Jake VanderPlas
6da4750c3b [random] remove internal uses of deprecated prng.seed_with_impl() 2023-10-17 13:18:08 -07:00
Jake VanderPlas
4ba7590d85 export jax.extend.source_info_util.current
PiperOrigin-RevId: 573290435
2023-10-13 12:31:11 -07:00
Jake VanderPlas
4e463c0aa2 JEX: add jax.extend.source_info_util 2023-10-13 09:36:00 -07:00
Roy Frostig
5158e251b6 identify PRNG schemes on key arrays, and recognize them in key constructors
Specifically:

* Introduce `jax.random.key_impl`, which accepts a key array and
  returns a hashable identifier of its PRNG implementation.

* Accept this identifier optionally as the `impl` argument to
  `jax.random.key` and `wrap_key_data`.

This now works:

```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```

This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-10-06 10:15:08 -07:00
Jake VanderPlas
48087cbe8d JEX: add jex.abstract_arrays.array_types 2023-09-19 11:37:05 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
Roy Frostig
a69f134cde add jax.extend.random.wrap_key_data 2023-08-26 11:39:25 -07:00
Roy Frostig
a71c0e6ecc create jax.extend.random as a copy of jax.prng
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00
Roy Frostig
ca008f37e3 initiate jax.extend via docs and top-level module set-up 2023-05-15 15:47:06 -07:00