mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
255 lines
9.2 KiB
Python
255 lines
9.2 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utilities for pseudo-random number generation.
|
|
|
|
The :mod:`jax.random` package provides a number of routines for deterministic
|
|
generation of sequences of pseudorandom numbers.
|
|
|
|
Basic usage
|
|
-----------
|
|
|
|
>>> seed = 1701
|
|
>>> num_steps = 100
|
|
>>> key = jax.random.key(seed)
|
|
>>> for i in range(num_steps):
|
|
... key, subkey = jax.random.split(key)
|
|
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
|
|
|
|
PRNG keys
|
|
---------
|
|
|
|
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
|
|
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
|
|
be passed as a first argument.
|
|
The random state is described by a special array element type that we call a **key**,
|
|
usually generated by the :py:func:`jax.random.key` function::
|
|
|
|
>>> from jax import random
|
|
>>> key = random.key(0)
|
|
>>> key
|
|
Array((), dtype=key<fry>) overlaying:
|
|
[0 0]
|
|
|
|
This key can then be used in any of JAX's random number generation routines::
|
|
|
|
>>> random.uniform(key)
|
|
Array(0.947667, dtype=float32)
|
|
|
|
Note that using a key does not modify it, so reusing the same key will lead to the same result::
|
|
|
|
>>> random.uniform(key)
|
|
Array(0.947667, dtype=float32)
|
|
|
|
If you need a new random number, you can use :meth:`jax.random.split` to generate new subkeys::
|
|
|
|
>>> key, subkey = random.split(key)
|
|
>>> random.uniform(subkey)
|
|
Array(0.00729382, dtype=float32)
|
|
|
|
.. note::
|
|
|
|
Typed key arrays, with element types such as ``key<fry>`` above,
|
|
were introduced in JAX v0.4.16. Before then, keys were
|
|
conventionally represented in ``uint32`` arrays, whose final
|
|
dimension represented the key's bit-level representation.
|
|
|
|
Both forms of key array can still be created and used with the
|
|
:mod:`jax.random` module. New-style typed key arrays are made with
|
|
:py:func:`jax.random.key`. Legacy ``uint32`` key arrays are made
|
|
with :py:func:`jax.random.PRNGKey`.
|
|
|
|
To convert between the two, use :py:func:`jax.random.key_data` and
|
|
:py:func:`jax.random.wrap_key_data`. The legacy key format may be
|
|
needed when interfacing with systems outside of JAX (e.g. exporting
|
|
arrays to a serializable format), or when passing keys to JAX-based
|
|
libraries that assume the legacy format.
|
|
|
|
Otherwise, typed keys are recommended. Caveats of legacy keys
|
|
relative to typed ones include:
|
|
|
|
* They have an extra trailing dimension.
|
|
|
|
* They have a numeric dtype (``uint32``), allowing for operations
|
|
that are typically not meant to be carried out over keys, such as
|
|
integer arithmetic.
|
|
|
|
* They do not carry information about the RNG implementation. When
|
|
legacy keys are passed to :mod:`jax.random` functions, a global
|
|
configuration setting determines the RNG implementation (see
|
|
"Advanced RNG configuration" below).
|
|
|
|
To learn more about this upgrade, and the design of key types, see
|
|
`JEP 9263
|
|
<https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html>`_.
|
|
|
|
Advanced
|
|
--------
|
|
|
|
Design and background
|
|
=====================
|
|
|
|
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
|
|
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
|
|
|
|
See `docs/jep/263-prng.md <https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md>`_
|
|
for more details.
|
|
|
|
To summarize, among other requirements, the JAX PRNG aims to:
|
|
|
|
1. ensure reproducibility,
|
|
2. parallelize well, both in terms of vectorization (generating array values)
|
|
and multi-replica, multi-core computation. In particular it should not use
|
|
sequencing constraints between random function calls.
|
|
|
|
Advanced RNG configuration
|
|
==========================
|
|
|
|
JAX provides several PRNG implementations. A specific one can be
|
|
selected with the optional ``impl`` keyword argument to
|
|
``jax.random.key``. When no ``impl`` option is passed to the ``key``
|
|
constructor, the implementation is determined by the global
|
|
``jax_default_prng_impl`` configuration flag. The string names of
|
|
available implementations are:
|
|
|
|
- ``"threefry2x32"`` (**default**):
|
|
A counter-based PRNG based on a variant of the Threefry hash function,
|
|
as described in `this paper by Salmon et al., 2011
|
|
<http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
|
|
|
|
- ``"rbg"`` and ``"unsafe_rbg"`` (**experimental**): PRNGs built atop
|
|
`XLA's Random Bit Generator (RBG) algorithm
|
|
<https://openxla.org/xla/operation_semantics#rngbitgenerator>`_.
|
|
|
|
- ``"rbg"`` uses XLA RBG for random number generation, whereas for
|
|
key derivation (as in ``jax.random.split`` and
|
|
``jax.random.fold_in``) it uses the same method as
|
|
``"threefry2x32"``.
|
|
|
|
- ``"unsafe_rbg"`` uses XLA RBG for both generation as well as key
|
|
derivation.
|
|
|
|
Random numbers generated by these experimental schemes have not
|
|
been subject to empirical randomness testing (e.g. BigCrush).
|
|
|
|
Key derivation in ``"unsafe_rbg"`` has also not been empirically
|
|
tested. The name emphasizes "unsafe" because key derivation
|
|
quality and generation quality are not well understood.
|
|
|
|
Additionally, both ``"rbg"`` and ``"unsafe_rbg"`` behave unusually
|
|
under ``jax.vmap``. When vmapping a random function over a batch
|
|
of keys, its output values can differ from its true map over the
|
|
same keys. Instead, under ``vmap``, the entire batch of output
|
|
random numbers is generated from only the first key in the input
|
|
key batch. For example, if ``keys`` is a vector of 8 keys, then
|
|
``jax.vmap(jax.random.normal)(keys)`` equals
|
|
``jax.random.normal(keys[0], shape=(8,))``. This peculiarity
|
|
reflects a workaround to XLA RBG's limited batching support.
|
|
|
|
Reasons to use an alternative to the default RNG include that:
|
|
|
|
1. It may be slow to compile for TPUs.
|
|
2. It is relatively slower to execute on TPUs.
|
|
|
|
**Automatic partitioning:**
|
|
|
|
In order for ``jax.jit`` to efficiently auto-partition functions that
|
|
generate sharded random number arrays (or key arrays), all PRNG
|
|
implementations require extra flags:
|
|
|
|
- For ``"threefry2x32"``, and ``"rbg"`` key derivation, set
|
|
``jax_threefry_partitionable=True``.
|
|
- For ``"unsafe_rbg"``, and ``"rbg"`` random generation", set the XLA
|
|
flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
|
|
|
|
The XLA flag can be set using an the ``XLA_FLAGS`` environment
|
|
variable, e.g. as
|
|
``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
|
|
|
|
For more about ``jax_threefry_partitionable``, see
|
|
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
|
|
|
**Summary:**
|
|
|
|
.. table::
|
|
:widths: auto
|
|
|
|
================================= ======== ========= === ========== ===== ============
|
|
Property Threefry Threefry* rbg unsafe_rbg rbg** unsafe_rbg**
|
|
================================= ======== ========= === ========== ===== ============
|
|
Fastest on TPU ✅ ✅ ✅ ✅
|
|
efficiently shardable (w/ pjit) ✅ ✅ ✅
|
|
identical across shardings ✅ ✅ ✅ ✅
|
|
identical across CPU/GPU/TPU ✅ ✅
|
|
exact ``jax.vmap`` over keys ✅ ✅
|
|
================================= ======== ========= === ========== ===== ============
|
|
|
|
(*): with ``jax_threefry_partitionable=1`` set
|
|
|
|
(**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set
|
|
"""
|
|
|
|
# Note: import <name> as <name> is required for names to be exported.
|
|
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
|
|
|
from jax._src.random import (
|
|
PRNGKey as PRNGKey,
|
|
ball as ball,
|
|
bernoulli as bernoulli,
|
|
binomial as binomial,
|
|
beta as beta,
|
|
bits as bits,
|
|
categorical as categorical,
|
|
cauchy as cauchy,
|
|
chisquare as chisquare,
|
|
choice as choice,
|
|
clone as clone,
|
|
dirichlet as dirichlet,
|
|
double_sided_maxwell as double_sided_maxwell,
|
|
exponential as exponential,
|
|
f as f,
|
|
fold_in as fold_in,
|
|
gamma as gamma,
|
|
generalized_normal as generalized_normal,
|
|
geometric as geometric,
|
|
gumbel as gumbel,
|
|
key as key,
|
|
key_data as key_data,
|
|
key_impl as key_impl,
|
|
laplace as laplace,
|
|
logistic as logistic,
|
|
loggamma as loggamma,
|
|
lognormal as lognormal,
|
|
maxwell as maxwell,
|
|
multinomial as multinomial,
|
|
multivariate_normal as multivariate_normal,
|
|
normal as normal,
|
|
orthogonal as orthogonal,
|
|
pareto as pareto,
|
|
permutation as permutation,
|
|
poisson as poisson,
|
|
rademacher as rademacher,
|
|
randint as randint,
|
|
random_gamma_p as random_gamma_p,
|
|
rayleigh as rayleigh,
|
|
split as split,
|
|
t as t,
|
|
triangular as triangular,
|
|
truncated_normal as truncated_normal,
|
|
uniform as uniform,
|
|
wald as wald,
|
|
weibull_min as weibull_min,
|
|
wrap_key_data as wrap_key_data,
|
|
)
|