update top-level examples to use new-style typed keys

This commit is contained in:
Roy Frostig 2023-08-17 17:33:43 -07:00
parent 753a6b2b0c
commit 78fd4f1664
6 changed files with 9 additions and 9 deletions

View File

@ -78,7 +78,7 @@ if __name__ == "__main__":
@jit
def objective(params, t):
rng = random.PRNGKey(t)
rng = random.key(t)
return -batch_elbo(funnel_log_density, rng, params, num_samples)
# Set up figure.
@ -107,7 +107,7 @@ if __name__ == "__main__":
# Plot random samples from variational distribution.
# Here we clone the rng used in computing the objective
# so that we can show exactly the same samples.
rngs = random.split(random.PRNGKey(t), num_samples)
rngs = random.split(random.key(t), num_samples)
samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
ax.plot(samples[:, 0], samples[:, 1], 'b.')

View File

@ -182,7 +182,7 @@ def main(_):
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
num_batches = num_complete_batches + bool(leftover)
key = random.PRNGKey(_SEED.value)
key = random.key(_SEED.value)
def data_stream():
rng = npr.RandomState(_SEED.value)

View File

@ -35,7 +35,7 @@ config.parse_flags_with_absl()
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
jax_rng = random.PRNGKey(0)
jax_rng = random.key(0)
result_shape, params = init_fun(jax_rng, input_shape)
result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32"))
test_case.assertEqual(result.shape, result_shape)

View File

@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
def main(unused_argv):
numpts = 7
key = random.PRNGKey(0)
key = random.key(0)
eye = jnp.eye(numpts)
def cov_map(cov_func, xs, xs2=None):

View File

@ -50,7 +50,7 @@ init_random_params, predict = stax.serial(
Dense(10), LogSoftmax)
if __name__ == "__main__":
rng = random.PRNGKey(0)
rng = random.key(0)
step_size = 0.001
num_epochs = 10

View File

@ -87,14 +87,14 @@ if __name__ == "__main__":
batch_size = 32
nrow, ncol = 10, 10 # sampled image grid size
test_rng = random.PRNGKey(1) # fixed prng key for evaluation
test_rng = random.key(1) # fixed prng key for evaluation
imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png")
train_images, _, test_images, _ = datasets.mnist(permute_train=True)
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
num_batches = num_complete_batches + bool(leftover)
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
enc_init_rng, dec_init_rng = random.split(random.key(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params
@ -131,7 +131,7 @@ if __name__ == "__main__":
opt_state = opt_init(init_params)
for epoch in range(num_epochs):
tic = time.time()
opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images)
opt_state = run_epoch(random.key(epoch), opt_state, train_images)
test_elbo, sampled_images = evaluate(opt_state, test_images)
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)