rocm_jax/jax/random.py

114 lines
3.0 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-11-20 16:41:32 -08:00
"""Utilities for pseudo-random number generation.
The ``jax.random`` package provides a number of routines for deterministic
generation of sequences of pseudorandom numbers.
Basic usage
-----------
>>> seed = 1701
>>> num_steps = 100
2020-11-20 16:41:32 -08:00
>>> key = jax.random.PRNGKey(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
2020-11-20 16:41:32 -08:00
PRNG Keys
---------
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
be passed as a first argument.
The random state is described by two unsigned 32-bit integers that we call a **key**,
usually generated by the :py:func:`jax.random.PRNGKey` function::
>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key
DeviceArray([0, 0], dtype=uint32)
This key can then be used in any of JAX's random number generation routines::
>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)
Note that using a key does not modify it, so reusing the same key will lead to the same result::
>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)
If you need a new random number, you can use :meth:`jax.random.split` to generate new subkeys::
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
DeviceArray(0.10536897, dtype=float32)
Design and Context
------------------
Among other requirements, the JAX PRNG aims to:
(a) ensure reproducibility,
(b) parallelize well, both in terms of vectorization (generating array values)
and multi-replica, multi-core computation. In particular it should not use
sequencing constraints between random function calls.
The approach is based on:
1. "Parallel random numbers: as easy as 1, 2, 3" (Salmon et al. 2011)
2. "Splittable pseudorandom number generators using cryptographic hashing"
(Claessen et al. 2013)
See also https://github.com/google/jax/blob/main/design_notes/prng.md
2020-11-20 16:41:32 -08:00
for the design and its motivation.
"""
# flake8: noqa: F401
from jax._src.random import (
PRNGKey,
bernoulli,
beta,
categorical,
cauchy,
choice,
dirichlet,
double_sided_maxwell,
exponential,
fold_in,
gamma,
gumbel,
laplace,
logistic,
maxwell,
multivariate_normal,
normal,
pareto,
permutation,
poisson,
rademacher,
randint,
random_gamma_p,
shuffle,
split,
t,
threefry2x32_p,
threefry_2x32,
truncated_normal,
uniform,
weibull_min,
)