mirror of
https://github.com/ROCm/jax.git
synced 2025-04-23 20:26:05 +00:00

These APIs were all removed 3 or more months ago, and the registrations here cause them to raise informative AttributeErrors. Enough time has passed now that we can remove these.
254 lines
9.2 KiB
Python
254 lines
9.2 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utilities for pseudo-random number generation.
|
|
|
|
The :mod:`jax.random` package provides a number of routines for deterministic
|
|
generation of sequences of pseudorandom numbers.
|
|
|
|
Basic usage
|
|
-----------
|
|
|
|
>>> seed = 1701
|
|
>>> num_steps = 100
|
|
>>> key = jax.random.key(seed)
|
|
>>> for i in range(num_steps):
|
|
... key, subkey = jax.random.split(key)
|
|
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
|
|
|
|
PRNG keys
|
|
---------
|
|
|
|
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
|
|
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
|
|
be passed as a first argument.
|
|
The random state is described by a special array element type that we call a **key**,
|
|
usually generated by the :py:func:`jax.random.key` function::
|
|
|
|
>>> from jax import random
|
|
>>> key = random.key(0)
|
|
>>> key
|
|
Array((), dtype=key<fry>) overlaying:
|
|
[0 0]
|
|
|
|
This key can then be used in any of JAX's random number generation routines::
|
|
|
|
>>> random.uniform(key)
|
|
Array(0.41845703, dtype=float32)
|
|
|
|
Note that using a key does not modify it, so reusing the same key will lead to the same result::
|
|
|
|
>>> random.uniform(key)
|
|
Array(0.41845703, dtype=float32)
|
|
|
|
If you need a new random number, you can use :meth:`jax.random.split` to generate new subkeys::
|
|
|
|
>>> key, subkey = random.split(key)
|
|
>>> random.uniform(subkey)
|
|
Array(0.10536897, dtype=float32)
|
|
|
|
.. note::
|
|
|
|
Typed key arrays, with element types such as ``key<fry>`` above,
|
|
were introduced in JAX v0.4.16. Before then, keys were
|
|
conventionally represented in ``uint32`` arrays, whose final
|
|
dimension represented the key's bit-level representation.
|
|
|
|
Both forms of key array can still be created and used with the
|
|
:mod:`jax.random` module. New-style typed key arrays are made with
|
|
:py:func:`jax.random.key`. Legacy ``uint32`` key arrays are made
|
|
with :py:func:`jax.random.PRNGKey`.
|
|
|
|
To convert between the two, use :py:func:`jax.random.key_data` and
|
|
:py:func:`jax.random.wrap_key_data`. The legacy key format may be
|
|
needed when interfacing with systems outside of JAX (e.g. exporting
|
|
arrays to a serializable format), or when passing keys to JAX-based
|
|
libraries that assume the legacy format.
|
|
|
|
Otherwise, typed keys are recommended. Caveats of legacy keys
|
|
relative to typed ones include:
|
|
|
|
* They have an extra trailing dimension.
|
|
|
|
* They have a numeric dtype (``uint32``), allowing for operations
|
|
that are typically not meant to be carried out over keys, such as
|
|
integer arithmetic.
|
|
|
|
* They do not carry information about the RNG implementation. When
|
|
legacy keys are passed to :mod:`jax.random` functions, a global
|
|
configuration setting determines the RNG implementation (see
|
|
"Advanced RNG configuration" below).
|
|
|
|
To learn more about this upgrade, and the design of key types, see
|
|
`JEP 9263
|
|
<https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html>`_.
|
|
|
|
Advanced
|
|
--------
|
|
|
|
Design and background
|
|
=====================
|
|
|
|
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
|
|
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
|
|
|
|
See `docs/jep/263-prng.md <https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md>`_
|
|
for more details.
|
|
|
|
To summarize, among other requirements, the JAX PRNG aims to:
|
|
|
|
1. ensure reproducibility,
|
|
2. parallelize well, both in terms of vectorization (generating array values)
|
|
and multi-replica, multi-core computation. In particular it should not use
|
|
sequencing constraints between random function calls.
|
|
|
|
Advanced RNG configuration
|
|
==========================
|
|
|
|
JAX provides several PRNG implementations. A specific one can be
|
|
selected with the optional ``impl`` keyword argument to
|
|
``jax.random.key``. When no ``impl`` option is passed to the ``key``
|
|
constructor, the implementation is determined by the global
|
|
``jax_default_prng_impl`` configuration flag. The string names of
|
|
available implementations are:
|
|
|
|
- ``"threefry2x32"`` (**default**):
|
|
A counter-based PRNG based on a variant of the Threefry hash function,
|
|
as described in `this paper by Salmon et al., 2011
|
|
<http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
|
|
|
|
- ``"rbg"`` and ``"unsafe_rbg"`` (**experimental**): PRNGs built atop
|
|
`XLA's Random Bit Generator (RBG) algorithm
|
|
<https://openxla.org/xla/operation_semantics#rngbitgenerator>`_.
|
|
|
|
- ``"rbg"`` uses XLA RBG for random number generation, whereas for
|
|
key derivation (as in ``jax.random.split`` and
|
|
``jax.random.fold_in``) it uses the same method as
|
|
``"threefry2x32"``.
|
|
|
|
- ``"unsafe_rbg"`` uses XLA RBG for both generation as well as key
|
|
derivation.
|
|
|
|
Random numbers generated by these experimental schemes have not
|
|
been subject to empirical randomness testing (e.g. BigCrush).
|
|
|
|
Key derivation in ``"unsafe_rbg"`` has also not been empirically
|
|
tested. The name emphasizes "unsafe" because key derivation
|
|
quality and generation quality are not well understood.
|
|
|
|
Additionally, both ``"rbg"`` and ``"unsafe_rbg"`` behave unusually
|
|
under ``jax.vmap``. When vmapping a random function over a batch
|
|
of keys, its output values can differ from its true map over the
|
|
same keys. Instead, under ``vmap``, the entire batch of output
|
|
random numbers is generated from only the first key in the input
|
|
key batch. For example, if ``keys`` is a vector of 8 keys, then
|
|
``jax.vmap(jax.random.normal)(keys)`` equals
|
|
``jax.random.normal(keys[0], shape=(8,))``. This peculiarity
|
|
reflects a workaround to XLA RBG's limited batching support.
|
|
|
|
Reasons to use an alternative to the default RNG include that:
|
|
|
|
1. It may be slow to compile for TPUs.
|
|
2. It is relatively slower to execute on TPUs.
|
|
|
|
**Automatic partitioning:**
|
|
|
|
In order for ``jax.jit`` to efficiently auto-partition functions that
|
|
generate sharded random number arrays (or key arrays), all PRNG
|
|
implementations require extra flags:
|
|
|
|
- For ``"threefry2x32"``, and ``"rbg"`` key derivation, set
|
|
``jax_threefry_partitionable=True``.
|
|
- For ``"unsafe_rbg"``, and ``"rbg"`` random generation", set the XLA
|
|
flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
|
|
|
|
The XLA flag can be set using an the ``XLA_FLAGS`` environment
|
|
variable, e.g. as
|
|
``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
|
|
|
|
For more about ``jax_threefry_partitionable``, see
|
|
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
|
|
|
**Summary:**
|
|
|
|
.. table::
|
|
:widths: auto
|
|
|
|
================================= ======== ========= === ========== ===== ============
|
|
Property Threefry Threefry* rbg unsafe_rbg rbg** unsafe_rbg**
|
|
================================= ======== ========= === ========== ===== ============
|
|
Fastest on TPU ✅ ✅ ✅ ✅
|
|
efficiently shardable (w/ pjit) ✅ ✅ ✅
|
|
identical across shardings ✅ ✅ ✅ ✅
|
|
identical across CPU/GPU/TPU ✅ ✅
|
|
exact ``jax.vmap`` over keys ✅ ✅
|
|
================================= ======== ========= === ========== ===== ============
|
|
|
|
(*): with ``jax_threefry_partitionable=1`` set
|
|
|
|
(**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set
|
|
"""
|
|
|
|
# Note: import <name> as <name> is required for names to be exported.
|
|
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
|
|
|
from jax._src.random import (
|
|
PRNGKey as PRNGKey,
|
|
ball as ball,
|
|
bernoulli as bernoulli,
|
|
binomial as binomial,
|
|
beta as beta,
|
|
bits as bits,
|
|
categorical as categorical,
|
|
cauchy as cauchy,
|
|
chisquare as chisquare,
|
|
choice as choice,
|
|
clone as clone,
|
|
dirichlet as dirichlet,
|
|
double_sided_maxwell as double_sided_maxwell,
|
|
exponential as exponential,
|
|
f as f,
|
|
fold_in as fold_in,
|
|
gamma as gamma,
|
|
generalized_normal as generalized_normal,
|
|
geometric as geometric,
|
|
gumbel as gumbel,
|
|
key as key,
|
|
key_data as key_data,
|
|
key_impl as key_impl,
|
|
laplace as laplace,
|
|
logistic as logistic,
|
|
loggamma as loggamma,
|
|
lognormal as lognormal,
|
|
maxwell as maxwell,
|
|
multivariate_normal as multivariate_normal,
|
|
normal as normal,
|
|
orthogonal as orthogonal,
|
|
pareto as pareto,
|
|
permutation as permutation,
|
|
poisson as poisson,
|
|
rademacher as rademacher,
|
|
randint as randint,
|
|
random_gamma_p as random_gamma_p,
|
|
rayleigh as rayleigh,
|
|
split as split,
|
|
t as t,
|
|
triangular as triangular,
|
|
truncated_normal as truncated_normal,
|
|
uniform as uniform,
|
|
wald as wald,
|
|
weibull_min as weibull_min,
|
|
wrap_key_data as wrap_key_data,
|
|
)
|