mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #17741 from froystig:new-style-key-docs
PiperOrigin-RevId: 614080080
This commit is contained in:
commit
0302e4c34d
@ -304,7 +304,7 @@ per_core_batch_size=4
|
||||
seq_len=512
|
||||
emb_dim=512
|
||||
x = jax.random.normal(
|
||||
jax.random.PRNGKey(0),
|
||||
jax.random.key(0),
|
||||
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
@ -1049,7 +1049,7 @@ per_core_batch_size=4
|
||||
seq_len=512
|
||||
emb_dim=512
|
||||
x = jax.random.normal(
|
||||
jax.random.PRNGKey(0),
|
||||
jax.random.key(0),
|
||||
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ program:
|
||||
>>> import numpy as np
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax import random
|
||||
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
|
||||
>>> x = random.uniform(random.key(0), (1000, 1000))
|
||||
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
|
||||
>>> # will block until the value is ready.
|
||||
>>> jnp.dot(x, x) + 3. # doctest: +SKIP
|
||||
|
@ -59,7 +59,7 @@ def func2(x):
|
||||
y = func1(x)
|
||||
return y, jnp.tile(x, 10) + 1
|
||||
|
||||
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
|
||||
x = jax.random.normal(jax.random.key(42), (1000, 1000))
|
||||
y, z = func2(x)
|
||||
|
||||
z.block_until_ready()
|
||||
@ -107,14 +107,14 @@ import jax.numpy as jnp
|
||||
import jax.profiler
|
||||
|
||||
def afunction():
|
||||
return jax.random.normal(jax.random.PRNGKey(77), (1000000,))
|
||||
return jax.random.normal(jax.random.key(77), (1000000,))
|
||||
|
||||
z = afunction()
|
||||
|
||||
def anotherfunc():
|
||||
arrays = []
|
||||
for i in range(1, 10):
|
||||
x = jax.random.normal(jax.random.PRNGKey(42), (i, 10000))
|
||||
x = jax.random.normal(jax.random.key(42), (i, 10000))
|
||||
arrays.append(x)
|
||||
x.block_until_ready()
|
||||
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
|
||||
|
@ -282,7 +282,7 @@
|
||||
"source": [
|
||||
"from jax import random\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(42)\n",
|
||||
"key = random.key(42)\n",
|
||||
"\n",
|
||||
"print(key)"
|
||||
]
|
||||
@ -293,7 +293,7 @@
|
||||
"id": "XhFpKnW9F2nF"
|
||||
},
|
||||
"source": [
|
||||
"A key is just an array of shape `(2,)`.\n",
|
||||
"A single key is an array of scalar shape `()` and key element type.\n",
|
||||
"\n",
|
||||
"'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:"
|
||||
]
|
||||
@ -381,7 +381,7 @@
|
||||
"source": [
|
||||
"`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n",
|
||||
"\n",
|
||||
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
|
||||
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
|
||||
"\n",
|
||||
"It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n",
|
||||
"\n",
|
||||
@ -460,12 +460,12 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"key = random.PRNGKey(42)\n",
|
||||
"key = random.key(42)\n",
|
||||
"subkeys = random.split(key, 3)\n",
|
||||
"sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n",
|
||||
"print(\"individually:\", sequence)\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(42)\n",
|
||||
"key = random.key(42)\n",
|
||||
"print(\"all at once: \", random.normal(key, shape=(3,)))"
|
||||
]
|
||||
},
|
||||
|
@ -150,14 +150,14 @@ To avoid this issue, JAX does not use a global state. Instead, random functions
|
||||
|
||||
from jax import random
|
||||
|
||||
key = random.PRNGKey(42)
|
||||
key = random.key(42)
|
||||
|
||||
print(key)
|
||||
```
|
||||
|
||||
+++ {"id": "XhFpKnW9F2nF"}
|
||||
|
||||
A key is just an array of shape `(2,)`.
|
||||
A single key is an array of scalar shape `()` and key element type.
|
||||
|
||||
'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:
|
||||
|
||||
@ -201,7 +201,7 @@ key = new_key # If we wanted to do this again, we would use new_key as the key.
|
||||
|
||||
`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.
|
||||
|
||||
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
|
||||
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
|
||||
|
||||
It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.
|
||||
|
||||
@ -240,12 +240,12 @@ In the example below, sampling 3 values out of a normal distribution individuall
|
||||
:id: 4nB_TA54D-HT
|
||||
:outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56
|
||||
|
||||
key = random.PRNGKey(42)
|
||||
key = random.key(42)
|
||||
subkeys = random.split(key, 3)
|
||||
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
|
||||
print("individually:", sequence)
|
||||
|
||||
key = random.PRNGKey(42)
|
||||
key = random.key(42)
|
||||
print("all at once: ", random.normal(key, shape=(3,)))
|
||||
```
|
||||
|
||||
|
@ -623,7 +623,7 @@
|
||||
"ys = xs * true_w + true_b + noise\n",
|
||||
"\n",
|
||||
"# Initialise parameters and replicate across devices.\n",
|
||||
"params = init(jax.random.PRNGKey(123))\n",
|
||||
"params = init(jax.random.key(123))\n",
|
||||
"n_devices = jax.local_device_count()\n",
|
||||
"replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)"
|
||||
]
|
||||
|
@ -291,7 +291,7 @@ noise = 0.5 * np.random.normal(size=(128, 1))
|
||||
ys = xs * true_w + true_b + noise
|
||||
|
||||
# Initialise parameters and replicate across devices.
|
||||
params = init(jax.random.PRNGKey(123))
|
||||
params = init(jax.random.key(123))
|
||||
n_devices = jax.local_device_count()
|
||||
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
|
||||
```
|
||||
|
@ -249,7 +249,7 @@
|
||||
"\n",
|
||||
"In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?\n",
|
||||
"\n",
|
||||
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey."
|
||||
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -351,7 +351,7 @@
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"rng = jax.random.PRNGKey(42)\n",
|
||||
"rng = jax.random.key(42)\n",
|
||||
"\n",
|
||||
"# Generate true data from y = w*x + b + noise\n",
|
||||
"true_w, true_b = 2, -1\n",
|
||||
|
@ -166,7 +166,7 @@ Notice that the need for a class becomes less clear once we have rewritten it th
|
||||
|
||||
In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?
|
||||
|
||||
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNGKey.
|
||||
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key.
|
||||
|
||||
+++ {"id": "I2SqRx14_z98"}
|
||||
|
||||
@ -233,7 +233,7 @@ Notice that we manually pipe the params in and out of the update function.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
rng = jax.random.PRNGKey(42)
|
||||
rng = jax.random.key(42)
|
||||
|
||||
# Generate true data from y = w*x + b + noise
|
||||
true_w, true_b = 2, -1
|
||||
|
@ -14,8 +14,8 @@ consistent with definitions used in Keras and Sonnet.
|
||||
|
||||
An initializer is a function that takes three arguments:
|
||||
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
|
||||
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
|
||||
key used when generating random numbers to initialize the array.
|
||||
data type ``dtype``. Argument ``key`` is a PRNG key (e.g. from
|
||||
:func:`jax.random.key`), used to generate random numbers to initialize the array.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
@ -1006,7 +1006,7 @@
|
||||
"source": [
|
||||
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
|
||||
"\n",
|
||||
"The random state is described by two unsigned-int32s that we call a __key__:"
|
||||
"The random state is described by a special array element that we call a __key__:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1030,7 +1030,7 @@
|
||||
],
|
||||
"source": [
|
||||
"from jax import random\n",
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"key"
|
||||
]
|
||||
},
|
||||
@ -2121,7 +2121,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
|
||||
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
|
||||
"x.dtype"
|
||||
]
|
||||
},
|
||||
@ -2188,7 +2188,7 @@
|
||||
"source": [
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import random\n",
|
||||
"x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)\n",
|
||||
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
|
||||
"x.dtype # --> dtype('float64')"
|
||||
]
|
||||
},
|
||||
|
@ -463,14 +463,14 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
|
||||
|
||||
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
|
||||
|
||||
The random state is described by two unsigned-int32s that we call a __key__:
|
||||
The random state is described by a special array element that we call a __key__:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: yPHE7KTWgAWs
|
||||
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
|
||||
|
||||
from jax import random
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
key
|
||||
```
|
||||
|
||||
@ -1071,7 +1071,7 @@ At the moment, JAX by default enforces single-precision numbers to mitigate the
|
||||
:id: CNNGtzM3NDkO
|
||||
:outputId: b422bb23-a784-44dc-f8c9-57f3b6c861b8
|
||||
|
||||
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
|
||||
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
|
||||
x.dtype
|
||||
```
|
||||
|
||||
@ -1117,7 +1117,7 @@ We can then confirm that `x64` mode is enabled:
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
|
||||
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
|
||||
x.dtype # --> dtype('float64')
|
||||
```
|
||||
|
||||
|
@ -131,7 +131,7 @@
|
||||
],
|
||||
"source": [
|
||||
"# Create an array of random values:\n",
|
||||
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
|
||||
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
|
||||
"# and use jax.device_put to distribute it across devices:\n",
|
||||
"y = jax.device_put(x, sharding.reshape(4, 2))\n",
|
||||
"jax.debug.visualize_array_sharding(y)"
|
||||
@ -272,7 +272,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))"
|
||||
"x = jax.random.normal(jax.random.key(0), (8192, 8192))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1513,7 +1513,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))\n",
|
||||
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
|
||||
"x = jax.device_put(x, sharding.reshape(4, 2))"
|
||||
]
|
||||
},
|
||||
@ -1738,7 +1738,7 @@
|
||||
"layer_sizes = [784, 8192, 8192, 8192, 10]\n",
|
||||
"batch_size = 8192\n",
|
||||
"\n",
|
||||
"params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)"
|
||||
"params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2184,7 +2184,7 @@
|
||||
" numbers = jax.random.uniform(key, x.shape)\n",
|
||||
" return x + numbers\n",
|
||||
"\n",
|
||||
"key = jax.random.PRNGKey(42)\n",
|
||||
"key = jax.random.key(42)\n",
|
||||
"x_sharding = jax.sharding.PositionalSharding(jax.devices())\n",
|
||||
"x = jax.device_put(jnp.arange(24), x_sharding)"
|
||||
]
|
||||
|
@ -81,7 +81,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
|
||||
:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b
|
||||
|
||||
# Create an array of random values:
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
|
||||
x = jax.random.normal(jax.random.key(0), (8192, 8192))
|
||||
# and use jax.device_put to distribute it across devices:
|
||||
y = jax.device_put(x, sharding.reshape(4, 2))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
@ -144,7 +144,7 @@ For example, here's a value with a single-device `Sharding`:
|
||||
:id: VmoX4SUp3vGJ
|
||||
|
||||
import jax
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
|
||||
x = jax.random.normal(jax.random.key(0), (8192, 8192))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@ -609,7 +609,7 @@ sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
|
||||
```{code-cell}
|
||||
:id: Q1wuDp-L3vGT
|
||||
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
|
||||
x = jax.random.normal(jax.random.key(0), (8192, 8192))
|
||||
x = jax.device_put(x, sharding.reshape(4, 2))
|
||||
```
|
||||
|
||||
@ -720,7 +720,7 @@ def init_model(key, layer_sizes, batch_size):
|
||||
layer_sizes = [784, 8192, 8192, 8192, 10]
|
||||
batch_size = 8192
|
||||
|
||||
params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)
|
||||
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
|
||||
```
|
||||
|
||||
+++ {"id": "sJv_h0AS2drh"}
|
||||
@ -902,7 +902,7 @@ def f(key, x):
|
||||
numbers = jax.random.uniform(key, x.shape)
|
||||
return x + numbers
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
key = jax.random.key(42)
|
||||
x_sharding = jax.sharding.PositionalSharding(jax.devices())
|
||||
x = jax.device_put(jnp.arange(24), x_sharding)
|
||||
```
|
||||
|
@ -84,7 +84,7 @@
|
||||
"num_epochs = 8\n",
|
||||
"batch_size = 128\n",
|
||||
"n_targets = 10\n",
|
||||
"params = init_network_params(layer_sizes, random.PRNGKey(0))"
|
||||
"params = init_network_params(layer_sizes, random.key(0))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -150,7 +150,7 @@
|
||||
],
|
||||
"source": [
|
||||
"# This works on single examples\n",
|
||||
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
|
||||
"random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
|
||||
"preds = predict(params, random_flattened_image)\n",
|
||||
"print(preds.shape)"
|
||||
]
|
||||
@ -173,7 +173,7 @@
|
||||
],
|
||||
"source": [
|
||||
"# Doesn't work with a batch\n",
|
||||
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
|
||||
"random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
|
||||
"try:\n",
|
||||
" preds = predict(params, random_flattened_images)\n",
|
||||
"except TypeError:\n",
|
||||
|
@ -71,7 +71,7 @@ step_size = 0.01
|
||||
num_epochs = 8
|
||||
batch_size = 128
|
||||
n_targets = 10
|
||||
params = init_network_params(layer_sizes, random.PRNGKey(0))
|
||||
params = init_network_params(layer_sizes, random.key(0))
|
||||
```
|
||||
|
||||
+++ {"id": "BtoNk_yxWtIw"}
|
||||
@ -109,7 +109,7 @@ Let's check that our prediction function only works on single images.
|
||||
:outputId: 9d3b29e8-fab3-4ecb-9f63-bc8c092f9006
|
||||
|
||||
# This works on single examples
|
||||
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
|
||||
random_flattened_image = random.normal(random.key(1), (28 * 28,))
|
||||
preds = predict(params, random_flattened_image)
|
||||
print(preds.shape)
|
||||
```
|
||||
@ -119,7 +119,7 @@ print(preds.shape)
|
||||
:outputId: d5d20211-b6da-44e9-f71e-946f2a9d0fc4
|
||||
|
||||
# Doesn't work with a batch
|
||||
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
|
||||
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
|
||||
try:
|
||||
preds = predict(params, random_flattened_images)
|
||||
except TypeError:
|
||||
|
@ -66,7 +66,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = random.normal(random.PRNGKey(0), (5000, 5000))\n",
|
||||
"x = random.normal(random.key(0), (5000, 5000))\n",
|
||||
"def f(w, b, x):\n",
|
||||
" return jnp.tanh(jnp.dot(x, w) + b)\n",
|
||||
"fast_f = jit(f)"
|
||||
|
@ -48,7 +48,7 @@ JAX provides a NumPy-like API for numerical computing which can be used as is, b
|
||||
```{code-cell} ipython3
|
||||
:id: HmlMcICOcSXR
|
||||
|
||||
x = random.normal(random.PRNGKey(0), (5000, 5000))
|
||||
x = random.normal(random.key(0), (5000, 5000))
|
||||
def f(w, b, x):
|
||||
return jnp.tanh(jnp.dot(x, w) + b)
|
||||
fast_f = jit(f)
|
||||
|
@ -27,7 +27,7 @@
|
||||
"from jax import grad, jit, vmap\n",
|
||||
"from jax import random\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(0)"
|
||||
"key = random.key(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1055,7 +1055,7 @@
|
||||
" outs, = vmap(vjp_fun)(M)\n",
|
||||
" return outs\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"num_covecs = 128\n",
|
||||
"U = random.normal(key, (num_covecs,) + y.shape)\n",
|
||||
"\n",
|
||||
@ -1306,7 +1306,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def check(seed):\n",
|
||||
" key = random.PRNGKey(seed)\n",
|
||||
" key = random.key(seed)\n",
|
||||
"\n",
|
||||
" # random coeffs for u and v\n",
|
||||
" key, subkey = random.split(key)\n",
|
||||
@ -1399,7 +1399,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def check(seed):\n",
|
||||
" key = random.PRNGKey(seed)\n",
|
||||
" key = random.key(seed)\n",
|
||||
"\n",
|
||||
" # random coeffs for u and v\n",
|
||||
" key, subkey = random.split(key)\n",
|
||||
|
@ -29,7 +29,7 @@ import jax.numpy as jnp
|
||||
from jax import grad, jit, vmap
|
||||
from jax import random
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
```
|
||||
|
||||
+++ {"id": "YxnjtAGN6vu2"}
|
||||
@ -614,7 +614,7 @@ def vmap_mjp(f, x, M):
|
||||
outs, = vmap(vjp_fun)(M)
|
||||
return outs
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
num_covecs = 128
|
||||
U = random.normal(key, (num_covecs,) + y.shape)
|
||||
|
||||
@ -770,7 +770,7 @@ Here's a check:
|
||||
:id: BGZV__zupIMS
|
||||
|
||||
def check(seed):
|
||||
key = random.PRNGKey(seed)
|
||||
key = random.key(seed)
|
||||
|
||||
# random coeffs for u and v
|
||||
key, subkey = random.split(key)
|
||||
@ -833,7 +833,7 @@ Here's a check of the VJP rules:
|
||||
:id: 4J7edvIBttcU
|
||||
|
||||
def check(seed):
|
||||
key = random.PRNGKey(seed)
|
||||
key = random.key(seed)
|
||||
|
||||
# random coeffs for u and v
|
||||
key, subkey = random.split(key)
|
||||
|
@ -60,7 +60,7 @@
|
||||
"import jax.numpy as jnp\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(1701)\n",
|
||||
"key = random.key(1701)\n",
|
||||
"\n",
|
||||
"x = jnp.linspace(0, 10, 500)\n",
|
||||
"y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))\n",
|
||||
@ -130,7 +130,7 @@
|
||||
"ax[0].set_title('original')\n",
|
||||
"\n",
|
||||
"# Create a noisy version by adding random Gaussian noise\n",
|
||||
"key = random.PRNGKey(1701)\n",
|
||||
"key = random.key(1701)\n",
|
||||
"noisy_image = image + 50 * random.normal(key, image.shape)\n",
|
||||
"ax[1].imshow(noisy_image, cmap='binary_r')\n",
|
||||
"ax[1].set_title('noisy')\n",
|
||||
|
@ -43,7 +43,7 @@ from jax import random
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
key = random.PRNGKey(1701)
|
||||
key = random.key(1701)
|
||||
|
||||
x = jnp.linspace(0, 10, 500)
|
||||
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))
|
||||
@ -84,7 +84,7 @@ ax[0].imshow(image, cmap='binary_r')
|
||||
ax[0].set_title('original')
|
||||
|
||||
# Create a noisy version by adding random Gaussian noise
|
||||
key = random.PRNGKey(1701)
|
||||
key = random.key(1701)
|
||||
noisy_image = image + 50 * random.normal(key, image.shape)
|
||||
ax[1].imshow(noisy_image, cmap='binary_r')
|
||||
ax[1].set_title('noisy')
|
||||
|
@ -97,7 +97,7 @@
|
||||
"num_epochs = 10\n",
|
||||
"batch_size = 128\n",
|
||||
"n_targets = 10\n",
|
||||
"params = init_network_params(layer_sizes, random.PRNGKey(0))"
|
||||
"params = init_network_params(layer_sizes, random.key(0))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -163,7 +163,7 @@
|
||||
],
|
||||
"source": [
|
||||
"# This works on single examples\n",
|
||||
"random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n",
|
||||
"random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
|
||||
"preds = predict(params, random_flattened_image)\n",
|
||||
"print(preds.shape)"
|
||||
]
|
||||
@ -186,7 +186,7 @@
|
||||
],
|
||||
"source": [
|
||||
"# Doesn't work with a batch\n",
|
||||
"random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n",
|
||||
"random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
|
||||
"try:\n",
|
||||
" preds = predict(params, random_flattened_images)\n",
|
||||
"except TypeError:\n",
|
||||
|
@ -79,7 +79,7 @@ step_size = 0.01
|
||||
num_epochs = 10
|
||||
batch_size = 128
|
||||
n_targets = 10
|
||||
params = init_network_params(layer_sizes, random.PRNGKey(0))
|
||||
params = init_network_params(layer_sizes, random.key(0))
|
||||
```
|
||||
|
||||
+++ {"id": "BtoNk_yxWtIw"}
|
||||
@ -117,7 +117,7 @@ Let's check that our prediction function only works on single images.
|
||||
:outputId: ce9d86ed-a830-4832-e04d-10d1abb1fb8a
|
||||
|
||||
# This works on single examples
|
||||
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
|
||||
random_flattened_image = random.normal(random.key(1), (28 * 28,))
|
||||
preds = predict(params, random_flattened_image)
|
||||
print(preds.shape)
|
||||
```
|
||||
@ -127,7 +127,7 @@ print(preds.shape)
|
||||
:outputId: f43bbc9d-bc8f-4168-ee7b-79ee9d33f245
|
||||
|
||||
# Doesn't work with a batch
|
||||
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
|
||||
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
|
||||
try:
|
||||
preds = predict(params, random_flattened_images)
|
||||
except TypeError:
|
||||
|
@ -81,7 +81,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"x = random.normal(key, (10,))\n",
|
||||
"print(x)"
|
||||
]
|
||||
|
@ -59,7 +59,7 @@ We'll be generating random data in the following examples. One big difference be
|
||||
:id: u0nseKZNqOoH
|
||||
:outputId: 03e20e21-376c-41bb-a6bb-57431823691b
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
x = random.normal(key, (10,))
|
||||
print(x)
|
||||
```
|
||||
|
@ -483,7 +483,7 @@
|
||||
"\n",
|
||||
"normal_sample = jax.jit(normal_sample, static_argnums=(1,))\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(10003)\n",
|
||||
"key = random.key(10003)\n",
|
||||
"\n",
|
||||
"beta_loc = jnp.zeros(num_features, jnp.float32)\n",
|
||||
"beta_log_scale = jnp.zeros(num_features, jnp.float32)\n",
|
||||
|
@ -210,7 +210,7 @@ def normal_sample(key, shape):
|
||||
|
||||
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
|
||||
|
||||
key = random.PRNGKey(10003)
|
||||
key = random.key(10003)
|
||||
|
||||
beta_loc = jnp.zeros(num_features, jnp.float32)
|
||||
beta_log_scale = jnp.zeros(num_features, jnp.float32)
|
||||
|
@ -338,7 +338,7 @@
|
||||
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
|
||||
" )\n",
|
||||
" )(x, y)\n",
|
||||
"k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n",
|
||||
"k1, k2 = jax.random.split(jax.random.key(0))\n",
|
||||
"x = jax.random.normal(k1, (1024, 1024))\n",
|
||||
"y = jax.random.normal(k2, (1024, 1024))\n",
|
||||
"z = matmul(x, y)\n",
|
||||
@ -376,7 +376,7 @@
|
||||
" lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)\n",
|
||||
" ),\n",
|
||||
" )(x, y)\n",
|
||||
"k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n",
|
||||
"k1, k2 = jax.random.split(jax.random.key(0))\n",
|
||||
"x = jax.random.normal(k1, (1024, 1024))\n",
|
||||
"y = jax.random.normal(k2, (1024, 1024))\n",
|
||||
"z = matmul(x, y, activation=jax.nn.relu)\n",
|
||||
@ -397,7 +397,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n",
|
||||
"k1, k2 = jax.random.split(jax.random.key(0))\n",
|
||||
"x = jax.random.normal(k1, (4, 1024, 1024))\n",
|
||||
"y = jax.random.normal(k2, (4, 1024, 1024))\n",
|
||||
"z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)\n",
|
||||
|
@ -226,7 +226,7 @@ def matmul(x: jax.Array, y: jax.Array):
|
||||
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
|
||||
)
|
||||
)(x, y)
|
||||
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
|
||||
k1, k2 = jax.random.split(jax.random.key(0))
|
||||
x = jax.random.normal(k1, (1024, 1024))
|
||||
y = jax.random.normal(k2, (1024, 1024))
|
||||
z = matmul(x, y)
|
||||
@ -253,7 +253,7 @@ def matmul(x: jax.Array, y: jax.Array, *, activation):
|
||||
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
|
||||
),
|
||||
)(x, y)
|
||||
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
|
||||
k1, k2 = jax.random.split(jax.random.key(0))
|
||||
x = jax.random.normal(k1, (1024, 1024))
|
||||
y = jax.random.normal(k2, (1024, 1024))
|
||||
z = matmul(x, y, activation=jax.nn.relu)
|
||||
@ -263,7 +263,7 @@ np.testing.assert_allclose(z, jax.nn.relu(x @ y))
|
||||
To conclude, let's highlight a cool feature of Pallas: it composes with `jax.vmap`! To turn this matrix multiplication into a batched version, we just need to `vmap` it.
|
||||
|
||||
```{code-cell} ipython3
|
||||
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
|
||||
k1, k2 = jax.random.split(jax.random.key(0))
|
||||
x = jax.random.normal(k1, (4, 1024, 1024))
|
||||
y = jax.random.normal(k2, (4, 1024, 1024))
|
||||
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
|
||||
|
@ -11,7 +11,7 @@ check out the Tensorboard profiler below.
|
||||
```python
|
||||
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
|
||||
# Run the operations to be profiled
|
||||
key = jax.random.PRNGKey(0)
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
@ -107,7 +107,7 @@ import jax
|
||||
jax.profiler.start_trace("/tmp/tensorboard")
|
||||
|
||||
# Run the operations to be profiled
|
||||
key = jax.random.PRNGKey(0)
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
@ -126,7 +126,7 @@ alternative to `start_trace` and `stop_trace`:
|
||||
import jax
|
||||
|
||||
with jax.profiler.trace("/tmp/tensorboard"):
|
||||
key = jax.random.PRNGKey(0)
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (5000, 5000))
|
||||
y = x @ x
|
||||
y.block_until_ready()
|
||||
|
@ -78,7 +78,7 @@ if __name__ == "__main__":
|
||||
|
||||
@jit
|
||||
def objective(params, t):
|
||||
rng = random.PRNGKey(t)
|
||||
rng = random.key(t)
|
||||
return -batch_elbo(funnel_log_density, rng, params, num_samples)
|
||||
|
||||
# Set up figure.
|
||||
@ -107,7 +107,7 @@ if __name__ == "__main__":
|
||||
# Plot random samples from variational distribution.
|
||||
# Here we clone the rng used in computing the objective
|
||||
# so that we can show exactly the same samples.
|
||||
rngs = random.split(random.PRNGKey(t), num_samples)
|
||||
rngs = random.split(random.key(t), num_samples)
|
||||
samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
|
||||
ax.plot(samples[:, 0], samples[:, 1], 'b.')
|
||||
|
||||
|
@ -182,7 +182,7 @@ def main(_):
|
||||
num_train = train_images.shape[0]
|
||||
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
key = random.PRNGKey(_SEED.value)
|
||||
key = random.key(_SEED.value)
|
||||
|
||||
def data_stream():
|
||||
rng = npr.RandomState(_SEED.value)
|
||||
|
@ -35,7 +35,7 @@ config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||
jax_rng = random.PRNGKey(0)
|
||||
jax_rng = random.key(0)
|
||||
result_shape, params = init_fun(jax_rng, input_shape)
|
||||
result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32"))
|
||||
test_case.assertEqual(result.shape, result_shape)
|
||||
|
@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
|
||||
def main(unused_argv):
|
||||
|
||||
numpts = 7
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
eye = jnp.eye(numpts)
|
||||
|
||||
def cov_map(cov_func, xs, xs2=None):
|
||||
|
@ -50,7 +50,7 @@ init_random_params, predict = stax.serial(
|
||||
Dense(10), LogSoftmax)
|
||||
|
||||
if __name__ == "__main__":
|
||||
rng = random.PRNGKey(0)
|
||||
rng = random.key(0)
|
||||
|
||||
step_size = 0.001
|
||||
num_epochs = 10
|
||||
|
@ -87,14 +87,14 @@ if __name__ == "__main__":
|
||||
batch_size = 32
|
||||
nrow, ncol = 10, 10 # sampled image grid size
|
||||
|
||||
test_rng = random.PRNGKey(1) # fixed prng key for evaluation
|
||||
test_rng = random.key(1) # fixed prng key for evaluation
|
||||
imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png")
|
||||
|
||||
train_images, _, test_images, _ = datasets.mnist(permute_train=True)
|
||||
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
|
||||
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
|
||||
enc_init_rng, dec_init_rng = random.split(random.key(2))
|
||||
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
|
||||
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
|
||||
init_params = init_encoder_params, init_decoder_params
|
||||
@ -131,7 +131,7 @@ if __name__ == "__main__":
|
||||
opt_state = opt_init(init_params)
|
||||
for epoch in range(num_epochs):
|
||||
tic = time.time()
|
||||
opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images)
|
||||
opt_state = run_epoch(random.key(epoch), opt_state, train_images)
|
||||
test_elbo, sampled_images = evaluate(opt_state, test_images)
|
||||
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
|
||||
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)
|
||||
|
@ -280,7 +280,7 @@ def jit(
|
||||
... def selu(x, alpha=1.67, lmbda=1.05):
|
||||
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
||||
>>>
|
||||
>>> key = jax.random.PRNGKey(0)
|
||||
>>> key = jax.random.key(0)
|
||||
>>> x = jax.random.normal(key, (10,))
|
||||
>>> print(selu(x)) # doctest: +SKIP
|
||||
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
||||
|
@ -1277,7 +1277,7 @@ def ensure_compile_time_eval():
|
||||
@jax.jit
|
||||
def jax_fn(x):
|
||||
with jax.ensure_compile_time_eval():
|
||||
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
||||
y = random.randint(random.key(0), (1000,1000), 0, 100)
|
||||
y2 = y @ y
|
||||
x2 = jnp.sum(y2) * x
|
||||
return x2
|
||||
@ -1285,7 +1285,7 @@ def ensure_compile_time_eval():
|
||||
A similar behavior can often be achieved simply by 'hoisting' the constant
|
||||
expression out of the corresponding staging API::
|
||||
|
||||
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
||||
y = random.randint(random.key(0), (1000,1000), 0, 100)
|
||||
|
||||
@jax.jit
|
||||
def jax_fn(x):
|
||||
|
@ -2101,7 +2101,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
|
||||
Example 2: partial products of an array of matrices
|
||||
|
||||
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
|
||||
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
|
||||
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
|
||||
>>> partial_prods.shape
|
||||
(4, 2, 2)
|
||||
|
@ -62,7 +62,7 @@ def zeros(key: KeyArray,
|
||||
The ``key`` argument is ignored.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
|
||||
Array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
"""
|
||||
@ -77,7 +77,7 @@ def ones(key: KeyArray,
|
||||
The ``key`` argument is ignored.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
|
||||
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
|
||||
Array([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]], dtype=float32)
|
||||
@ -96,7 +96,7 @@ def constant(value: ArrayLike,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.constant(-7)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
|
||||
Array([[-7., -7., -7.],
|
||||
[-7., -7., -7.]], dtype=float32)
|
||||
"""
|
||||
@ -122,7 +122,7 @@ def uniform(scale: RealNumeric = 1e-2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.uniform(10.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[7.298188 , 8.691938 , 8.7230015],
|
||||
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
||||
"""
|
||||
@ -148,7 +148,7 @@ def normal(stddev: RealNumeric = 1e-2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.normal(5.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
||||
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
||||
"""
|
||||
@ -376,7 +376,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.50350785, 0.8088631 , 0.81566876],
|
||||
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
|
||||
|
||||
@ -414,7 +414,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.41770416, 0.75262755, 0.7619329 ],
|
||||
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
|
||||
|
||||
@ -452,7 +452,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.56293887, 0.90433645, 0.9119454 ],
|
||||
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
|
||||
|
||||
@ -488,7 +488,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
|
||||
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
|
||||
|
||||
@ -524,7 +524,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.he_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
|
||||
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
|
||||
|
||||
@ -562,7 +562,7 @@ def he_normal(in_axis: int | Sequence[int] = -2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.he_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
|
||||
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
|
||||
|
||||
@ -595,7 +595,7 @@ def orthogonal(scale: RealNumeric = 1.0,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.orthogonal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
|
||||
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
|
||||
"""
|
||||
@ -638,7 +638,7 @@ def delta_orthogonal(
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.delta_orthogonal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ]],
|
||||
|
@ -245,7 +245,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
|
||||
"""Folds in data to a PRNG key to form a new PRNG key.
|
||||
|
||||
Args:
|
||||
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
||||
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
|
||||
data: a 32bit integer representing data to be folded in to the key.
|
||||
|
||||
Returns:
|
||||
@ -275,7 +275,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
|
||||
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
||||
|
||||
Args:
|
||||
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
||||
key: a PRNG key (from ``key``, ``split``, ``fold_in``).
|
||||
num: optional, a positive integer (or tuple of integers) indicating
|
||||
the number (or shape) of keys to produce. Defaults to 2.
|
||||
|
||||
|
@ -44,7 +44,7 @@ net_init, net_apply = stax.serial(
|
||||
)
|
||||
|
||||
# Initialize parameters, not committing to a batch shape
|
||||
rng = random.PRNGKey(0)
|
||||
rng = random.key(0)
|
||||
in_shape = (-1, 28, 28, 1)
|
||||
out_shape, net_params = net_init(rng, in_shape)
|
||||
|
||||
|
@ -268,7 +268,7 @@ def Dropout(rate, mode='train'):
|
||||
msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
|
||||
"argument. That is, instead of `apply_fun(params, inputs)`, call "
|
||||
"it like `apply_fun(params, inputs, rng)` where `rng` is a "
|
||||
"jax.random.PRNGKey value.")
|
||||
"PRNG key (e.g. from `jax.random.key`).")
|
||||
raise ValueError(msg)
|
||||
if mode == 'train':
|
||||
keep = random.bernoulli(rng, rate, inputs.shape)
|
||||
|
@ -29,7 +29,7 @@ def random_bcoo(key, shape, *, dtype=jnp.float_, indices_dtype=None,
|
||||
"""Generate a random BCOO matrix.
|
||||
|
||||
Args:
|
||||
key : random.PRNGKey to be passed to ``generator`` function.
|
||||
key : PRNG key to be passed to ``generator`` function.
|
||||
shape : tuple specifying the shape of the array to be generated.
|
||||
dtype : dtype of the array to be generated.
|
||||
indices_dtype: dtype of the BCOO indices.
|
||||
|
@ -27,9 +27,9 @@ For example:
|
||||
>>> from jax import random
|
||||
>>> from jax.experimental.sparse import BCOO, sparsify
|
||||
|
||||
>>> mat = random.uniform(random.PRNGKey(1701), (5, 5))
|
||||
>>> mat = random.uniform(random.key(1701), (5, 5))
|
||||
>>> mat = mat.at[mat < 0.5].set(0)
|
||||
>>> vec = random.uniform(random.PRNGKey(42), (5,))
|
||||
>>> vec = random.uniform(random.key(42), (5,))
|
||||
|
||||
>>> def f(mat, vec):
|
||||
... return -(jnp.sin(mat) @ vec)
|
||||
|
@ -22,24 +22,25 @@ Basic usage
|
||||
|
||||
>>> seed = 1701
|
||||
>>> num_steps = 100
|
||||
>>> key = jax.random.PRNGKey(seed)
|
||||
>>> key = jax.random.key(seed)
|
||||
>>> for i in range(num_steps):
|
||||
... key, subkey = jax.random.split(key)
|
||||
... params = compiled_update(subkey, params, next(batches)) # doctest: +SKIP
|
||||
|
||||
PRNG Keys
|
||||
PRNG keys
|
||||
---------
|
||||
|
||||
Unlike the *stateful* pseudorandom number generators (PRNGs) that users of NumPy and
|
||||
SciPy may be accustomed to, JAX random functions all require an explicit PRNG state to
|
||||
be passed as a first argument.
|
||||
The random state is described by two unsigned 32-bit integers that we call a **key**,
|
||||
usually generated by the :py:func:`jax.random.PRNGKey` function::
|
||||
The random state is described by a special array element type that we call a **key**,
|
||||
usually generated by the :py:func:`jax.random.key` function::
|
||||
|
||||
>>> from jax import random
|
||||
>>> key = random.PRNGKey(0)
|
||||
>>> key = random.key(0)
|
||||
>>> key
|
||||
Array([0, 0], dtype=uint32)
|
||||
Array((), dtype=key<fry>) overlaying:
|
||||
[0 0]
|
||||
|
||||
This key can then be used in any of JAX's random number generation routines::
|
||||
|
||||
@ -57,11 +58,47 @@ If you need a new random number, you can use :meth:`jax.random.split` to generat
|
||||
>>> random.uniform(subkey)
|
||||
Array(0.10536897, dtype=float32)
|
||||
|
||||
.. note::
|
||||
|
||||
Typed key arrays, with element types such as ``key<fry>`` above,
|
||||
were introduced in JAX v0.4.16. Before then, keys were
|
||||
conventionally represented in ``uint32`` arrays, whose final
|
||||
dimension represented the key's bit-level representation.
|
||||
|
||||
Both forms of key array can still be created and used with the
|
||||
:mod:`jax.random` module. New-style typed key arrays are made with
|
||||
:py:func:`jax.random.key`. Legacy ``uint32`` key arrays are made
|
||||
with :py:func:`jax.random.PRNGKey`.
|
||||
|
||||
To convert between the two, use :py:func:`jax.random.key_data` and
|
||||
:py:func:`jax.random.wrap_key_data`. The legacy key format may be
|
||||
needed when interfacing with systems outside of JAX (e.g. exporting
|
||||
arrays to a serializable format), or when passing keys to JAX-based
|
||||
libraries that assume the legacy format.
|
||||
|
||||
Otherwise, typed keys are recommended. Caveats of legacy keys
|
||||
relative to typed ones include:
|
||||
|
||||
* They have an extra trailing dimension.
|
||||
|
||||
* They have a numeric dtype (``uint32``), allowing for operations
|
||||
that are typically not meant to be carried out over keys, such as
|
||||
integer arithmetic.
|
||||
|
||||
* They do not carry information about the RNG implementation. When
|
||||
legacy keys are passed to :mod:`jax.random` functions, a global
|
||||
configuration setting determines the RNG implementation (see
|
||||
"Advanced RNG configuration" below).
|
||||
|
||||
To learn more about this upgrade, and the design of key types, see
|
||||
`JEP 9263
|
||||
<https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html>`_.
|
||||
|
||||
Advanced
|
||||
--------
|
||||
|
||||
Design and Context
|
||||
==================
|
||||
Design and background
|
||||
=====================
|
||||
|
||||
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
|
||||
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
|
||||
@ -79,16 +116,19 @@ To summarize, among other requirements, the JAX PRNG aims to:
|
||||
Advanced RNG configuration
|
||||
==========================
|
||||
|
||||
JAX provides several PRNG implementations (controlled by the
|
||||
`jax_default_prng_impl` flag).
|
||||
JAX provides several PRNG implementations. A specific one can be
|
||||
selected with the optional `impl` keyword argument to
|
||||
`jax.random.key`. When no `impl` option is passed to the `key`
|
||||
constructor, the implementation is determined by the global
|
||||
`jax_default_prng_impl` configuration flag.
|
||||
|
||||
- **default**
|
||||
- **default**, `"threefry2x32"`:
|
||||
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
|
||||
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
|
||||
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
|
||||
|
||||
- "rbg" uses ThreeFry for splitting, and XLA RBG for data generation.
|
||||
- "unsafe_rbg" exists only for demonstration purposes, using RBG both for
|
||||
- `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
|
||||
- `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
|
||||
splitting (using an untested made up algorithm) and generating.
|
||||
|
||||
The random streams generated by these experimental implementations haven't
|
||||
@ -127,7 +167,7 @@ robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
|
||||
less safe in the sense that the quality of random streams it generates from
|
||||
different keys is less well understood.
|
||||
|
||||
For more about jax_threefry_partitionable, see
|
||||
For more about `jax_threefry_partitionable`, see
|
||||
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
|
||||
"""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user