Merge pull request #17741 from froystig:new-style-key-docs

PiperOrigin-RevId: 614080080
This commit is contained in:
jax authors 2024-03-08 16:41:22 -08:00
commit 0302e4c34d
47 changed files with 168 additions and 128 deletions

View File

@ -304,7 +304,7 @@ per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)
@ -1049,7 +1049,7 @@ per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)

View File

@ -9,7 +9,7 @@ program:
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3. # doctest: +SKIP

View File

@ -59,7 +59,7 @@ def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
@ -107,14 +107,14 @@ import jax.numpy as jnp
import jax.profiler
def afunction():
return jax.random.normal(jax.random.PRNGKey(77), (1000000,))
return jax.random.normal(jax.random.key(77), (1000000,))
z = afunction()
def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000))
x = jax.random.normal(jax.random.key(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")

View File

@ -282,7 +282,7 @@
"source": [
"from jax import random\n",
"\n",
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"\n",
"print(key)"
]
@ -293,7 +293,7 @@
"id": "XhFpKnW9F2nF"
},
"source": [
"A key is just an array of shape `(2,)`.\n",
"A single key is an array of scalar shape `()` and key element type.\n",
"\n",
"'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:"
]
@ -381,7 +381,7 @@
"source": [
"`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n",
"\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"\n",
"It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n",
"\n",
@ -460,12 +460,12 @@
}
],
"source": [
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"subkeys = random.split(key, 3)\n",
"sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n",
"print(\"individually:\", sequence)\n",
"\n",
"key = random.PRNGKey(42)\n",
"key = random.key(42)\n",
"print(\"all at once: \", random.normal(key, shape=(3,)))"
]
},

View File

@ -150,14 +150,14 @@ To avoid this issue, JAX does not use a global state. Instead, random functions
from jax import random
key = random.PRNGKey(42)
key = random.key(42)
print(key)
```
+++ {"id": "XhFpKnW9F2nF"}
A key is just an array of shape `(2,)`.
A single key is an array of scalar shape `()` and key element type.
'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:
@ -201,7 +201,7 @@ key = new_key # If we wanted to do this again, we would use new_key as the key.
`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.
@ -240,12 +240,12 @@ In the example below, sampling 3 values out of a normal distribution individuall
:id: 4nB_TA54D-HT
:outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56
key = random.PRNGKey(42)
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.PRNGKey(42)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
```

View File

@ -623,7 +623,7 @@
"ys = xs * true_w + true_b + noise\n",
"\n",
"# Initialise parameters and replicate across devices.\n",
"params = init(jax.random.PRNGKey(123))\n",
"params = init(jax.random.key(123))\n",
"n_devices = jax.local_device_count()\n",
"replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)"
]

View File

@ -291,7 +291,7 @@ noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise
# Initialise parameters and replicate across devices.
params = init(jax.random.PRNGKey(123))
params = init(jax.random.key(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
```

View File

@ -249,7 +249,7 @@
"\n",
"In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?\n",
"\n",
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey."
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key."
]
},
{
@ -351,7 +351,7 @@
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"rng = jax.random.PRNGKey(42)\n",
"rng = jax.random.key(42)\n",
"\n",
"# Generate true data from y = w*x + b + noise\n",
"true_w, true_b = 2, -1\n",

View File

@ -166,7 +166,7 @@ Notice that the need for a class becomes less clear once we have rewritten it th
In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey.
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key.
+++ {"id": "I2SqRx14_z98"}
@ -233,7 +233,7 @@ Notice that we manually pipe the params in and out of the update function.
import matplotlib.pyplot as plt
rng = jax.random.PRNGKey(42)
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1

View File

@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.
An initializer is a function that takes three arguments:
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
key used when generating random numbers to initialize the array.
data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
:func:`jax.random.key`), used to generate random numbers to initialize the array.
.. autosummary::
:toctree: _autosummary

View File

