mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #22684 from froystig:rngdoc
PiperOrigin-RevId: 656600958
This commit is contained in:
commit
dab15d6fdd
@ -117,31 +117,70 @@ Advanced RNG configuration
|
||||
==========================
|
||||
|
||||
JAX provides several PRNG implementations. A specific one can be
|
||||
selected with the optional `impl` keyword argument to
|
||||
`jax.random.key`. When no `impl` option is passed to the `key`
|
||||
selected with the optional ``impl`` keyword argument to
|
||||
``jax.random.key``. When no ``impl`` option is passed to the ``key``
|
||||
constructor, the implementation is determined by the global
|
||||
`jax_default_prng_impl` configuration flag.
|
||||
``jax_default_prng_impl`` configuration flag. The string names of
|
||||
available implementations are:
|
||||
|
||||
- **default**, `"threefry2x32"`:
|
||||
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
|
||||
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
|
||||
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
|
||||
- ``"threefry2x32"`` (**default**):
|
||||
A counter-based PRNG based on a variant of the Threefry hash function,
|
||||
as described in `this paper by Salmon et al., 2011
|
||||
<http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
|
||||
|
||||
- `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
|
||||
- `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
|
||||
splitting (using an untested made up algorithm) and generating.
|
||||
- ``"rbg"`` and ``"unsafe_rbg"`` (**experimental**): PRNGs built atop
|
||||
`XLA's Random Bit Generator (RBG) algorithm
|
||||
<https://openxla.org/xla/operation_semantics#rngbitgenerator>`_.
|
||||
|
||||
The random streams generated by these experimental implementations haven't
|
||||
been subject to any empirical randomness testing (e.g. Big Crush). The
|
||||
random bits generated may change between JAX versions.
|
||||
- ``"rbg"`` uses XLA RBG for random number generation, whereas for
|
||||
key derivation (as in ``jax.random.split`` and
|
||||
``jax.random.fold_in``) it uses the same method as
|
||||
``"threefry2x32"``.
|
||||
|
||||
The possible reasons not use the default RNG are:
|
||||
- ``"unsafe_rbg"`` uses XLA RBG for both generation as well as key
|
||||
derivation.
|
||||
|
||||
1. it may be slow to compile (specifically for Google Cloud TPUs)
|
||||
2. it's slower to execute on TPUs
|
||||
3. it doesn't support efficient automatic sharding / partitioning
|
||||
Random numbers generated by these experimental schemes have not
|
||||
been subject to empirical randomness testing (e.g. BigCrush).
|
||||
|
||||
Here is a short summary:
|
||||
Key derivation in ``"unsafe_rbg"`` has also not been empirically
|
||||
tested. The name emphasizes "unsafe" because key derivation
|
||||
quality and generation quality are not well understood.
|
||||
|
||||
Additionally, both ``"rbg"`` and ``"unsafe_rbg"`` behave unusually
|
||||
under ``jax.vmap``. When vmapping a random function over a batch
|
||||
of keys, its output values can differ from its true map over the
|
||||
same keys. Instead, under ``vmap``, the entire batch of output
|
||||
random numbers is generated from only the first key in the input
|
||||
key batch. For example, if ``keys`` is a vector of 8 keys, then
|
||||
``jax.vmap(jax.random.normal)(keys)`` equals
|
||||
``jax.random.normal(keys[0], shape=(8,))``. This peculiarity
|
||||
reflects a workaround to XLA RBG's limited batching support.
|
||||
|
||||
Reasons to use an alternative to the default RNG include that:
|
||||
|
||||
1. It may be slow to compile for TPUs.
|
||||
2. It is relatively slower to execute on TPUs.
|
||||
|
||||
**Automatic partitioning:**
|
||||
|
||||
In order for ``jax.jit`` to efficiently auto-partition functions that
|
||||
generate sharded random number arrays (or key arrays), all PRNG
|
||||
implementations require extra flags:
|
||||
|
||||
- For ``"threefry2x32"``, and ``"rbg"`` key derivation, set
|
||||
``jax_threefry_partitionable=True``.
|
||||
- For ``"unsafe_rbg"``, and ``"rbg"`` random generation", set the XLA
|
||||
flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
|
||||
|
||||
The XLA flag can be set using an the ``XLA_FLAGS`` environment
|
||||
variable, e.g. as
|
||||
``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
|
||||
|
||||
For more about ``jax_threefry_partitionable``, see
|
||||
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
||||
|
||||
**Summary:**
|
||||
|
||||
.. table::
|
||||
:widths: auto
|
||||
@ -153,22 +192,12 @@ Here is a short summary:
|
||||
efficiently shardable (w/ pjit) ✅ ✅ ✅
|
||||
identical across shardings ✅ ✅ ✅ ✅
|
||||
identical across CPU/GPU/TPU ✅ ✅
|
||||
identical across JAX/XLA versions ✅ ✅
|
||||
exact ``jax.vmap`` over keys ✅ ✅
|
||||
================================= ======== ========= === ========== ===== ============
|
||||
|
||||
(*): with ``jax_threefry_partitionable=1`` set
|
||||
|
||||
(**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set
|
||||
|
||||
The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less
|
||||
robust/studied hash function for random value generation (but not for
|
||||
`jax.random.split` or `jax.random.fold_in`), "unsafe_rbg" additionally uses less
|
||||
robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
|
||||
less safe in the sense that the quality of random streams it generates from
|
||||
different keys is less well understood.
|
||||
|
||||
For more about `jax_threefry_partitionable`, see
|
||||
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
||||
"""
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
|
Loading…
x
Reference in New Issue
Block a user