tweak docstrings in mnist examples

This commit is contained in:
Matthew Johnson 2019-05-23 09:07:44 -07:00
parent 5d902298b9
commit 92e5f93a29
5 changed files with 202 additions and 4 deletions

38
examples/fluidsim.py Normal file
View File

@ -0,0 +1,38 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def project(vx, vy):
h = 1. / vx.shape[0]
div = -0.5 * h * (np.roll(vx, -1, axis=0) - np.roll(vx, 1, axis=0) +
np.roll(vy, -1, axis=1) - np.roll(vy, 1, axis=1))
p = np.zeros(vx.shape)
for k in range(10):
p = (div + np.roll(p, 1, axis=0) + np.roll(p, -1, axis=0)
+ np.roll(p, 1, axis=1) + np.roll(p, -1, axis=1)) / 4.
vx = vx - 0.5 * (np.roll(p, -1, axis=0) - np.roll(p, 1, axis=0)) / h
vy = vy - 0.5 * (np.roll(p, -1, axis=1) - np.roll(p, 1, axis=1)) / h
return vx, vy
def advect(f, vx, vy):
rows, cols = f.shape
cell_ys, cell_xs = np.meshgrid(np.arange(rows), np.arange(cols))
return linear_interpolate(f, cell_xs - vx, cell_ys - vy)
def linear_interpolate(

98
examples/mnist_bench.py Normal file
View File

@ -0,0 +1,98 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using JAX together with the mini-libraries stax, for
neural network building, and optimizers, for first-order stochastic optimization.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import itertools
import numpy.random as npr
import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
from examples import datasets
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
Dense(1024), Relu,
Dense(1024), Relu,
Dense(10), LogSoftmax)
if __name__ == "__main__":
rng = random.PRNGKey(0)
step_size = 0.001
num_epochs = 3
batch_size = 128
momentum_mass = 0.9
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = optimizers.get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = optimizers.get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))

View File

@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using JAX together with the mini-libraries stax, for
neural network building, and optimizers, for first-order stochastic optimization.
"""A basic MNIST example using JAX with the mini-libraries stax and optimizers.
The mini-library jax.experimental.stax is for neural network building, and
the mini-library jax.experimentaloptimizers is for first-order stochastic
optimization.
"""
from __future__ import absolute_import

View File

@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using Numpy and JAX.
"""An MNIST example with single-program multiple-data (SPMD) data parallelism.
The primary aim here is simplicity and minimal dependencies.
The aim here is to illustrate how to use JAX's `pmap` to express and execute
SPMD programs for data parallelism along a batch dimension, while also
minimizing dependencies by avoiding the use of higher-level layers and
optimizers libraries.
"""
from __future__ import absolute_import

View File

@ -0,0 +1,56 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An MNIST example for single-program multiple-data (SPMD) spatial parallelism.
The aim here is to illustrate how to use JAX's `pmap` to express and execute
SPMD programs for data parallelism along a spatial dimension (rather than a
batch dimension).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import itertools
import numpy.random as npr
import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Relu, LogSoftmax
from examples import datasets
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
SpmdConv(3, (2, 2), axis_name="x"), Relu,
SpmdConv(3, (2, 2), axis_name="x"), Relu,
SpmdDense(10), LogSoftmax)