@ -1006,7 +1006,7 @@
"source": [
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by two unsigned-int32s that we call a __key__:"
"The random state is described by a special array element that we call a __key__:"
]
},
{
@ -1030,7 +1030,7 @@
],
"source": [
"from jax import random\n",
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"key"
]
},
@ -2121,7 +2121,7 @@
}
],
"source": [
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
"x.dtype"
]
},
@ -2188,7 +2188,7 @@
"source": [
"import jax.numpy as jnp\n",
"from jax import random\n",
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
"x.dtype # --> dtype('float64')"
]
},

View File

@ -463,14 +463,14 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by two unsigned-int32s that we call a __key__:
The random state is described by a special array element that we call a __key__:
```{code-cell} ipython3
:id: yPHE7KTWgAWs
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
from jax import random
key = random.PRNGKey(0)
key = random.key(0)
key
```
@ -1071,7 +1071,7 @@ At the moment, JAX by default enforces single-precision numbers to mitigate the
:id: CNNGtzM3NDkO
:outputId: b422bb23-a784-44dc-f8c9-57f3b6c861b8
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
```
@ -1117,7 +1117,7 @@ We can then confirm that `x64` mode is enabled:
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
```

View File

@ -131,7 +131,7 @@
],
"source": [
"# Create an array of random values:\n",
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
"# and use jax.device_put to distribute it across devices:\n",
"y = jax.device_put(x, sharding.reshape(4, 2))\n",
"jax.debug.visualize_array_sharding(y)"
@ -272,7 +272,7 @@
"outputs": [],
"source": [
"import jax\n",
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))"
"x = jax.random.normal(jax.random.key(0), (8192, 8192))"
]
},
{
@ -1513,7 +1513,7 @@
},
"outputs": [],
"source": [
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
"x = jax.device_put(x, sharding.reshape(4, 2))"
]
},
@ -1738,7 +1738,7 @@
"layer_sizes = [784, 8192, 8192, 8192, 10]\n",
"batch_size = 8192\n",
"\n",
"params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)"
"params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)"
]
},
{
@ -2184,7 +2184,7 @@
" numbers = jax.random.uniform(key, x.shape)\n",
" return x + numbers\n",
"\n",
"key = jax.random.PRNGKey(42)\n",
"key = jax.random.key(42)\n",
"x_sharding = jax.sharding.PositionalSharding(jax.devices())\n",
"x = jax.device_put(jnp.arange(24), x_sharding)"
]

View File

@ -81,7 +81,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b
# Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)
@ -144,7 +144,7 @@ For example, here's a value with a single-device `Sharding`:
:id: VmoX4SUp3vGJ
import jax
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
```
```{code-cell}
@ -609,7 +609,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
```{code-cell}
:id: Q1wuDp-L3vGT
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2))
```
@ -720,7 +720,7 @@ def init_model(key, layer_sizes, batch_size):
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
```
+++ {"id": "sJv_h0AS2drh"}
@ -902,7 +902,7 @@ def f(key, x):
numbers = jax.random.uniform(key, x.shape)
return x + numbers
key = jax.random.PRNGKey(42)
key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24), x_sharding)
```

View File

