Merge pull request #22684 from froystig:rngdoc

PiperOrigin-RevId: 656600958
This commit is contained in:
jax authors 2024-07-26 19:12:36 -07:00
commit dab15d6fdd

View File

@ -117,31 +117,70 @@ Advanced RNG configuration
==========================
JAX provides several PRNG implementations. A specific one can be
selected with the optional `impl` keyword argument to
`jax.random.key`. When no `impl` option is passed to the `key`
selected with the optional ``impl`` keyword argument to
``jax.random.key``. When no ``impl`` option is passed to the ``key``
constructor, the implementation is determined by the global
`jax_default_prng_impl` configuration flag.
``jax_default_prng_impl`` configuration flag. The string names of
available implementations are:
- **default**, `"threefry2x32"`:
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
- ``"threefry2x32"`` (**default**):
A counter-based PRNG based on a variant of the Threefry hash function,
as described in `this paper by Salmon et al., 2011
<http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
- `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
- `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
splitting (using an untested made up algorithm) and generating.
- ``"rbg"`` and ``"unsafe_rbg"`` (**experimental**): PRNGs built atop
`XLA's Random Bit Generator (RBG) algorithm
<https://openxla.org/xla/operation_semantics#rngbitgenerator>`_.
The random streams generated by these experimental implementations haven't
been subject to any empirical randomness testing (e.g. Big Crush). The
random bits generated may change between JAX versions.
- ``"rbg"`` uses XLA RBG for random number generation, whereas for
key derivation (as in ``jax.random.split`` and
``jax.random.fold_in``) it uses the same method as
``"threefry2x32"``.
The possible reasons not use the default RNG are:
- ``"unsafe_rbg"`` uses XLA RBG for both generation as well as key
derivation.
1. it may be slow to compile (specifically for Google Cloud TPUs)
2. it's slower to execute on TPUs
3. it doesn't support efficient automatic sharding / partitioning
Random numbers generated by these experimental schemes have not
been subject to empirical randomness testing (e.g. BigCrush).
Here is a short summary:
Key derivation in ``"unsafe_rbg"`` has also not been empirically
tested. The name emphasizes "unsafe" because key derivation
quality and generation quality are not well understood.
Additionally, both ``"rbg"`` and ``"unsafe_rbg"`` behave unusually
under ``jax.vmap``. When vmapping a random function over a batch
of keys, its output values can differ from its true map over the
same keys. Instead, under ``vmap``, the entire batch of output
random numbers is generated from only the first key in the input
key batch. For example, if ``keys`` is a vector of 8 keys, then
``jax.vmap(jax.random.normal)(keys)`` equals
``jax.random.normal(keys[0], shape=(8,))``. This peculiarity
reflects a workaround to XLA RBG's limited batching support.
Reasons to use an alternative to the default RNG include that:
1. It may be slow to compile for TPUs.
2. It is relatively slower to execute on TPUs.
**Automatic partitioning:**
In order for ``jax.jit`` to efficiently auto-partition functions that
generate sharded random number arrays (or key arrays), all PRNG
implementations require extra flags:
- For ``"threefry2x32"``, and ``"rbg"`` key derivation, set
``jax_threefry_partitionable=True``.
- For ``"unsafe_rbg"``, and ``"rbg"`` random generation", set the XLA
flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
The XLA flag can be set using an the ``XLA_FLAGS`` environment
variable, e.g. as
``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
For more about ``jax_threefry_partitionable``, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
**Summary:**
.. table::
:widths: auto
@ -153,22 +192,12 @@ Here is a short summary:
efficiently shardable (w/ pjit)
identical across shardings
identical across CPU/GPU/TPU
identical across JAX/XLA versions
exact ``jax.vmap`` over keys
================================= ======== ========= === ========== ===== ============
(*): with ``jax_threefry_partitionable=1`` set
(**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set
The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less
robust/studied hash function for random value generation (but not for
`jax.random.split` or `jax.random.fold_in`), "unsafe_rbg" additionally uses less
robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
less safe in the sense that the quality of random streams it generates from
different keys is less well understood.
For more about `jax_threefry_partitionable`, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
"""
# Note: import <name> as <name> is required for names to be exported.