remove absl from examples, fix import statements

This commit is contained in:
Matthew Johnson 2018-11-21 12:10:31 -08:00
parent 7cf6babc78
commit 51fc713089
5 changed files with 7 additions and 64 deletions

View File

@ -1,34 +0,0 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from absl import app
import IPython
import numpy as onp
import jax
import jax.numpy as np
from jax import lax
from jax import random
from jax import jit, grad, vmap, jacfwd, jacrev, hessian
def main(unused_argv):
IPython.embed(user_ns=dict(globals(), **locals()))
if __name__ == "__main__":
app.run(main)

View File

@ -23,15 +23,14 @@ from __future__ import print_function
import time
import itertools
from absl import app
import numpy.random as npr
import jax.numpy as np
from jax import jit, grad
from jax.experimental import minmax
from jax.examples import datasets
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
import datasets
def loss(params, batch):
@ -50,7 +49,7 @@ init_random_params, predict = stax.serial(
Dense(1024), Relu,
Dense(10), LogSoftmax)
def main(unused_argv):
if __name__ == "__main__":
step_size = 0.001
num_epochs = 10
batch_size = 32
@ -93,7 +92,3 @@ def main(unused_argv):
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == "__main__":
app.run(main)

View File

@ -23,13 +23,12 @@ from __future__ import print_function
import time
from absl import app
import numpy.random as npr
from jax.api import jit, grad
from jax.examples import datasets
from jax.scipy.misc import logsumexp
import jax.numpy as np
import datasets
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
@ -54,7 +53,7 @@ def accuracy(params, batch):
return np.mean(predicted_class == target_class)
def main(unused_argv):
if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10] # TODO(mattjj): revise to standard arch
param_scale = 0.1
step_size = 0.001
@ -93,7 +92,3 @@ def main(unused_argv):
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == "__main__":
app.run(main)

View File

@ -25,15 +25,14 @@ from __future__ import print_function
import os
import time
from absl import app
import matplotlib.pyplot as plt
import jax.numpy as np
from jax import jit, grad, lax, random
from jax.examples import datasets
from jax.experimental import minmax
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
import datasets
def gaussian_kl(mu, sigmasq):
@ -84,7 +83,7 @@ decoder_init, decode = stax.serial(
)
def main(unused_argv):
if __name__ == "__main__":
step_size = 0.001
num_epochs = 100
batch_size = 32
@ -138,7 +137,3 @@ def main(unused_argv):
test_elbo, images = evaluate(opt_state, test_images)
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
plt.imsave(imfile.format(epoch), images, cmap=plt.cm.gray)
if __name__ == "__main__":
app.run(main)

View File

@ -21,8 +21,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
import numpy.random as npr
import jax.numpy as np
@ -85,9 +83,7 @@ def ResNet50(num_classes):
AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
def main(argv):
del argv # Unused.
if __name__ == "__main__":
batch_size = 8
num_classes = 1001
input_shape = (224, 224, 3, batch_size)
@ -128,7 +124,3 @@ def main(argv):
for i in xrange(num_steps):
opt_state = update(i, opt_state, next(batches))
trained_params = minmax.get_params(opt_state)
if __name__ == '__main__':
app.run(main)