From bec24999a82b1eaa4b508058963ac58bad152a5f Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 12 Dec 2018 16:08:59 -0800 Subject: [PATCH] import more examples in examples_test, fix resulting errors --- examples/mnist_classifier.py | 3 +-- examples/mnist_classifier_fromscratch.py | 3 +-- examples/mnist_vae.py | 3 +-- tests/examples_test.py | 3 +++ 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/mnist_classifier.py b/examples/mnist_classifier.py index 6336ac675..be063a29f 100644 --- a/examples/mnist_classifier.py +++ b/examples/mnist_classifier.py @@ -31,7 +31,7 @@ from jax import jit, grad from jax.experimental import minmax from jax.experimental import stax from jax.experimental.stax import Dense, Relu, LogSoftmax -import datasets +from examples import datasets def loss(params, batch): @@ -94,4 +94,3 @@ if __name__ == "__main__": print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc)) - diff --git a/examples/mnist_classifier_fromscratch.py b/examples/mnist_classifier_fromscratch.py index b0dfbb5cd..d9b1b9a98 100644 --- a/examples/mnist_classifier_fromscratch.py +++ b/examples/mnist_classifier_fromscratch.py @@ -29,7 +29,7 @@ from jax.api import jit, grad from jax.config import config from jax.scipy.misc import logsumexp import jax.numpy as np -import datasets +from examples import datasets def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): @@ -93,4 +93,3 @@ if __name__ == "__main__": print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc)) - diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index bda0a2bc2..9f0449bcc 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -33,7 +33,7 @@ from jax import jit, grad, lax, random from jax.experimental import minmax from jax.experimental import stax from jax.experimental.stax import Dense, FanOut, Relu, Softplus -import datasets +from examples import datasets def gaussian_kl(mu, sigmasq): @@ -139,4 +139,3 @@ if __name__ == "__main__": test_elbo, sampled_images = evaluate(opt_state, test_images) print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic)) plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray) - diff --git a/tests/examples_test.py b/tests/examples_test.py index e4614fd74..50093f4e4 100644 --- a/tests/examples_test.py +++ b/tests/examples_test.py @@ -27,6 +27,9 @@ import numpy as onp from jax import test_util as jtu sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from examples import mnist_classifier +from examples import mnist_classifier_fromscratch +from examples import mnist_vae from examples import resnet50 sys.path.pop()