import more examples in examples_test, fix resulting errors

This commit is contained in:
Roy Frostig 2018-12-12 16:08:59 -08:00
parent 91fe4a1bcc
commit bec24999a8
4 changed files with 6 additions and 6 deletions

View File

@ -31,7 +31,7 @@ from jax import jit, grad
from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
import datasets
from examples import datasets
def loss(params, batch):
@ -94,4 +94,3 @@ if __name__ == "__main__":
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))

View File

@ -29,7 +29,7 @@ from jax.api import jit, grad
from jax.config import config
from jax.scipy.misc import logsumexp
import jax.numpy as np
import datasets
from examples import datasets
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
@ -93,4 +93,3 @@ if __name__ == "__main__":
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))

View File

@ -33,7 +33,7 @@ from jax import jit, grad, lax, random
from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
import datasets
from examples import datasets
def gaussian_kl(mu, sigmasq):
@ -139,4 +139,3 @@ if __name__ == "__main__":
test_elbo, sampled_images = evaluate(opt_state, test_images)
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)

View File

@ -27,6 +27,9 @@ import numpy as onp
from jax import test_util as jtu
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from examples import mnist_classifier
from examples import mnist_classifier_fromscratch
from examples import mnist_vae
from examples import resnet50
sys.path.pop()