diff --git a/examples/advi.py b/examples/advi.py index 35ee94a58..68092b2cf 100644 --- a/examples/advi.py +++ b/examples/advi.py @@ -78,7 +78,7 @@ if __name__ == "__main__": @jit def objective(params, t): - rng = random.PRNGKey(t) + rng = random.key(t) return -batch_elbo(funnel_log_density, rng, params, num_samples) # Set up figure. @@ -107,7 +107,7 @@ if __name__ == "__main__": # Plot random samples from variational distribution. # Here we clone the rng used in computing the objective # so that we can show exactly the same samples. - rngs = random.split(random.PRNGKey(t), num_samples) + rngs = random.split(random.key(t), num_samples) samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params) ax.plot(samples[:, 0], samples[:, 1], 'b.') diff --git a/examples/differentially_private_sgd.py b/examples/differentially_private_sgd.py index 4777554b1..ca368098d 100644 --- a/examples/differentially_private_sgd.py +++ b/examples/differentially_private_sgd.py @@ -182,7 +182,7 @@ def main(_): num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value) num_batches = num_complete_batches + bool(leftover) - key = random.PRNGKey(_SEED.value) + key = random.key(_SEED.value) def data_stream(): rng = npr.RandomState(_SEED.value) diff --git a/examples/examples_test.py b/examples/examples_test.py index e2ca51d78..b8b4d11e2 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -35,7 +35,7 @@ config.parse_flags_with_absl() def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape): - jax_rng = random.PRNGKey(0) + jax_rng = random.key(0) result_shape, params = init_fun(jax_rng, input_shape) result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32")) test_case.assertEqual(result.shape, result_shape) diff --git a/examples/gaussian_process_regression.py b/examples/gaussian_process_regression.py index 070943b72..c42a024d4 100644 --- a/examples/gaussian_process_regression.py +++ b/examples/gaussian_process_regression.py @@ -30,7 +30,7 @@ import matplotlib.pyplot as plt def main(unused_argv): numpts = 7 - key = random.PRNGKey(0) + key = random.key(0) eye = jnp.eye(numpts) def cov_map(cov_func, xs, xs2=None): diff --git a/examples/mnist_classifier.py b/examples/mnist_classifier.py index a0fe8b996..a7730ab2b 100644 --- a/examples/mnist_classifier.py +++ b/examples/mnist_classifier.py @@ -50,7 +50,7 @@ init_random_params, predict = stax.serial( Dense(10), LogSoftmax) if __name__ == "__main__": - rng = random.PRNGKey(0) + rng = random.key(0) step_size = 0.001 num_epochs = 10 diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index 141be978f..df207afd8 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -87,14 +87,14 @@ if __name__ == "__main__": batch_size = 32 nrow, ncol = 10, 10 # sampled image grid size - test_rng = random.PRNGKey(1) # fixed prng key for evaluation + test_rng = random.key(1) # fixed prng key for evaluation imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png") train_images, _, test_images, _ = datasets.mnist(permute_train=True) num_complete_batches, leftover = divmod(train_images.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) - enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2)) + enc_init_rng, dec_init_rng = random.split(random.key(2)) _, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28)) _, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10)) init_params = init_encoder_params, init_decoder_params @@ -131,7 +131,7 @@ if __name__ == "__main__": opt_state = opt_init(init_params) for epoch in range(num_epochs): tic = time.time() - opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images) + opt_state = run_epoch(random.key(epoch), opt_state, train_images) test_elbo, sampled_images = evaluate(opt_state, test_images) print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)") plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)