Consolidate material on PRNGs and add a short summary to Key Concepts.

This commit is contained in:
Emily Fertig 2024-11-15 10:30:13 -08:00
parent d8085008b7
commit 225a2a5f8b
6 changed files with 65 additions and 456 deletions

View File

@ -348,7 +348,7 @@ Some standouts:
1. [In-place mutating updates of
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
1. [Random numbers are
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
1. If you're looking for [convolution
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
they're in the `jax.lax` package.

View File

@ -189,3 +189,43 @@ tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the le
in a tree.
You can learn more in the {ref}`working-with-pytrees` tutorial.
(key-concepts-prngs)=
## Pseudorandom numbers
Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using {func}`numpy.random.seed`. Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`:
```{code-cell}
from jax import random
key = random.key(43)
print(key)
```
The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions.
Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.
```{code-cell}
print(random.normal(key))
print(random.normal(key))
```
**The rule of thumb is: never reuse keys (unless you want identical outputs).**
In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function:
```{code-cell}
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
```
Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. {func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys.
For more on pseudo random numbers in JAX, see the {ref}`pseudorandom-numbers` tutorial.

View File

@ -865,312 +865,9 @@
"id": "MUycRNh6e50W"
},
"source": [
"## 🔪 Random numbers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O8vvaVt3MRG2"
},
"source": [
"> _If all scientific papers whose results are in doubt because of bad\n",
"> `rand()`s were to disappear from library shelves, there would be a\n",
"> gap on each shelf about as big as your fist._ - Numerical Recipes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qikt9pPW9L5K"
},
"source": [
"### RNGs and state\n",
"You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "rr9FeP41fynt",
"outputId": "df0ceb15-96ec-4a78-e327-c77f7ea3a745"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.2726690048900553\n",
"0.6304191979771206\n",
"0.6933648856441533\n"
]
}
],
"source": [
"print(np.random.random())\n",
"print(np.random.random())\n",
"print(np.random.random())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ORMVVGZJgSVi"
},
"source": [
"Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "7Pyp2ajzfPO2"
},
"outputs": [],
"source": [
"np.random.seed(0)\n",
"rng_state = np.random.get_state()\n",
"# print(rng_state)\n",
"# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n",
"# 2481403966, 4042607538, 337614300, ... 614 more numbers...,\n",
"# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aJIxHVXCiM6m"
},
"source": [
"This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "GAHaDCYafpAF"
},
"outputs": [],
"source": [
"_ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state)\n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n",
"## 🔪 Random numbers\n",
"\n",
"# Let's exhaust the entropy in this PRNG statevector\n",
"for i in range(311):\n",
" _ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state)\n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n",
"\n",
"# Next call iterates the RNG state for a new batch of fake \"entropy\".\n",
"_ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"# print(rng_state)\n",
"# --> ('MT19937', array([1499117434, 2949980591, 2242547484,\n",
"# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N_mWnleNogps"
},
"source": [
"The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n",
"\n",
"The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uvq7nV-j4vKK"
},
"source": [
"### JAX PRNG"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "COjzGBpO4tzL"
},
"source": [
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by a special array element that we call a __key__:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "yPHE7KTWgAWs",
"outputId": "ae8af0ee-f19e-474e-81b6-45e894eb2fc3"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([0, 0], dtype=uint32)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"key = random.key(0)\n",
"key"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XjYyWYNfq0hW"
},
"source": [
"JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!\n",
"\n",
"Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "7zUdQMynoE5e",
"outputId": "23a86b72-dfb9-410a-8e68-22b48dc10805"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.20584226]\n",
"[0 0]\n",
"[-0.20584226]\n",
"[0 0]\n"
]
}
],
"source": [
"print(random.normal(key, shape=(1,)))\n",
"print(key)\n",
"# No no no!\n",
"print(random.normal(key, shape=(1,)))\n",
"print(key)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hQN9van8rJgd"
},
"source": [
"Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "ASj0_rSzqgGh",
"outputId": "2f13f249-85d1-47bb-d503-823eca6961aa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"old key [0 0]\n",
" \\---SPLIT --> new key [4146024105 967050713]\n",
" \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n"
]
}
],
"source": [
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
"print(r\" \\---SPLIT --> new key \", key)\n",
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tqtFVE4MthO3"
},
"source": [
"We propagate the __key__ and make new __subkeys__ whenever we need a new random number:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "jbC34XLor2Ek",
"outputId": "4059a2e2-0205-40bc-ad55-17709d538871"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"old key [4146024105 967050713]\n",
" \\---SPLIT --> new key [2384771982 3928867769]\n",
" \\--> new subkey [1278412471 2182328957] --> normal [-0.58665055]\n"
]
}
],
"source": [
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
"print(r\" \\---SPLIT --> new key \", key)\n",
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0KLYUluz3lN3"
},
"source": [
"We can generate more than one __subkey__ at a time:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "lEi08PJ4tfkX",
"outputId": "1f280560-155d-4c04-98e8-c41d72ee5b01"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.37533438]\n",
"[0.98645043]\n",
"[0.14553197]\n"
]
}
],
"source": [
"key, *subkeys = random.split(key, 4)\n",
"for subkey in subkeys:\n",
" print(random.normal(subkey, shape=(1,)))"
"JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial."
]
},
{

View File

@ -384,153 +384,7 @@ jnp.sum(jnp.array(x))
## 🔪 Random numbers
+++ {"id": "O8vvaVt3MRG2"}
> _If all scientific papers whose results are in doubt because of bad
> `rand()`s were to disappear from library shelves, there would be a
> gap on each shelf about as big as your fist._ - Numerical Recipes
+++ {"id": "Qikt9pPW9L5K"}
### RNGs and state
You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
```{code-cell} ipython3
:id: rr9FeP41fynt
:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745
print(np.random.random())
print(np.random.random())
print(np.random.random())
```
+++ {"id": "ORMVVGZJgSVi"}
Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up.
```{code-cell} ipython3
:id: 7Pyp2ajzfPO2
np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
```
+++ {"id": "aJIxHVXCiM6m"}
This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, "consuming" 2 of the uint32s in the Mersenne twister state vector:
```{code-cell} ipython3
:id: GAHaDCYafpAF
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
```
+++ {"id": "N_mWnleNogps"}
The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.
The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow.
+++ {"id": "Uvq7nV-j4vKK"}
### JAX PRNG
+++ {"id": "COjzGBpO4tzL"}
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by a special array element that we call a __key__:
```{code-cell} ipython3
:id: yPHE7KTWgAWs
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
key = random.key(0)
key
```
+++ {"id": "XjYyWYNfq0hW"}
JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!
Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:
```{code-cell} ipython3
:id: 7zUdQMynoE5e
:outputId: 23a86b72-dfb9-410a-8e68-22b48dc10805
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
```
+++ {"id": "hQN9van8rJgd"}
Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:
```{code-cell} ipython3
:id: ASj0_rSzqgGh
:outputId: 2f13f249-85d1-47bb-d503-823eca6961aa
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(r" \---SPLIT --> new key ", key)
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
```
+++ {"id": "tqtFVE4MthO3"}
We propagate the __key__ and make new __subkeys__ whenever we need a new random number:
```{code-cell} ipython3
:id: jbC34XLor2Ek
:outputId: 4059a2e2-0205-40bc-ad55-17709d538871
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(r" \---SPLIT --> new key ", key)
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
```
+++ {"id": "0KLYUluz3lN3"}
We can generate more than one __subkey__ at a time:
```{code-cell} ipython3
:id: lEi08PJ4tfkX
:outputId: 1f280560-155d-4c04-98e8-c41d72ee5b01
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
```
JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial.
+++ {"id": "rg4CpMZ8c3ri"}

View File

@ -17,6 +17,10 @@ kernelspec:
<!--* freshness: { reviewed: '2024-05-03' } *-->
> _If all scientific papers whose results are in doubt because of bad
> `rand()`s were to disappear from library shelves, there would be a
> gap on each shelf about as big as your fist._ - Numerical Recipes
In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.
PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.
@ -35,6 +39,19 @@ import numpy as np
np.random.seed(0)
```
Repeated calls to NumPy's stateful pseudorandom number generators (PRNGs) mutate the global state and give a stream of pseudorandom numbers:
```{code-cell}
:id: rr9FeP41fynt
:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745
print(np.random.random())
print(np.random.random())
print(np.random.random())
```
Underneath the hood, NumPy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this "entropy" has been used up.
You can inspect the content of the state using the following command.
```{code-cell}
@ -109,7 +126,7 @@ Further, when executing in multi-device environments, execution efficiency would
### Explicit random state
To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`:
To avoid these issues, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`:
```{code-cell}
from jax import random
@ -137,6 +154,7 @@ Re-using the same key, even with different {mod}`~jax.random` APIs, can result i
**The rule of thumb is: never reuse keys (unless you want identical outputs).**
JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.
In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function:
```{code-cell}

View File

@ -677,7 +677,7 @@ class KeyReuseError(JAXTypeError):
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
This sort of key reuse is problematic because the JAX PRNG is stateless, and keys
must be manually split; For more information on this see `Sharp Bits: Random Numbers
<https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers>`_.
must be manually split; For more information on this see `the Pseudorandom Numbers
tutorial <https://jax.readthedocs.io/en/latest/random-numbers.html>`_.
"""
pass