diff --git a/README.md b/README.md index b19d7b9ff..8f50aa42a 100644 --- a/README.md +++ b/README.md @@ -273,7 +273,7 @@ from jax import random, pmap import jax.numpy as jnp # Create 8 random 5000 x 6000 matrices, one per GPU -keys = random.split(random.PRNGKey(0), 8) +keys = random.split(random.key(0), 8) mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) # Run a local matmul on each device in parallel (no data transfer) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index c68dab85d..710ffb6d7 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -839,7 +839,7 @@ def serial_dot_products(state): out = out + y * x[0] return out - x = jax.random.normal(jax.random.PRNGKey(0), (2, 2)) + x = jax.random.normal(jax.random.key(0), (2, 2)) f(x).block_until_ready() # compile while state: f(x).block_until_ready() @@ -929,7 +929,7 @@ def jit_add_chain(state): def g(x, y): return lax.add(x, y) - x = jax.random.normal(jax.random.PRNGKey(0), (2, 2)) + x = jax.random.normal(jax.random.key(0), (2, 2)) while state: @jax.jit def f(x): diff --git a/benchmarks/sparse_benchmark.py b/benchmarks/sparse_benchmark.py index 65550b9cf..d6328881d 100644 --- a/benchmarks/sparse_benchmark.py +++ b/benchmarks/sparse_benchmark.py @@ -109,7 +109,7 @@ def sparse_bcoo_todense_compile(state): def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 - key = jax.random.PRNGKey(1701) + key = jax.random.key(1701) mat = sparse.random_bcoo( key, nse=nse, diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index 1278bd01c..279aef3e9 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -38,7 +38,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key, subkey = random.split(key)\n", "x = random.normal(key, (5000, 5000))\n", "\n", @@ -189,7 +189,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, ())\n", "\n", "print(grad(f)(x))\n", @@ -261,7 +261,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, (5000, 5000))" ] }, diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index 6a6993f44..9acb1971c 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -27,7 +27,7 @@ "import jax.numpy as jnp\n", "from jax import random\n", "\n", - "key = random.PRNGKey(0)" + "key = random.key(0)" ] }, { @@ -194,7 +194,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, ())\n", "\n", "print(grad(f)(x))\n", @@ -246,7 +246,7 @@ "\n", "layer_sizes = [5, 2, 3]\n", "\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key, *keys = random.split(key, len(layer_sizes))\n", "params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", "\n", @@ -351,7 +351,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, (5000, 5000))" ] }, @@ -754,7 +754,7 @@ }, "outputs": [], "source": [ - "keys = random.split(random.PRNGKey(0), 8)\n", + "keys = random.split(random.key(0), 8)\n", "mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n", "result = pmap(jnp.dot)(mats, mats)\n", "print(pmap(jnp.mean)(result))" diff --git a/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb b/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb index 1777d3d1e..84abf8658 100644 --- a/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb +++ b/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb @@ -366,7 +366,7 @@ "\n", "# set some initial conditions for each replicate\n", "ys = jnp.zeros((N_dev, N, 3))\n", - "state0 = jr.uniform(jr.PRNGKey(1), \n", + "state0 = jr.uniform(jr.key(1), \n", " minval=-1., maxval=1.,\n", " shape=(N_dev, 3))\n", "state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n", diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index 4f4ba8c16..981f0a9e8 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -263,7 +263,7 @@ "from jax import random\n", "\n", "# create 8 random keys\n", - "keys = random.split(random.PRNGKey(0), 8)\n", + "keys = random.split(random.key(0), 8)\n", "# create a 5000 x 6000 matrix on each device by mapping over keys\n", "mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)\n", "# the stack of matrices is represented logically as a single array\n", diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 28a817154..8490bd489 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -306,7 +306,7 @@ per_core_batch_size=4 seq_len=512 emb_dim=512 x = jax.random.normal( - jax.random.PRNGKey(0), + jax.random.key(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.bfloat16, ) diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py index 4c0b4b6f7..31a00c490 100644 --- a/docs/Custom_Operation_for_GPUs.py +++ b/docs/Custom_Operation_for_GPUs.py @@ -479,7 +479,7 @@ seq_len = 512 emb_dim = 512 assert jax.local_device_count() > 1, "Only 1 GPU, the example work, but it is this really what you want?" x = jax.random.normal( - jax.random.PRNGKey(0), + jax.random.key(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.float16, ) diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index da95f96d8..d58a45d1d 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -590,7 +590,7 @@ def vmap_mjp(f, x, M): outs, = vmap(vjp_fun)(M) return outs -key = random.PRNGKey(0) +key = random.key(0) num_covecs = 128 U = random.normal(key, (num_covecs,) + y.shape) @@ -714,7 +714,7 @@ Here's a check: ```{code-cell} def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) @@ -768,7 +768,7 @@ Here's a check of the VJP rules: ```{code-cell} def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 919690d23..6fbe67f4d 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -1509,7 +1509,7 @@ "layer_sizes = [784, 128, 128, 128, 128, 128, 8]\n", "batch_size = 32\n", "\n", - "params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)" + "params, batch = init(jax.random.key(0), layer_sizes, batch_size)" ] }, { diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 6f9dfbb65..389b62c8a 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -1055,7 +1055,7 @@ def init(key, layer_sizes, batch_size): layer_sizes = [784, 128, 128, 128, 128, 128, 8] batch_size = 32 -params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size) +params, batch = init(jax.random.key(0), layer_sizes, batch_size) ``` Compare these examples with the purely [automatic partitioning examples in the diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 7d228e4be..eb1bb1609 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -184,7 +184,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.truncated_normal(5.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 2.9047365, 5.2338114, 5.29852 ], [-3.836303 , -4.192359 , 0.6022964]], dtype=float32) """