Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.
PiperOrigin-RevId: 578349699
Specifically:
* Introduce `jax.random.key_impl`, which accepts a key array and
returns a hashable identifier of its PRNG implementation.
* Accept this identifier optionally as the `impl` argument to
`jax.random.key` and `wrap_key_data`.
This now works:
```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```
This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).
Co-authored-by: Jake Vanderplas <jakevdp@google.com>