source sync fixups

This commit is contained in:
Roy Frostig 2018-11-21 20:26:59 -08:00
parent f1bb77dafb
commit 7a313bf622
4 changed files with 1 additions and 24 deletions

View File

@ -23,7 +23,6 @@ from __future__ import print_function
import time
import itertools
from absl import app
import numpy.random as npr
import jax.numpy as np
@ -93,7 +92,3 @@ if __name__ == "__main__":
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == "__main__":
app.run(main)

View File

@ -23,7 +23,6 @@ from __future__ import print_function
import time
from absl import app
import numpy.random as npr
from jax.api import jit, grad
@ -93,7 +92,3 @@ if __name__ == "__main__":
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == "__main__":
app.run(main)

View File

@ -25,7 +25,6 @@ from __future__ import print_function
import os
import time
from absl import app
import matplotlib.pyplot as plt
import jax.numpy as np
@ -138,7 +137,3 @@ if __name__ == "__main__":
test_elbo, sampled_images = evaluate(opt_state, test_images)
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)
if __name__ == "__main__":
app.run(main)

View File

@ -21,8 +21,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
import numpy.random as npr
import jax.numpy as np
@ -85,9 +83,7 @@ def ResNet50(num_classes):
AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
def main(argv):
del argv # Unused.
if __name__ == "__main__":
batch_size = 8
num_classes = 1001
input_shape = (224, 224, 3, batch_size)
@ -128,7 +124,3 @@ def main(argv):
for i in xrange(num_steps):
opt_state = update(i, opt_state, next(batches))
trained_params = minmax.get_params(opt_state)
if __name__ == '__main__':
app.run(main)