5 Commits

Author SHA1 Message Date
Matthew Johnson
634d252bb3 improvements to RBG PRNG
1. factor out rbg_prng_impl and unsafe_rbg_prng_impl. the former uses
   threefry2x32 for split and fold_in, while the latter uses untested
   heuristics based on calling rng_bit_generator itself as a kind of
   hash function
2. for unsafe_rbg_prng_impl's split and fold_in, generate longer
   sequences from rng_bit_generator (10x iterations) which may be useful on
   some backends
3. for unsafe_rbg_prng_impl, actually apply rng_bit_generator as our
   'hash function' in fold_in

Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
2021-10-07 18:59:13 -07:00
Matthew Johnson
980dfcfd2c add experimental RngBitGenerator ("RBG") PRNG
Not only is this an experimental API, but also because the
RngBitGenerator is not guaranteed to be stable across compiler versions
(or backends or shardings), let's assert that this JAX PRNG
implementation may not be stable across JAX versions.

Even without that kind of stability, this PRNG is still useful because
compared to effectful RNG primitives, like lax.rng_uniform, this RBG
PRNG will still work correctly with lax.scan and jax.checkpoint (while
still potentially being more performant on some platforms than JAX's
standard PRNG).
2021-10-01 19:04:14 -07:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Roy Frostig
60e0e9f929 implement backwards-compatible behavior and enable custom PRNGs only conditionally
Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
2021-08-19 20:43:11 -07:00
Roy Frostig
aa265cce95 introduce custom PRNG implementations and an array-like adapter for them
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
2021-08-19 20:43:11 -07:00