mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
import more examples in examples_test, fix resulting errors
This commit is contained in:
parent
91fe4a1bcc
commit
bec24999a8
@ -31,7 +31,7 @@ from jax import jit, grad
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
import datasets
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
@ -94,4 +94,3 @@ if __name__ == "__main__":
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
@ -29,7 +29,7 @@ from jax.api import jit, grad
|
||||
from jax.config import config
|
||||
from jax.scipy.misc import logsumexp
|
||||
import jax.numpy as np
|
||||
import datasets
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
@ -93,4 +93,3 @@ if __name__ == "__main__":
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
@ -33,7 +33,7 @@ from jax import jit, grad, lax, random
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
import datasets
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def gaussian_kl(mu, sigmasq):
|
||||
@ -139,4 +139,3 @@ if __name__ == "__main__":
|
||||
test_elbo, sampled_images = evaluate(opt_state, test_images)
|
||||
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
|
||||
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)
|
||||
|
||||
|
@ -27,6 +27,9 @@ import numpy as onp
|
||||
from jax import test_util as jtu
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from examples import mnist_classifier
|
||||
from examples import mnist_classifier_fromscratch
|
||||
from examples import mnist_vae
|
||||
from examples import resnet50
|
||||
sys.path.pop()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user