mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 06:06:07 +00:00
update top-level examples to use new-style typed keys
This commit is contained in:
parent
753a6b2b0c
commit
78fd4f1664
@ -78,7 +78,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
@jit
|
@jit
|
||||||
def objective(params, t):
|
def objective(params, t):
|
||||||
rng = random.PRNGKey(t)
|
rng = random.key(t)
|
||||||
return -batch_elbo(funnel_log_density, rng, params, num_samples)
|
return -batch_elbo(funnel_log_density, rng, params, num_samples)
|
||||||
|
|
||||||
# Set up figure.
|
# Set up figure.
|
||||||
@ -107,7 +107,7 @@ if __name__ == "__main__":
|
|||||||
# Plot random samples from variational distribution.
|
# Plot random samples from variational distribution.
|
||||||
# Here we clone the rng used in computing the objective
|
# Here we clone the rng used in computing the objective
|
||||||
# so that we can show exactly the same samples.
|
# so that we can show exactly the same samples.
|
||||||
rngs = random.split(random.PRNGKey(t), num_samples)
|
rngs = random.split(random.key(t), num_samples)
|
||||||
samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
|
samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
|
||||||
ax.plot(samples[:, 0], samples[:, 1], 'b.')
|
ax.plot(samples[:, 0], samples[:, 1], 'b.')
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ def main(_):
|
|||||||
num_train = train_images.shape[0]
|
num_train = train_images.shape[0]
|
||||||
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
|
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
|
||||||
num_batches = num_complete_batches + bool(leftover)
|
num_batches = num_complete_batches + bool(leftover)
|
||||||
key = random.PRNGKey(_SEED.value)
|
key = random.key(_SEED.value)
|
||||||
|
|
||||||
def data_stream():
|
def data_stream():
|
||||||
rng = npr.RandomState(_SEED.value)
|
rng = npr.RandomState(_SEED.value)
|
||||||
|
@ -35,7 +35,7 @@ config.parse_flags_with_absl()
|
|||||||
|
|
||||||
|
|
||||||
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||||
jax_rng = random.PRNGKey(0)
|
jax_rng = random.key(0)
|
||||||
result_shape, params = init_fun(jax_rng, input_shape)
|
result_shape, params = init_fun(jax_rng, input_shape)
|
||||||
result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32"))
|
result = apply_fun(params, test_case.rng.normal(size=input_shape).astype("float32"))
|
||||||
test_case.assertEqual(result.shape, result_shape)
|
test_case.assertEqual(result.shape, result_shape)
|
||||||
|
@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
|
|||||||
def main(unused_argv):
|
def main(unused_argv):
|
||||||
|
|
||||||
numpts = 7
|
numpts = 7
|
||||||
key = random.PRNGKey(0)
|
key = random.key(0)
|
||||||
eye = jnp.eye(numpts)
|
eye = jnp.eye(numpts)
|
||||||
|
|
||||||
def cov_map(cov_func, xs, xs2=None):
|
def cov_map(cov_func, xs, xs2=None):
|
||||||
|
@ -50,7 +50,7 @@ init_random_params, predict = stax.serial(
|
|||||||
Dense(10), LogSoftmax)
|
Dense(10), LogSoftmax)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
rng = random.PRNGKey(0)
|
rng = random.key(0)
|
||||||
|
|
||||||
step_size = 0.001
|
step_size = 0.001
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
|
@ -87,14 +87,14 @@ if __name__ == "__main__":
|
|||||||
batch_size = 32
|
batch_size = 32
|
||||||
nrow, ncol = 10, 10 # sampled image grid size
|
nrow, ncol = 10, 10 # sampled image grid size
|
||||||
|
|
||||||
test_rng = random.PRNGKey(1) # fixed prng key for evaluation
|
test_rng = random.key(1) # fixed prng key for evaluation
|
||||||
imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png")
|
imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png")
|
||||||
|
|
||||||
train_images, _, test_images, _ = datasets.mnist(permute_train=True)
|
train_images, _, test_images, _ = datasets.mnist(permute_train=True)
|
||||||
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
|
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
|
||||||
num_batches = num_complete_batches + bool(leftover)
|
num_batches = num_complete_batches + bool(leftover)
|
||||||
|
|
||||||
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
|
enc_init_rng, dec_init_rng = random.split(random.key(2))
|
||||||
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
|
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
|
||||||
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
|
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
|
||||||
init_params = init_encoder_params, init_decoder_params
|
init_params = init_encoder_params, init_decoder_params
|
||||||
@ -131,7 +131,7 @@ if __name__ == "__main__":
|
|||||||
opt_state = opt_init(init_params)
|
opt_state = opt_init(init_params)
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images)
|
opt_state = run_epoch(random.key(epoch), opt_state, train_images)
|
||||||
test_elbo, sampled_images = evaluate(opt_state, test_images)
|
test_elbo, sampled_images = evaluate(opt_state, test_images)
|
||||||
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
|
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
|
||||||
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)
|
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user