1. factor out rbg_prng_impl and unsafe_rbg_prng_impl. the former uses
threefry2x32 for split and fold_in, while the latter uses untested
heuristics based on calling rng_bit_generator itself as a kind of
hash function
2. for unsafe_rbg_prng_impl's split and fold_in, generate longer
sequences from rng_bit_generator (10x iterations) which may be useful on
some backends
3. for unsafe_rbg_prng_impl, actually apply rng_bit_generator as our
'hash function' in fold_in
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
Not only is this an experimental API, but also because the
RngBitGenerator is not guaranteed to be stable across compiler versions
(or backends or shardings), let's assert that this JAX PRNG
implementation may not be stable across JAX versions.
Even without that kind of stability, this PRNG is still useful because
compared to effectful RNG primitives, like lax.rng_uniform, this RBG
PRNG will still work correctly with lax.scan and jax.checkpoint (while
still potentially being more performant on some platforms than JAX's
standard PRNG).
Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.
Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.
A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:
1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
back to it from the functions in random.