update README and several docs to typed RNG keys

This commit is contained in:
Roy Frostig 2024-08-11 08:09:47 -07:00
parent 9e86416a32
commit 371935cc10
13 changed files with 22 additions and 22 deletions

View File

@ -273,7 +273,7 @@ from jax import random, pmap
import jax.numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)

View File

@ -839,7 +839,7 @@ def serial_dot_products(state):
out = out + y * x[0]
return out
x = jax.random.normal(jax.random.PRNGKey(0), (2, 2))
x = jax.random.normal(jax.random.key(0), (2, 2))
f(x).block_until_ready() # compile
while state:
f(x).block_until_ready()
@ -929,7 +929,7 @@ def jit_add_chain(state):
def g(x, y):
return lax.add(x, y)
x = jax.random.normal(jax.random.PRNGKey(0), (2, 2))
x = jax.random.normal(jax.random.key(0), (2, 2))
while state:
@jax.jit
def f(x):

View File

@ -109,7 +109,7 @@ def sparse_bcoo_todense_compile(state):
def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False):
shape = (2000, 2000)
nse = 10000
key = jax.random.PRNGKey(1701)
key = jax.random.key(1701)
mat = sparse.random_bcoo(
key,
nse=nse,

View File

@ -38,7 +38,7 @@
},
"outputs": [],
"source": [
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"key, subkey = random.split(key)\n",
"x = random.normal(key, (5000, 5000))\n",
"\n",
@ -189,7 +189,7 @@
},
"outputs": [],
"source": [
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"x = random.normal(key, ())\n",
"\n",
"print(grad(f)(x))\n",
@ -261,7 +261,7 @@
},
"outputs": [],
"source": [
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"x = random.normal(key, (5000, 5000))"
]
},

View File

@ -27,7 +27,7 @@
"import jax.numpy as jnp\n",
"from jax import random\n",
"\n",
"key = random.PRNGKey(0)"
"key = random.key(0)"
]
},
{
@ -194,7 +194,7 @@
},
"outputs": [],
"source": [
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"x = random.normal(key, ())\n",
"\n",
"print(grad(f)(x))\n",
@ -246,7 +246,7 @@
"\n",
"layer_sizes = [5, 2, 3]\n",
"\n",
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"key, *keys = random.split(key, len(layer_sizes))\n",
"params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
"\n",
@ -351,7 +351,7 @@
},
"outputs": [],
"source": [
"key = random.PRNGKey(0)\n",
"key = random.key(0)\n",
"x = random.normal(key, (5000, 5000))"
]
},
@ -754,7 +754,7 @@
},
"outputs": [],
"source": [
"keys = random.split(random.PRNGKey(0), 8)\n",
"keys = random.split(random.key(0), 8)\n",
"mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n",
"result = pmap(jnp.dot)(mats, mats)\n",
"print(pmap(jnp.mean)(result))"

View File

@ -366,7 +366,7 @@
"\n",
"# set some initial conditions for each replicate\n",
"ys = jnp.zeros((N_dev, N, 3))\n",
"state0 = jr.uniform(jr.PRNGKey(1), \n",
"state0 = jr.uniform(jr.key(1), \n",
" minval=-1., maxval=1.,\n",
" shape=(N_dev, 3))\n",
"state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n",

View File

@ -263,7 +263,7 @@
"from jax import random\n",
"\n",
"# create 8 random keys\n",
"keys = random.split(random.PRNGKey(0), 8)\n",
"keys = random.split(random.key(0), 8)\n",
"# create a 5000 x 6000 matrix on each device by mapping over keys\n",
"mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)\n",
"# the stack of matrices is represented logically as a single array\n",

View File

@ -306,7 +306,7 @@ per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)

View File

@ -479,7 +479,7 @@ seq_len = 512
emb_dim = 512
assert jax.local_device_count() > 1, "Only 1 GPU, the example work, but it is this really what you want?"
x = jax.random.normal(
jax.random.PRNGKey(0),
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.float16,
)

View File

@ -590,7 +590,7 @@ def vmap_mjp(f, x, M):
outs, = vmap(vjp_fun)(M)
return outs
key = random.PRNGKey(0)
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
@ -714,7 +714,7 @@ Here's a check:
```{code-cell}
def check(seed):
key = random.PRNGKey(seed)
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
@ -768,7 +768,7 @@ Here's a check of the VJP rules:
```{code-cell}
def check(seed):
key = random.PRNGKey(seed)
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)

View File

@ -1509,7 +1509,7 @@
"layer_sizes = [784, 128, 128, 128, 128, 128, 8]\n",
"batch_size = 32\n",
"\n",
"params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)"
"params, batch = init(jax.random.key(0), layer_sizes, batch_size)"
]
},
{

View File

@ -1055,7 +1055,7 @@ def init(key, layer_sizes, batch_size):
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32
params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)
params, batch = init(jax.random.key(0), layer_sizes, batch_size)
```
Compare these examples with the purely [automatic partitioning examples in the

View File

@ -184,7 +184,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.truncated_normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 2.9047365, 5.2338114, 5.29852 ],
[-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
"""