@ -84,7 +84,7 @@
"num_epochs = 8\n",
"batch_size = 128\n",
"n_targets = 10\n",
"params = init_network_params(layer_sizes, random.PRNGKey(0))"
"params = init_network_params(layer_sizes, random.key(0))"
]
},
{
@ -150,7 +150,7 @@
],
"source": [
"# This works on single examples\n",
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
"random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
"preds = predict(params, random_flattened_image)\n",
"print(preds.shape)"
]
@ -173,7 +173,7 @@
],
"source": [
"# Doesn't work with a batch\n",
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
"random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
"try:\n",
" preds = predict(params, random_flattened_images)\n",
"except TypeError:\n",

View File

@ -71,7 +71,7 @@ step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
params = init_network_params(layer_sizes, random.key(0))
```
+++ {"id": "BtoNk_yxWtIw"}
@ -109,7 +109,7 @@ Let's check that our prediction function only works on single images.
:outputId: 9d3b29e8-fab3-4ecb-9f63-bc8c092f9006
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
```
@ -119,7 +119,7 @@ print(preds.shape)
:outputId: d5d20211-b6da-44e9-f71e-946f2a9d0fc4
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:

View File

@ -66,7 +66,7 @@
},
"outputs": [],
"source": [
"x = random.normal(random.PRNGKey(0), (5000, 5000))\n",
"x = random.normal(random.key(0), (5000, 5000))\n",
"def f(w, b, x):\n",
" return jnp.tanh(jnp.dot(x, w) + b)\n",
"fast_f = jit(f)"

View File

@ -48,7 +48,7 @@ JAX provides a NumPy-like API for numerical computing which can be used as is, b
```{code-cell} ipython3
:id: HmlMcICOcSXR
x = random.normal(random.PRNGKey(0), (5000, 5000))
x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)

View File

@ -27,7 +27,7 @@
"from jax import grad, jit, vmap\n",
"from jax import random\n",
"\n",
"key = random.PRNGKey(0)"
"key = random.key(0)"
]
},
{
@ -1055,7 +1055,7 @@
" outs, = vmap(vjp_fun)(M)\n",
" return outs\n",
"\n",
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"num_covecs = 128\n",
"U = random.normal(key, (num_covecs,) + y.shape)\n",
"\n",
@ -1306,7 +1306,7 @@
"outputs": [],
"source": [
"def check(seed):\n",
" key = random.PRNGKey(seed)\n",
" key = random.key(seed)\n",
"\n",
" # random coeffs for u and v\n",
" key, subkey = random.split(key)\n",
@ -1399,7 +1399,7 @@
"outputs": [],
"source": [
"def check(seed):\n",
" key = random.PRNGKey(seed)\n",
" key = random.key(seed)\n",
"\n",
" # random coeffs for u and v\n",
" key, subkey = random.split(key)\n",

View File

@ -29,7 +29,7 @@ import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
key = random.key(0)
```
+++ {"id": "YxnjtAGN6vu2"}
@ -614,7 +614,7 @@ def vmap_mjp(f, x, M):
outs, = vmap(vjp_fun)(M)
return outs
key = random.PRNGKey(0)
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
@ -770,7 +770,7 @@ Here's a check:
:id: BGZV__zupIMS
def check(seed):
key = random.PRNGKey(seed)
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
@ -833,7 +833,7 @@ Here's a check of the VJP rules:
:id: 4J7edvIBttcU
def check(seed):
key = random.PRNGKey(seed)
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)

View File

@ -60,7 +60,7 @@
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"key = random.PRNGKey(1701)\n",
"key = random.key(1701)\n",
"\n",
"x = jnp.linspace(0, 10, 500)\n",
"y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))\n",
@ -130,7 +130,7 @@
"ax[0].set_title('original')\n",
"\n",
"# Create a noisy version by adding random Gaussian noise\n",
"key = random.PRNGKey(1701)\n",
"key = random.key(1701)\n",
"noisy_image = image + 50 * random.normal(key, image.shape)\n",
"ax[1].imshow(noisy_image, cmap='binary_r')\n",
"ax[1].set_title('noisy')\n",

View File

@ -43,7 +43,7 @@ from jax import random
import jax.numpy as jnp
import numpy as np
key = random.PRNGKey(1701)
key = random.key(1701)
x = jnp.linspace(0, 10, 500)
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))
@ -84,7 +84,7 @@ ax[0].imshow(image, cmap='binary_r')
ax[0].set_title('original')
# Create a noisy version by adding random Gaussian noise
key = random.PRNGKey(1701)
key = random.key(1701)
noisy_image = image + 50 * random.normal(key, image.shape)
ax[1].imshow(noisy_image, cmap='binary_r')
ax[1].set_title('noisy')

View File

