update prng docs to mention jax_threefry_partitionable

fixes #15484
This commit is contained in:
Matthew Johnson 2023-04-07 22:53:53 -07:00
parent 053affd173
commit a7f5e07549
2 changed files with 15 additions and 11 deletions

View File

@ -14,7 +14,7 @@ List of Available Functions
.. autosummary::
:toctree: _autosummary
PRNGKey
ball
bernoulli

View File

@ -106,17 +106,18 @@ Here is a short summary:
.. table::
:widths: auto
================================= ======== === ========== ======= ==============
Property Threefry rbg unsafe_rbg rbg (*) unsafe_rbg (*)
================================= ======== === ========== ======= ==============
Fastest on TPU
efficiently shardable (w/ pjit)
identical across shardings
identical across CPU/GPU/TPU
identical across JAX/XLA versions
================================= ======== === ========== ======= ==============
================================= ======== ========= === ========== ===== ============
Property Threefry Threefry* rbg unsafe_rbg rbg** unsafe_rbg**
================================= ======== ========= === ========== ===== ============
Fastest on TPU
efficiently shardable (w/ pjit)
identical across shardings
identical across CPU/GPU/TPU
identical across JAX/XLA versions
================================= ======== ========= === ========== ===== ============
(*): with XLA_FLAGS=xla_tpu_spmd_rng_bit_generator_unsafe=1 set
(*): with jax_threefry_partitionable=1 set
(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set
The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less
robust/studied hash function for random value generation (but not for
@ -124,6 +125,9 @@ robust/studied hash function for random value generation (but not for
robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
less safe in the sense that the quality of random streams it generates from
different keys is less well understood.
For more about jax_threefry_partitionable, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
"""
from jax._src.prng import PRNGKeyArray as _PRNGKeyArray