mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
remove absl from examples, fix import statements
This commit is contained in:
parent
7cf6babc78
commit
51fc713089
@ -1,34 +0,0 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from absl import app
|
||||
|
||||
import IPython
|
||||
import numpy as onp
|
||||
|
||||
import jax
|
||||
import jax.numpy as np
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax import jit, grad, vmap, jacfwd, jacrev, hessian
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
IPython.embed(user_ns=dict(globals(), **locals()))
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -23,15 +23,14 @@ from __future__ import print_function
|
||||
import time
|
||||
import itertools
|
||||
|
||||
from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad
|
||||
from jax.experimental import minmax
|
||||
from jax.examples import datasets
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
import datasets
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
@ -50,7 +49,7 @@ init_random_params, predict = stax.serial(
|
||||
Dense(1024), Relu,
|
||||
Dense(10), LogSoftmax)
|
||||
|
||||
def main(unused_argv):
|
||||
if __name__ == "__main__":
|
||||
step_size = 0.001
|
||||
num_epochs = 10
|
||||
batch_size = 32
|
||||
@ -93,7 +92,3 @@ def main(unused_argv):
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -23,13 +23,12 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
from jax.examples import datasets
|
||||
from jax.scipy.misc import logsumexp
|
||||
import jax.numpy as np
|
||||
import datasets
|
||||
|
||||
|
||||
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
|
||||
@ -54,7 +53,7 @@ def accuracy(params, batch):
|
||||
return np.mean(predicted_class == target_class)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
if __name__ == "__main__":
|
||||
layer_sizes = [784, 1024, 1024, 10] # TODO(mattjj): revise to standard arch
|
||||
param_scale = 0.1
|
||||
step_size = 0.001
|
||||
@ -93,7 +92,3 @@ def main(unused_argv):
|
||||
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
||||
print("Training set accuracy {}".format(train_acc))
|
||||
print("Test set accuracy {}".format(test_acc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -25,15 +25,14 @@ from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import jax.numpy as np
|
||||
from jax import jit, grad, lax, random
|
||||
from jax.examples import datasets
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
import datasets
|
||||
|
||||
|
||||
def gaussian_kl(mu, sigmasq):
|
||||
@ -84,7 +83,7 @@ decoder_init, decode = stax.serial(
|
||||
)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
if __name__ == "__main__":
|
||||
step_size = 0.001
|
||||
num_epochs = 100
|
||||
batch_size = 32
|
||||
@ -138,7 +137,3 @@ def main(unused_argv):
|
||||
test_elbo, images = evaluate(opt_state, test_images)
|
||||
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
|
||||
plt.imsave(imfile.format(epoch), images, cmap=plt.cm.gray)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
@ -21,8 +21,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
@ -85,9 +83,7 @@ def ResNet50(num_classes):
|
||||
AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_size = 8
|
||||
num_classes = 1001
|
||||
input_shape = (224, 224, 3, batch_size)
|
||||
@ -128,7 +124,3 @@ def main(argv):
|
||||
for i in xrange(num_steps):
|
||||
opt_state = update(i, opt_state, next(batches))
|
||||
trained_params = minmax.get_params(opt_state)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
|
Loading…
x
Reference in New Issue
Block a user