@ -97,7 +97,7 @@
"num_epochs = 10\n",
"batch_size = 128\n",
"n_targets = 10\n",
"params = init_network_params(layer_sizes, random.PRNGKey(0))"
"params = init_network_params(layer_sizes, random.key(0))"
]
},
{
@ -163,7 +163,7 @@
],
"source": [
"# This works on single examples\n",
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
"random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
"preds = predict(params, random_flattened_image)\n",
"print(preds.shape)"
]
@ -186,7 +186,7 @@
],
"source": [
"# Doesn't work with a batch\n",
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
"random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
"try:\n",
" preds = predict(params, random_flattened_images)\n",
"except TypeError:\n",

View File

@ -79,7 +79,7 @@ step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
params = init_network_params(layer_sizes, random.key(0))
```
+++ {"id": "BtoNk_yxWtIw"}
@ -117,7 +117,7 @@ Let's check that our prediction function only works on single images.
:outputId: ce9d86ed-a830-4832-e04d-10d1abb1fb8a
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
```
@ -127,7 +127,7 @@ print(preds.shape)
:outputId: f43bbc9d-bc8f-4168-ee7b-79ee9d33f245
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:

View File

@ -81,7 +81,7 @@
}
],
"source": [
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"x = random.normal(key, (10,))\n",
"print(x)"
]

View File

@ -59,7 +59,7 @@ We'll be generating random data in the following examples. One big difference be
:id: u0nseKZNqOoH
:outputId: 03e20e21-376c-41bb-a6bb-57431823691b
key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (10,))
print(x)
```

View File

@ -483,7 +483,7 @@
"\n",
"normal_sample = jax.jit(normal_sample, static_argnums=(1,))\n",
"\n",
"key = random.PRNGKey(10003)\n",
"key = random.key(10003)\n",
"\n",
"beta_loc = jnp.zeros(num_features, jnp.float32)\n",
"beta_log_scale = jnp.zeros(num_features, jnp.float32)\n",

View File

@ -210,7 +210,7 @@ def normal_sample(key, shape):
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.PRNGKey(10003)
key = random.key(10003)
beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)

View File

@ -338,7 +338,7 @@
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
" )\n",
" )(x, y)\n",
"k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n",
"k1, k2 = jax.random.split(jax.random.key(0))\n",
"x = jax.random.normal(k1, (1024, 1024))\n",
"y = jax.random.normal(k2, (1024, 1024))\n",
"z = matmul(x, y)\n",
@ -376,7 +376,7 @@
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
" ),\n",
" )(x, y)\n",
"k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n",
"k1, k2 = jax.random.split(jax.random.key(0))\n",
"x = jax.random.normal(k1, (1024, 1024))\n",
"y = jax.random.normal(k2, (1024, 1024))\n",
"z = matmul(x, y, activation=jax.nn.relu)\n",
@ -397,7 +397,7 @@
"metadata": {},
"outputs": [],
"source": [
"k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n",
"k1, k2 = jax.random.split(jax.random.key(0))\n",
"x = jax.random.normal(k1, (4, 1024, 1024))\n",
"y = jax.random.normal(k2, (4, 1024, 1024))\n",
"z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)\n",

View File

@ -226,7 +226,7 @@ def matmul(x: jax.Array, y: jax.Array):
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
)
)(x, y)
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
@ -253,7 +253,7 @@ def matmul(x: jax.Array, y: jax.Array, *, activation):
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
),
)(x, y)
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
@ -263,7 +263,7 @@ np.testing.assert_allclose(z, jax.nn.relu(x @ y))
To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! To turn this matrix multiplication into a batched version, we just need to `vmap` it.
```{code-cell} ipython3
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)

View File

