mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
JAX tutorials: pseudorandom numbers
This commit is contained in:
parent
b48254e084
commit
cacadf43c0
@ -1,3 +1,4 @@
|
||||
(jep-9263)=
|
||||
# JEP 9263: Typed keys & pluggable RNGs
|
||||
|
||||
*Jake VanderPlas, Roy Frostig*
|
||||
|
@ -1,8 +1,195 @@
|
||||
---
|
||||
jupytext:
|
||||
formats: md:myst
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.15.2
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
name: python3
|
||||
---
|
||||
|
||||
(pseudorandom-numbers)=
|
||||
# Pseudorandom numbers
|
||||
|
||||
```{note}
|
||||
This is a placeholder for a section in the new {ref}`jax-tutorials`.
|
||||
In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.
|
||||
|
||||
For the time being, you may find some related content in the old documentation:
|
||||
- {doc}`../jax-101/05-random-numbers`
|
||||
```
|
||||
PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.
|
||||
|
||||
Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.
|
||||
|
||||
To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section.
|
||||
|
||||
## Random numbers in NumPy
|
||||
|
||||
Pseudo random number generation is natively supported in NumPy by the {mod}`numpy.random` module.
|
||||
In NumPy, pseudo random number generation is based on a global `state`, which can be set to a deterministic initial condition using {func}`np.random.seed`.
|
||||
|
||||
```{code-cell}
|
||||
import numpy as np
|
||||
np.random.seed(0)
|
||||
```
|
||||
|
||||
You can inspect the content of the state using the following command.
|
||||
|
||||
```{code-cell}
|
||||
def print_truncated_random_state():
|
||||
"""To avoid spamming the outputs, print only part of the state."""
|
||||
full_random_state = np.random.get_state()
|
||||
print(str(full_random_state)[:460], '...')
|
||||
|
||||
print_truncated_random_state()
|
||||
```
|
||||
|
||||
The `state` is updated by each call to a random function:
|
||||
|
||||
```{code-cell}
|
||||
np.random.seed(0)
|
||||
print_truncated_random_state()
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
_ = np.random.uniform()
|
||||
print_truncated_random_state()
|
||||
```
|
||||
|
||||
NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:
|
||||
|
||||
```{code-cell}
|
||||
np.random.seed(0)
|
||||
print(np.random.uniform(size=3))
|
||||
```
|
||||
|
||||
NumPy provides a *sequential equivalent guarantee*, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:
|
||||
|
||||
```{code-cell}
|
||||
np.random.seed(0)
|
||||
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
|
||||
|
||||
np.random.seed(0)
|
||||
print("all at once: ", np.random.uniform(size=3))
|
||||
```
|
||||
|
||||
## Random numbers in JAX
|
||||
|
||||
JAX's random number generation differs from NumPy's in important ways, because NumPy's
|
||||
PRNG design makes it hard to simultaneously guarantee a number of desirable properties.
|
||||
Specifically, in JAX we want PRNG generation to be:
|
||||
|
||||
1. reproducible,
|
||||
2. parallelizable,
|
||||
3. vectorisable.
|
||||
|
||||
We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:
|
||||
|
||||
```{code-cell}
|
||||
import numpy as np
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
def bar(): return np.random.uniform()
|
||||
def baz(): return np.random.uniform()
|
||||
|
||||
def foo(): return bar() + 2 * baz()
|
||||
|
||||
print(foo())
|
||||
```
|
||||
|
||||
The function `foo` sums two scalars sampled from a uniform distribution.
|
||||
|
||||
The output of this code can only satisfy requirement #1 if we assume a predictable order of execution for `bar()` and `baz()`.
|
||||
This is not a problem in NumPy, which always evaluates code in the order defined by the Python interpreter.
|
||||
In JAX, however, this is more problematic: for efficient execution, we want the JIT compiler to be free to reorder, elide, and fuse various operations in the function we define.
|
||||
Further, when executing in multi-device environments, execution efficiency would be hampered by the need for each process to synchronize a global state.
|
||||
|
||||
### Explicit random state
|
||||
|
||||
To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`:
|
||||
|
||||
```{code-cell}
|
||||
from jax import random
|
||||
|
||||
key = random.key(42)
|
||||
print(key)
|
||||
```
|
||||
|
||||
```{note}
|
||||
This section uses the new-style typed PRNG keys produced by {func}`jax.random.key`, rather than the
|
||||
old-style raw PRNG keys produced by {func}`jax.random.PRNGKey`. For details, see {ref}`jep-9263`.
|
||||
```
|
||||
|
||||
A key is an array with a special dtype corresponding to the particular PRNG implementation being used; in the default implementation each key is backed by a pair of `uint32` values.
|
||||
|
||||
The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions.
|
||||
Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.
|
||||
|
||||
```{code-cell}
|
||||
print(random.normal(key))
|
||||
print(random.normal(key))
|
||||
```
|
||||
|
||||
Re-using the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable.
|
||||
|
||||
**The rule of thumb is: never reuse keys (unless you want identical outputs).**
|
||||
|
||||
In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function:
|
||||
|
||||
```{code-cell}
|
||||
for i in range(3):
|
||||
new_key, subkey = random.split(key)
|
||||
del key # The old key is consumed by split() -- we must never use it again.
|
||||
|
||||
val = random.normal(subkey)
|
||||
del subkey # The subkey is consumed by normal().
|
||||
|
||||
print(f"draw {i}: {val}")
|
||||
key = new_key # new_key is safe to use in the next iteration.
|
||||
```
|
||||
|
||||
(Calling `del` here is not required, but we do so to emphasize that the key should not be reused once consumed.)
|
||||
|
||||
{func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys.
|
||||
We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.
|
||||
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on: the crucial point is that you never use the same key twice.
|
||||
|
||||
It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`.
|
||||
They are all independent keys with equal status.
|
||||
The key/subkey naming convention is a typical usage pattern that helps track how keys are consumed:
|
||||
subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.
|
||||
|
||||
Usually, the above example would be written concisely as
|
||||
|
||||
```{code-cell}
|
||||
key, subkey = random.split(key)
|
||||
```
|
||||
|
||||
which discards the old key automatically.
|
||||
It's worth noting that {func}`~jax.random.split` can create as many keys as you need, not just 2:
|
||||
|
||||
```{code-cell}
|
||||
key, *forty_two_subkeys = random.split(key, num=43)
|
||||
```
|
||||
|
||||
### Lack of sequential equivalence
|
||||
|
||||
Another difference between NumPy's and JAX's random modules relates to the sequential equivalence guarantee mentioned above.
|
||||
|
||||
As in NumPy, JAX's random module also allows sampling of vectors of numbers.
|
||||
However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).
|
||||
|
||||
In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying `shape=(3,)`:
|
||||
|
||||
```{code-cell}
|
||||
key = random.key(42)
|
||||
subkeys = random.split(key, 3)
|
||||
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
|
||||
print("individually:", sequence)
|
||||
|
||||
key = random.key(42)
|
||||
print("all at once: ", random.normal(key, shape=(3,)))
|
||||
```
|
||||
|
||||
Note that contrary to our recommendation above, we use `key` directly as an input to {func}`random.normal` in the second example. This is because we won't reuse it anywhere else, so we don't violate the single-use principle.
|
||||
|
Loading…
x
Reference in New Issue
Block a user