mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
parent
053affd173
commit
a7f5e07549
@ -14,7 +14,7 @@ List of Available Functions
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
|
||||
PRNGKey
|
||||
ball
|
||||
bernoulli
|
||||
|
@ -106,17 +106,18 @@ Here is a short summary:
|
||||
.. table::
|
||||
:widths: auto
|
||||
|
||||
================================= ======== === ========== ======= ==============
|
||||
Property Threefry rbg unsafe_rbg rbg (*) unsafe_rbg (*)
|
||||
================================= ======== === ========== ======= ==============
|
||||
Fastest on TPU ✅ ✅ ✅ ✅
|
||||
efficiently shardable (w/ pjit) ✅ ✅
|
||||
identical across shardings ✅ ✅ ✅
|
||||
identical across CPU/GPU/TPU ✅
|
||||
identical across JAX/XLA versions ✅
|
||||
================================= ======== === ========== ======= ==============
|
||||
================================= ======== ========= === ========== ===== ============
|
||||
Property Threefry Threefry* rbg unsafe_rbg rbg** unsafe_rbg**
|
||||
================================= ======== ========= === ========== ===== ============
|
||||
Fastest on TPU ✅ ✅ ✅ ✅
|
||||
efficiently shardable (w/ pjit) ✅ ✅ ✅
|
||||
identical across shardings ✅ ✅ ✅ ✅
|
||||
identical across CPU/GPU/TPU ✅ ✅
|
||||
identical across JAX/XLA versions ✅ ✅
|
||||
================================= ======== ========= === ========== ===== ============
|
||||
|
||||
(*): with XLA_FLAGS=xla_tpu_spmd_rng_bit_generator_unsafe=1 set
|
||||
(*): with jax_threefry_partitionable=1 set
|
||||
(**): with XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set
|
||||
|
||||
The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less
|
||||
robust/studied hash function for random value generation (but not for
|
||||
@ -124,6 +125,9 @@ robust/studied hash function for random value generation (but not for
|
||||
robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
|
||||
less safe in the sense that the quality of random streams it generates from
|
||||
different keys is less well understood.
|
||||
|
||||
For more about jax_threefry_partitionable, see
|
||||
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
||||
"""
|
||||
|
||||
from jax._src.prng import PRNGKeyArray as _PRNGKeyArray
|
||||
|
Loading…
x
Reference in New Issue
Block a user