2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2018-11-19 07:43:23 -08:00
|
|
|
"""A basic MNIST example using JAX together with the mini-libraries stax, for
|
2019-02-06 11:02:16 -08:00
|
|
|
neural network building, and optimizers, for first-order stochastic optimization.
|
2018-11-17 18:03:33 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import time
|
2018-11-19 07:43:23 -08:00
|
|
|
import itertools
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
import numpy.random as npr
|
|
|
|
|
|
|
|
import jax.numpy as np
|
2018-11-29 12:30:34 -08:00
|
|
|
from jax.config import config
|
2019-04-03 12:54:02 +01:00
|
|
|
from jax import jit, grad, random
|
2019-02-06 11:02:16 -08:00
|
|
|
from jax.experimental import optimizers
|
2018-11-19 15:50:10 -08:00
|
|
|
from jax.experimental import stax
|
2018-11-19 17:52:23 -08:00
|
|
|
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
2018-12-12 16:08:59 -08:00
|
|
|
from examples import datasets
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def loss(params, batch):
|
|
|
|
inputs, targets = batch
|
|
|
|
preds = predict(params, inputs)
|
|
|
|
return -np.mean(preds * targets)
|
|
|
|
|
|
|
|
def accuracy(params, batch):
|
|
|
|
inputs, targets = batch
|
|
|
|
target_class = np.argmax(targets, axis=1)
|
|
|
|
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
|
|
|
return np.mean(predicted_class == target_class)
|
|
|
|
|
2018-11-19 07:43:23 -08:00
|
|
|
init_random_params, predict = stax.serial(
|
|
|
|
Dense(1024), Relu,
|
|
|
|
Dense(1024), Relu,
|
2018-11-19 17:52:23 -08:00
|
|
|
Dense(10), LogSoftmax)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 12:10:31 -08:00
|
|
|
if __name__ == "__main__":
|
2019-04-03 12:54:02 +01:00
|
|
|
rng = random.PRNGKey(0)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
step_size = 0.001
|
|
|
|
num_epochs = 10
|
2018-12-13 18:24:35 -08:00
|
|
|
batch_size = 128
|
2018-11-19 07:43:23 -08:00
|
|
|
momentum_mass = 0.9
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
train_images, train_labels, test_images, test_labels = datasets.mnist()
|
|
|
|
num_train = train_images.shape[0]
|
|
|
|
num_complete_batches, leftover = divmod(num_train, batch_size)
|
|
|
|
num_batches = num_complete_batches + bool(leftover)
|
|
|
|
|
|
|
|
def data_stream():
|
|
|
|
rng = npr.RandomState(0)
|
|
|
|
while True:
|
|
|
|
perm = rng.permutation(num_train)
|
|
|
|
for i in range(num_batches):
|
|
|
|
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
|
|
|
|
yield train_images[batch_idx], train_labels[batch_idx]
|
|
|
|
batches = data_stream()
|
|
|
|
|
2019-05-03 12:37:14 -07:00
|
|
|
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
|
2018-11-19 07:43:23 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
@jit
|
2018-11-19 07:43:23 -08:00
|
|
|
def update(i, opt_state, batch):
|
2019-05-03 12:37:14 -07:00
|
|
|
params = get_params(opt_state)
|
2018-11-19 07:43:23 -08:00
|
|
|
return opt_update(i, grad(loss)(params, batch), opt_state)
|
|
|
|
|
2019-04-03 12:54:02 +01:00
|
|
|
_, init_params = init_random_params(rng, (-1, 28 * 28))
|
2018-11-19 07:43:23 -08:00
|
|
|
opt_state = opt_init(init_params)
|
|
|
|
itercount = itertools.count()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-29 12:30:34 -08:00
|
|
|
print("\nStarting training...")
|
2018-11-17 18:03:33 -08:00
|
|
|
for epoch in range(num_epochs):
|
|
|
|
start_time = time.time()
|
|
|
|
for _ in range(num_batches):
|
2018-11-19 07:43:23 -08:00
|
|
|
opt_state = update(next(itercount), opt_state, next(batches))
|
2018-11-17 18:03:33 -08:00
|
|
|
epoch_time = time.time() - start_time
|
|
|
|
|
2019-05-03 12:37:14 -07:00
|
|
|
params = get_params(opt_state)
|
2018-11-17 18:03:33 -08:00
|
|
|
train_acc = accuracy(params, (train_images, train_labels))
|
|
|
|
test_acc = accuracy(params, (test_images, test_labels))
|
|
|
|
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
|
|
|
|
print("Training set accuracy {}".format(train_acc))
|
|
|
|
print("Test set accuracy {}".format(test_acc))
|