mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
source sync fixups
This commit is contained in:
parent
f1bb77dafb
commit
7a313bf622
@ -23,7 +23,6 @@ from __future__ import print_function
|
||||
import time
|
||||
import itertools
|
||||
|
||||
from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
@ -93,7 +92,3 @@ if __name__ == "__main__":
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -23,7 +23,6 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
@ -93,7 +92,3 @@ if __name__ == "__main__":
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -25,7 +25,6 @@ from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import jax.numpy as np
|
||||
@ -138,7 +137,3 @@ if __name__ == "__main__":
|
||||
test_elbo, sampled_images = evaluate(opt_state, test_images)
|
||||
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
|
||||
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -21,8 +21,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
@ -85,9 +83,7 @@ def ResNet50(num_classes):
|
||||
AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_size = 8
|
||||
num_classes = 1001
|
||||
input_shape = (224, 224, 3, batch_size)
|
||||
@ -128,7 +124,3 @@ def main(argv):
|
||||
for i in xrange(num_steps):
|
||||
opt_state = update(i, opt_state, next(batches))
|
||||
trained_params = minmax.get_params(opt_state)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
|
Loading…
x
Reference in New Issue
Block a user