@ -11,7 +11,7 @@ check out the Tensorboard profiler below.
```python
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# Run the operations to be profiled
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
@ -107,7 +107,7 @@ import jax
jax.profiler.start_trace("/tmp/tensorboard")
# Run the operations to be profiled
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
@ -126,7 +126,7 @@ alternative to `start_trace` and `stop_trace`:
import jax
with jax.profiler.trace("/tmp/tensorboard"):
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

View File

@ -78,7 +78,7 @@ if __name__ == "__main__":
@jit
def objective(params, t):
rng = random.PRNGKey(t)
rng = random.key(t)
return -batch_elbo(funnel_log_density, rng, params, num_samples)
# Set up figure.
@ -107,7 +107,7 @@ if __name__ == "__main__":
# Plot random samples from variational distribution.
# Here we clone the rng used in computing the objective
# so that we can show exactly the same samples.
rngs = random.split(random.PRNGKey(t), num_samples)
rngs = random.split(random.key(t), num_samples)
samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
ax.plot(samples[:, 0], samples[:, 1], 'b.')

View File

@ -182,7 +182,7 @@ def main(_):
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
num_batches = num_complete_batches + bool(leftover)
key = random.PRNGKey(_SEED.value)
key = random.key(_SEED.value)
def data_stream():
rng = npr.RandomState(_SEED.value)

View File

@ -35,7 +35,7 @@ config.parse_flags_with_absl()
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
jax_rng = random.PRNGKey(0)
jax_rng = random.key(0)
result_shape, params = init_fun(jax_rng, input_shape)
result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32"))
test_case.assertEqual(result.shape, result_shape)

View File

@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
def main(unused_argv):
numpts = 7
key = random.PRNGKey(0)
key = random.key(0)
eye = jnp.eye(numpts)
def cov_map(cov_func, xs, xs2=None):

View File

@ -50,7 +50,7 @@ init_random_params, predict = stax.serial(
Dense(10), LogSoftmax)
if __name__ == "__main__":
rng = random.PRNGKey(0)
rng = random.key(0)
step_size = 0.001
num_epochs = 10

View File

@ -87,14 +87,14 @@ if __name__ == "__main__":
batch_size = 32
nrow, ncol = 10, 10 # sampled image grid size
test_rng = random.PRNGKey(1) # fixed prng key for evaluation
test_rng = random.key(1) # fixed prng key for evaluation
imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png")
train_images, _, test_images, _ = datasets.mnist(permute_train=True)
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
num_batches = num_complete_batches + bool(leftover)
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
enc_init_rng, dec_init_rng = random.split(random.key(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params
@ -131,7 +131,7 @@ if __name__ == "__main__":
opt_state = opt_init(init_params)
for epoch in range(num_epochs):
tic = time.time()
opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images)
opt_state = run_epoch(random.key(epoch), opt_state, train_images)
test_elbo, sampled_images = evaluate(opt_state, test_images)
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)

View File

@ -280,7 +280,7 @@ def jit(
... def selu(x, alpha=1.67, lmbda=1.05):
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> key = jax.random.key(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x)) # doctest: +SKIP
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748

View File

@ -1277,7 +1277,7 @@ def ensure_compile_time_eval():
@jax.jit
def jax_fn(x):
with jax.ensure_compile_time_eval():
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
y = random.randint(random.key(0), (1000,1000), 0, 100)
y2 = y @ y
x2 = jnp.sum(y2) * x
return x2
@ -1285,7 +1285,7 @@ def ensure_compile_time_eval():
A similar behavior can often be achieved simply by 'hoisting' the constant
expression out of the corresponding staging API::
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
y = random.randint(random.key(0), (1000,1000), 0, 100)
@jax.jit
def jax_fn(x):

View File

@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)

View File

@ -62,7 +62,7 @@ def zeros(key: KeyArray,
The ``key`` argument is ignored.
>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
"""
@ -77,7 +77,7 @@ def ones(key: KeyArray,
The ``key`` argument is ignored.
>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32)
@ -96,7 +96,7 @@ def constant(value: ArrayLike,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
"""
@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.50350785, 0.8088631 , 0.81566876],
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.41770416, 0.75262755, 0.7619329 ],
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.56293887, 0.90433645, 0.9119454 ],
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
"""
@ -638,7 +638,7 @@ def delta_orthogonal(
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
Array([[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]],

View File

@ -245,7 +245,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
data: a 32bit integer representing data to be folded in to the key.
Returns:
@ -275,7 +275,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
num: optional, a positive integer (or tuple of integers) indicating
the number (or shape) of keys to produce. Defaults to 2.

View File

@ -44,7 +44,7 @@ net_init, net_apply = stax.serial(
)
# Initialize parameters, not committing to a batch shape
rng = random.PRNGKey(0)
rng = random.key(0)
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(rng, in_shape)

View File

@ -268,7 +268,7 @@ def Dropout(rate, mode='train'):
msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
"argument. That is, instead of `apply_fun(params, inputs)`, call "
"it like `apply_fun(params, inputs, rng)` where `rng` is a "
"jax.random.PRNGKey value.")
"PRNG key (e.g. from `jax.random.key`).")
raise ValueError(msg)
if mode == 'train':
keep = random.bernoulli(rng, rate, inputs.shape)

View File

@ -29,7 +29,7 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=None,
"""Generate a random BCOO matrix.
Args:
key : random.PRNGKey to be passed to ``generator`` function.
key : PRNG key to be passed to ``generator`` function.
shape : tuple specifying the shape of the array to be generated.
dtype : dtype of the array to be generated.
indices_dtype: dtype of the BCOO indices.

View File

@ -27,9 +27,9 @@ For example:
>>> from jax import random
>>> from jax.experimental.sparse import BCOO, sparsify
>>> mat = random.uniform(random.PRNGKey(1701), (5, 5))
>>> mat = random.uniform(random.key(1701), (5, 5))
>>> mat = mat.at[mat < 0.5].set(0)
>>> vec = random.uniform(random.PRNGKey(42), (5,))
>>> vec = random.uniform(random.key(42), (5,))
>>> def f(mat, vec):
... return -(jnp.sin(mat) @ vec)

View File

@ -22,24 +22,25 @@ Basic usage
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.PRNGKey(seed)
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
PRNG Keys
PRNG keys
---------
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
be passed as a first argument.
The random state is described by two unsigned 32-bit integers that we call a **key**,
usually generated by the :py:func:`jax.random.PRNGKey` function::
The random state is described by a special array element type that we call a **key**,
usually generated by the :py:func:`jax.random.key` function::
>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key = random.key(0)
>>> key
Array([0, 0], dtype=uint32)
Array((), dtype=key<fry>) overlaying:
[0 0]
This key can then be used in any of JAX's random number generation routines::
@ -57,11 +58,47 @@ If you need a new random number, you can use :meth:`jax.random.split` to generat
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32)
.. note::
Typed key arrays, with element types such as ``key<fry>`` above,
were introduced in JAX v0.4.16. Before then, keys were
conventionally represented in ``uint32`` arrays, whose final
dimension represented the key's bit-level representation.
Both forms of key array can still be created and used with the
:mod:`jax.random` module. New-style typed key arrays are made with
:py:func:`jax.random.key`. Legacy ``uint32`` key arrays are made
with :py:func:`jax.random.PRNGKey`.
To convert between the two, use :py:func:`jax.random.key_data` and
:py:func:`jax.random.wrap_key_data`. The legacy key format may be
needed when interfacing with systems outside of JAX (e.g. exporting
arrays to a serializable format), or when passing keys to JAX-based
libraries that assume the legacy format.
Otherwise, typed keys are recommended. Caveats of legacy keys
relative to typed ones include:
* They have an extra trailing dimension.
* They have a numeric dtype (``uint32``), allowing for operations
that are typically not meant to be carried out over keys, such as
integer arithmetic.
* They do not carry information about the RNG implementation. When
legacy keys are passed to :mod:`jax.random` functions, a global
configuration setting determines the RNG implementation (see
"Advanced RNG configuration" below).
To learn more about this upgrade, and the design of key types, see
`JEP 9263
<https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html>`_.
Advanced
--------
Design and Context
==================
Design and background
=====================
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
@ -79,16 +116,19 @@ To summarize, among other requirements, the JAX PRNG aims to:
Advanced RNG configuration
==========================
JAX provides several PRNG implementations (controlled by the
`jax_default_prng_impl` flag).
JAX provides several PRNG implementations. A specific one can be
selected with the optional `impl` keyword argument to
`jax.random.key`. When no `impl` option is passed to the `key`
constructor, the implementation is determined by the global
`jax_default_prng_impl` configuration flag.
- **default**
- **default**, `"threefry2x32"`:
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
- "rbg" uses ThreeFry for splitting, and XLA RBG for data generation.
- "unsafe_rbg" exists only for demonstration purposes, using RBG both for
- `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
- `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
splitting (using an untested made up algorithm) and generating.
The random streams generated by these experimental implementations haven't
@ -127,7 +167,7 @@ robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
less safe in the sense that the quality of random streams it generates from
different keys is less well understood.
For more about jax_threefry_partitionable, see
For more about `jax_threefry_partitionable`, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
"""