mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
update README and several docs to typed RNG keys
This commit is contained in:
parent
9e86416a32
commit
371935cc10
@ -273,7 +273,7 @@ from jax import random, pmap
|
||||
import jax.numpy as jnp
|
||||
|
||||
# Create 8 random 5000 x 6000 matrices, one per GPU
|
||||
keys = random.split(random.PRNGKey(0), 8)
|
||||
keys = random.split(random.key(0), 8)
|
||||
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
|
||||
|
||||
# Run a local matmul on each device in parallel (no data transfer)
|
||||
|
@ -839,7 +839,7 @@ def serial_dot_products(state):
|
||||
out = out + y * x[0]
|
||||
return out
|
||||
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (2, 2))
|
||||
x = jax.random.normal(jax.random.key(0), (2, 2))
|
||||
f(x).block_until_ready() # compile
|
||||
while state:
|
||||
f(x).block_until_ready()
|
||||
@ -929,7 +929,7 @@ def jit_add_chain(state):
|
||||
def g(x, y):
|
||||
return lax.add(x, y)
|
||||
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (2, 2))
|
||||
x = jax.random.normal(jax.random.key(0), (2, 2))
|
||||
while state:
|
||||
@jax.jit
|
||||
def f(x):
|
||||
|
@ -109,7 +109,7 @@ def sparse_bcoo_todense_compile(state):
|
||||
def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False):
|
||||
shape = (2000, 2000)
|
||||
nse = 10000
|
||||
key = jax.random.PRNGKey(1701)
|
||||
key = jax.random.key(1701)
|
||||
mat = sparse.random_bcoo(
|
||||
key,
|
||||
nse=nse,
|
||||
|
@ -38,7 +38,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"key, subkey = random.split(key)\n",
|
||||
"x = random.normal(key, (5000, 5000))\n",
|
||||
"\n",
|
||||
@ -189,7 +189,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"x = random.normal(key, ())\n",
|
||||
"\n",
|
||||
"print(grad(f)(x))\n",
|
||||
@ -261,7 +261,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"x = random.normal(key, (5000, 5000))"
|
||||
]
|
||||
},
|
||||
|
@ -27,7 +27,7 @@
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax import random\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(0)"
|
||||
"key = random.key(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -194,7 +194,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"x = random.normal(key, ())\n",
|
||||
"\n",
|
||||
"print(grad(f)(x))\n",
|
||||
@ -246,7 +246,7 @@
|
||||
"\n",
|
||||
"layer_sizes = [5, 2, 3]\n",
|
||||
"\n",
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"key, *keys = random.split(key, len(layer_sizes))\n",
|
||||
"params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
|
||||
"\n",
|
||||
@ -351,7 +351,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"key = random.PRNGKey(0)\n",
|
||||
"key = random.key(0)\n",
|
||||
"x = random.normal(key, (5000, 5000))"
|
||||
]
|
||||
},
|
||||
@ -754,7 +754,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"keys = random.split(random.PRNGKey(0), 8)\n",
|
||||
"keys = random.split(random.key(0), 8)\n",
|
||||
"mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n",
|
||||
"result = pmap(jnp.dot)(mats, mats)\n",
|
||||
"print(pmap(jnp.mean)(result))"
|
||||
|
@ -366,7 +366,7 @@
|
||||
"\n",
|
||||
"# set some initial conditions for each replicate\n",
|
||||
"ys = jnp.zeros((N_dev, N, 3))\n",
|
||||
"state0 = jr.uniform(jr.PRNGKey(1), \n",
|
||||
"state0 = jr.uniform(jr.key(1), \n",
|
||||
" minval=-1., maxval=1.,\n",
|
||||
" shape=(N_dev, 3))\n",
|
||||
"state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n",
|
||||
|
@ -263,7 +263,7 @@
|
||||
"from jax import random\n",
|
||||
"\n",
|
||||
"# create 8 random keys\n",
|
||||
"keys = random.split(random.PRNGKey(0), 8)\n",
|
||||
"keys = random.split(random.key(0), 8)\n",
|
||||
"# create a 5000 x 6000 matrix on each device by mapping over keys\n",
|
||||
"mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)\n",
|
||||
"# the stack of matrices is represented logically as a single array\n",
|
||||
|
@ -306,7 +306,7 @@ per_core_batch_size=4
|
||||
seq_len=512
|
||||
emb_dim=512
|
||||
x = jax.random.normal(
|
||||
jax.random.PRNGKey(0),
|
||||
jax.random.key(0),
|
||||
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
|
@ -479,7 +479,7 @@ seq_len = 512
|
||||
emb_dim = 512
|
||||
assert jax.local_device_count() > 1, "Only 1 GPU, the example work, but it is this really what you want?"
|
||||
x = jax.random.normal(
|
||||
jax.random.PRNGKey(0),
|
||||
jax.random.key(0),
|
||||
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
|
||||
dtype=jnp.float16,
|
||||
)
|
||||
|
@ -590,7 +590,7 @@ def vmap_mjp(f, x, M):
|
||||
outs, = vmap(vjp_fun)(M)
|
||||
return outs
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
num_covecs = 128
|
||||
U = random.normal(key, (num_covecs,) + y.shape)
|
||||
|
||||
@ -714,7 +714,7 @@ Here's a check:
|
||||
|
||||
```{code-cell}
|
||||
def check(seed):
|
||||
key = random.PRNGKey(seed)
|
||||
key = random.key(seed)
|
||||
|
||||
# random coeffs for u and v
|
||||
key, subkey = random.split(key)
|
||||
@ -768,7 +768,7 @@ Here's a check of the VJP rules:
|
||||
|
||||
```{code-cell}
|
||||
def check(seed):
|
||||
key = random.PRNGKey(seed)
|
||||
key = random.key(seed)
|
||||
|
||||
# random coeffs for u and v
|
||||
key, subkey = random.split(key)
|
||||
|
@ -1509,7 +1509,7 @@
|
||||
"layer_sizes = [784, 128, 128, 128, 128, 128, 8]\n",
|
||||
"batch_size = 32\n",
|
||||
"\n",
|
||||
"params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)"
|
||||
"params, batch = init(jax.random.key(0), layer_sizes, batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1055,7 +1055,7 @@ def init(key, layer_sizes, batch_size):
|
||||
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
|
||||
batch_size = 32
|
||||
|
||||
params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)
|
||||
params, batch = init(jax.random.key(0), layer_sizes, batch_size)
|
||||
```
|
||||
|
||||
Compare these examples with the purely [automatic partitioning examples in the
|
||||
|
@ -184,7 +184,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.truncated_normal(5.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
>>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 2.9047365, 5.2338114, 5.29852 ],
|
||||
[-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user