mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
update top-level examples to use new-style typed keys
This commit is contained in:
parent
753a6b2b0c
commit
78fd4f1664
@ -78,7 +78,7 @@ if __name__ == "__main__":
|
||||
|
||||
@jit
|
||||
def objective(params, t):
|
||||
rng = random.PRNGKey(t)
|
||||
rng = random.key(t)
|
||||
return -batch_elbo(funnel_log_density, rng, params, num_samples)
|
||||
|
||||
# Set up figure.
|
||||
@ -107,7 +107,7 @@ if __name__ == "__main__":
|
||||
# Plot random samples from variational distribution.
|
||||
# Here we clone the rng used in computing the objective
|
||||
# so that we can show exactly the same samples.
|
||||
rngs = random.split(random.PRNGKey(t), num_samples)
|
||||
rngs = random.split(random.key(t), num_samples)
|
||||
samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
|
||||
ax.plot(samples[:, 0], samples[:, 1], 'b.')
|
||||
|
||||
|
@ -182,7 +182,7 @@ def main(_):
|
||||
num_train = train_images.shape[0]
|
||||
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
key = random.PRNGKey(_SEED.value)
|
||||
key = random.key(_SEED.value)
|
||||
|
||||
def data_stream():
|
||||
rng = npr.RandomState(_SEED.value)
|
||||
|
@ -35,7 +35,7 @@ config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||
jax_rng = random.PRNGKey(0)
|
||||
jax_rng = random.key(0)
|
||||
result_shape, params = init_fun(jax_rng, input_shape)
|
||||
result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32"))
|
||||
test_case.assertEqual(result.shape, result_shape)
|
||||
|
@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
|
||||
def main(unused_argv):
|
||||
|
||||
numpts = 7
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
eye = jnp.eye(numpts)
|
||||
|
||||
def cov_map(cov_func, xs, xs2=None):
|
||||
|
@ -50,7 +50,7 @@ init_random_params, predict = stax.serial(
|
||||
Dense(10), LogSoftmax)
|
||||
|
||||
if __name__ == "__main__":
|
||||
rng = random.PRNGKey(0)
|
||||
rng = random.key(0)
|
||||
|
||||
step_size = 0.001
|
||||
num_epochs = 10
|
||||
|
@ -87,14 +87,14 @@ if __name__ == "__main__":
|
||||
batch_size = 32
|
||||
nrow, ncol = 10, 10 # sampled image grid size
|
||||
|
||||
test_rng = random.PRNGKey(1) # fixed prng key for evaluation
|
||||
test_rng = random.key(1) # fixed prng key for evaluation
|
||||
imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png")
|
||||
|
||||
train_images, _, test_images, _ = datasets.mnist(permute_train=True)
|
||||
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
|
||||
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
|
||||
enc_init_rng, dec_init_rng = random.split(random.key(2))
|
||||
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
|
||||
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
|
||||
init_params = init_encoder_params, init_decoder_params
|
||||
@ -131,7 +131,7 @@ if __name__ == "__main__":
|
||||
opt_state = opt_init(init_params)
|
||||
for epoch in range(num_epochs):
|
||||
tic = time.time()
|
||||
opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images)
|
||||
opt_state = run_epoch(random.key(epoch), opt_state, train_images)
|
||||
test_elbo, sampled_images = evaluate(opt_state, test_images)
|
||||
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
|
||||
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)
|
||||
|
Loading…
x
Reference in New Issue
Block a user