diff --git a/README.md b/README.md index 1395ae23a..b001a8cee 100644 --- a/README.md +++ b/README.md @@ -348,7 +348,7 @@ Some standouts: 1. [In-place mutating updates of arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). + different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. diff --git a/docs/key-concepts.md b/docs/key-concepts.md index daab2c9fd..91f0c9534 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -189,3 +189,43 @@ tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the le in a tree. You can learn more in the {ref}`working-with-pytrees` tutorial. + +(key-concepts-prngs)= +## Pseudorandom numbers + +Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using {func}`numpy.random.seed`. Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`: + +```{code-cell} +from jax import random + +key = random.key(43) +print(key) +``` + +The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions. +Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated. + +```{code-cell} +print(random.normal(key)) +print(random.normal(key)) +``` + +**The rule of thumb is: never reuse keys (unless you want identical outputs).** + +In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: + +```{code-cell} +for i in range(3): + new_key, subkey = random.split(key) + del key # The old key is consumed by split() -- we must never use it again. + + val = random.normal(subkey) + del subkey # The subkey is consumed by normal(). + + print(f"draw {i}: {val}") + key = new_key # new_key is safe to use in the next iteration. +``` + +Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. {func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. + +For more on pseudo random numbers in JAX, see the {ref}`pseudorandom-numbers` tutorial. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 92c736957..02077d2a6 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -865,312 +865,9 @@ "id": "MUycRNh6e50W" }, "source": [ - "## 🔪 Random numbers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O8vvaVt3MRG2" - }, - "source": [ - "> _If all scientific papers whose results are in doubt because of bad\n", - "> `rand()`s were to disappear from library shelves, there would be a\n", - "> gap on each shelf about as big as your fist._ - Numerical Recipes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qikt9pPW9L5K" - }, - "source": [ - "### RNGs and state\n", - "You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "rr9FeP41fynt", - "outputId": "df0ceb15-96ec-4a78-e327-c77f7ea3a745" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.2726690048900553\n", - "0.6304191979771206\n", - "0.6933648856441533\n" - ] - } - ], - "source": [ - "print(np.random.random())\n", - "print(np.random.random())\n", - "print(np.random.random())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ORMVVGZJgSVi" - }, - "source": [ - "Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "7Pyp2ajzfPO2" - }, - "outputs": [], - "source": [ - "np.random.seed(0)\n", - "rng_state = np.random.get_state()\n", - "# print(rng_state)\n", - "# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n", - "# 2481403966, 4042607538, 337614300, ... 614 more numbers...,\n", - "# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aJIxHVXCiM6m" - }, - "source": [ - "This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "GAHaDCYafpAF" - }, - "outputs": [], - "source": [ - "_ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "#print(rng_state)\n", - "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", - "# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n", + "## 🔪 Random numbers\n", "\n", - "# Let's exhaust the entropy in this PRNG statevector\n", - "for i in range(311):\n", - " _ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "#print(rng_state)\n", - "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", - "# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n", - "\n", - "# Next call iterates the RNG state for a new batch of fake \"entropy\".\n", - "_ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "# print(rng_state)\n", - "# --> ('MT19937', array([1499117434, 2949980591, 2242547484,\n", - "# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N_mWnleNogps" - }, - "source": [ - "The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n", - "\n", - "The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Uvq7nV-j4vKK" - }, - "source": [ - "### JAX PRNG" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "COjzGBpO4tzL" - }, - "source": [ - "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", - "\n", - "The random state is described by a special array element that we call a __key__:" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "id": "yPHE7KTWgAWs", - "outputId": "ae8af0ee-f19e-474e-81b6-45e894eb2fc3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([0, 0], dtype=uint32)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = random.key(0)\n", - "key" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XjYyWYNfq0hW" - }, - "source": [ - "JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!\n", - "\n", - "Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "7zUdQMynoE5e", - "outputId": "23a86b72-dfb9-410a-8e68-22b48dc10805" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.20584226]\n", - "[0 0]\n", - "[-0.20584226]\n", - "[0 0]\n" - ] - } - ], - "source": [ - "print(random.normal(key, shape=(1,)))\n", - "print(key)\n", - "# No no no!\n", - "print(random.normal(key, shape=(1,)))\n", - "print(key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hQN9van8rJgd" - }, - "source": [ - "Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "id": "ASj0_rSzqgGh", - "outputId": "2f13f249-85d1-47bb-d503-823eca6961aa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [0 0]\n", - " \\---SPLIT --> new key [4146024105 967050713]\n", - " \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "key, subkey = random.split(key)\n", - "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(r\" \\---SPLIT --> new key \", key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tqtFVE4MthO3" - }, - "source": [ - "We propagate the __key__ and make new __subkeys__ whenever we need a new random number:" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "jbC34XLor2Ek", - "outputId": "4059a2e2-0205-40bc-ad55-17709d538871" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [4146024105 967050713]\n", - " \\---SPLIT --> new key [2384771982 3928867769]\n", - " \\--> new subkey [1278412471 2182328957] --> normal [-0.58665055]\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "key, subkey = random.split(key)\n", - "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(r\" \\---SPLIT --> new key \", key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0KLYUluz3lN3" - }, - "source": [ - "We can generate more than one __subkey__ at a time:" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "id": "lEi08PJ4tfkX", - "outputId": "1f280560-155d-4c04-98e8-c41d72ee5b01" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.37533438]\n", - "[0.98645043]\n", - "[0.14553197]\n" - ] - } - ], - "source": [ - "key, *subkeys = random.split(key, 4)\n", - "for subkey in subkeys:\n", - " print(random.normal(subkey, shape=(1,)))" + "JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial." ] }, { diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 00955de23..f35c5ead1 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -384,153 +384,7 @@ jnp.sum(jnp.array(x)) ## 🔪 Random numbers -+++ {"id": "O8vvaVt3MRG2"} - -> _If all scientific papers whose results are in doubt because of bad -> `rand()`s were to disappear from library shelves, there would be a -> gap on each shelf about as big as your fist._ - Numerical Recipes - -+++ {"id": "Qikt9pPW9L5K"} - -### RNGs and state -You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness: - -```{code-cell} ipython3 -:id: rr9FeP41fynt -:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745 - -print(np.random.random()) -print(np.random.random()) -print(np.random.random()) -``` - -+++ {"id": "ORMVVGZJgSVi"} - -Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up. - -```{code-cell} ipython3 -:id: 7Pyp2ajzfPO2 - -np.random.seed(0) -rng_state = np.random.get_state() -# print(rng_state) -# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044, -# 2481403966, 4042607538, 337614300, ... 614 more numbers..., -# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0) -``` - -+++ {"id": "aJIxHVXCiM6m"} - -This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, "consuming" 2 of the uint32s in the Mersenne twister state vector: - -```{code-cell} ipython3 -:id: GAHaDCYafpAF - -_ = np.random.uniform() -rng_state = np.random.get_state() -#print(rng_state) -# --> ('MT19937', array([2443250962, 1093594115, 1878467924, -# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0) - -# Let's exhaust the entropy in this PRNG statevector -for i in range(311): - _ = np.random.uniform() -rng_state = np.random.get_state() -#print(rng_state) -# --> ('MT19937', array([2443250962, 1093594115, 1878467924, -# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0) - -# Next call iterates the RNG state for a new batch of fake "entropy". -_ = np.random.uniform() -rng_state = np.random.get_state() -# print(rng_state) -# --> ('MT19937', array([1499117434, 2949980591, 2242547484, -# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0) -``` - -+++ {"id": "N_mWnleNogps"} - -The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user. - -The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow. - -+++ {"id": "Uvq7nV-j4vKK"} - -### JAX PRNG - -+++ {"id": "COjzGBpO4tzL"} - -JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. - -The random state is described by a special array element that we call a __key__: - -```{code-cell} ipython3 -:id: yPHE7KTWgAWs -:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 - -key = random.key(0) -key -``` - -+++ {"id": "XjYyWYNfq0hW"} - -JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! - -Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__: - -```{code-cell} ipython3 -:id: 7zUdQMynoE5e -:outputId: 23a86b72-dfb9-410a-8e68-22b48dc10805 - -print(random.normal(key, shape=(1,))) -print(key) -# No no no! -print(random.normal(key, shape=(1,))) -print(key) -``` - -+++ {"id": "hQN9van8rJgd"} - -Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number: - -```{code-cell} ipython3 -:id: ASj0_rSzqgGh -:outputId: 2f13f249-85d1-47bb-d503-823eca6961aa - -print("old key", key) -key, subkey = random.split(key) -normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(r" \---SPLIT --> new key ", key) -print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) -``` - -+++ {"id": "tqtFVE4MthO3"} - -We propagate the __key__ and make new __subkeys__ whenever we need a new random number: - -```{code-cell} ipython3 -:id: jbC34XLor2Ek -:outputId: 4059a2e2-0205-40bc-ad55-17709d538871 - -print("old key", key) -key, subkey = random.split(key) -normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(r" \---SPLIT --> new key ", key) -print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) -``` - -+++ {"id": "0KLYUluz3lN3"} - -We can generate more than one __subkey__ at a time: - -```{code-cell} ipython3 -:id: lEi08PJ4tfkX -:outputId: 1f280560-155d-4c04-98e8-c41d72ee5b01 - -key, *subkeys = random.split(key, 4) -for subkey in subkeys: - print(random.normal(subkey, shape=(1,))) -``` +JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial. +++ {"id": "rg4CpMZ8c3ri"} diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 2ad1eadb0..00f77e347 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -17,6 +17,10 @@ kernelspec: +> _If all scientific papers whose results are in doubt because of bad +> `rand()`s were to disappear from library shelves, there would be a +> gap on each shelf about as big as your fist._ - Numerical Recipes + In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next. @@ -35,6 +39,19 @@ import numpy as np np.random.seed(0) ``` +Repeated calls to NumPy's stateful pseudorandom number generators (PRNGs) mutate the global state and give a stream of pseudorandom numbers: + +```{code-cell} +:id: rr9FeP41fynt +:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745 + +print(np.random.random()) +print(np.random.random()) +print(np.random.random()) +``` + +Underneath the hood, NumPy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this "entropy" has been used up. + You can inspect the content of the state using the following command. ```{code-cell} @@ -109,7 +126,7 @@ Further, when executing in multi-device environments, execution efficiency would ### Explicit random state -To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`: +To avoid these issues, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`: ```{code-cell} from jax import random @@ -137,6 +154,7 @@ Re-using the same key, even with different {mod}`~jax.random` APIs, can result i **The rule of thumb is: never reuse keys (unless you want identical outputs).** +JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation. In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: ```{code-cell} diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 590f68ac0..6540fd1f5 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -677,7 +677,7 @@ class KeyReuseError(JAXTypeError): KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 This sort of key reuse is problematic because the JAX PRNG is stateless, and keys - must be manually split; For more information on this see `Sharp Bits: Random Numbers - `_. + must be manually split; For more information on this see `the Pseudorandom Numbers + tutorial `_. """ pass