From 92e5f93a290ec8146ab6a667abe54adbf39d61ff Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 23 May 2019 09:07:44 -0700 Subject: [PATCH] tweak docstrings in mnist examples --- examples/fluidsim.py | 38 +++++++ examples/mnist_bench.py | 98 +++++++++++++++++++ examples/mnist_classifier.py | 7 +- examples/spmd_mnist_classifier_fromscratch.py | 7 +- examples/spmd_spatially_sharded_conv_net.py | 56 +++++++++++ 5 files changed, 202 insertions(+), 4 deletions(-) create mode 100644 examples/fluidsim.py create mode 100644 examples/mnist_bench.py create mode 100644 examples/spmd_spatially_sharded_conv_net.py diff --git a/examples/fluidsim.py b/examples/fluidsim.py new file mode 100644 index 000000000..72cc8a165 --- /dev/null +++ b/examples/fluidsim.py @@ -0,0 +1,38 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +def project(vx, vy): + h = 1. / vx.shape[0] + div = -0.5 * h * (np.roll(vx, -1, axis=0) - np.roll(vx, 1, axis=0) + + np.roll(vy, -1, axis=1) - np.roll(vy, 1, axis=1)) + + p = np.zeros(vx.shape) + for k in range(10): + p = (div + np.roll(p, 1, axis=0) + np.roll(p, -1, axis=0) + + np.roll(p, 1, axis=1) + np.roll(p, -1, axis=1)) / 4. + + vx = vx - 0.5 * (np.roll(p, -1, axis=0) - np.roll(p, 1, axis=0)) / h + vy = vy - 0.5 * (np.roll(p, -1, axis=1) - np.roll(p, 1, axis=1)) / h + return vx, vy + +def advect(f, vx, vy): + rows, cols = f.shape + cell_ys, cell_xs = np.meshgrid(np.arange(rows), np.arange(cols)) + return linear_interpolate(f, cell_xs - vx, cell_ys - vy) + +def linear_interpolate( diff --git a/examples/mnist_bench.py b/examples/mnist_bench.py new file mode 100644 index 000000000..6e957fa02 --- /dev/null +++ b/examples/mnist_bench.py @@ -0,0 +1,98 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A basic MNIST example using JAX together with the mini-libraries stax, for +neural network building, and optimizers, for first-order stochastic optimization. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import itertools + +import numpy.random as npr + +import jax.numpy as np +from jax.config import config +from jax import jit, grad, random +from jax.experimental import optimizers +from jax.experimental import stax +from jax.experimental.stax import Dense, Relu, LogSoftmax +from examples import datasets + + +def loss(params, batch): + inputs, targets = batch + preds = predict(params, inputs) + return -np.mean(preds * targets) + +def accuracy(params, batch): + inputs, targets = batch + target_class = np.argmax(targets, axis=1) + predicted_class = np.argmax(predict(params, inputs), axis=1) + return np.mean(predicted_class == target_class) + +init_random_params, predict = stax.serial( + Dense(1024), Relu, + Dense(1024), Relu, + Dense(10), LogSoftmax) + +if __name__ == "__main__": + rng = random.PRNGKey(0) + + step_size = 0.001 + num_epochs = 3 + batch_size = 128 + momentum_mass = 0.9 + + train_images, train_labels, test_images, test_labels = datasets.mnist() + num_train = train_images.shape[0] + num_complete_batches, leftover = divmod(num_train, batch_size) + num_batches = num_complete_batches + bool(leftover) + + def data_stream(): + rng = npr.RandomState(0) + while True: + perm = rng.permutation(num_train) + for i in range(num_batches): + batch_idx = perm[i * batch_size:(i + 1) * batch_size] + yield train_images[batch_idx], train_labels[batch_idx] + batches = data_stream() + + opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass) + + @jit + def update(i, opt_state, batch): + params = optimizers.get_params(opt_state) + return opt_update(i, grad(loss)(params, batch), opt_state) + + _, init_params = init_random_params(rng, (-1, 28 * 28)) + opt_state = opt_init(init_params) + itercount = itertools.count() + + print("\nStarting training...") + for epoch in range(num_epochs): + start_time = time.time() + for _ in range(num_batches): + opt_state = update(next(itercount), opt_state, next(batches)) + epoch_time = time.time() - start_time + + params = optimizers.get_params(opt_state) + train_acc = accuracy(params, (train_images, train_labels)) + test_acc = accuracy(params, (test_images, test_labels)) + print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) + print("Training set accuracy {}".format(train_acc)) + print("Test set accuracy {}".format(test_acc)) diff --git a/examples/mnist_classifier.py b/examples/mnist_classifier.py index 16aa8b45a..1f4679039 100644 --- a/examples/mnist_classifier.py +++ b/examples/mnist_classifier.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A basic MNIST example using JAX together with the mini-libraries stax, for -neural network building, and optimizers, for first-order stochastic optimization. +"""A basic MNIST example using JAX with the mini-libraries stax and optimizers. + +The mini-library jax.experimental.stax is for neural network building, and +the mini-library jax.experimentaloptimizers is for first-order stochastic +optimization. """ from __future__ import absolute_import diff --git a/examples/spmd_mnist_classifier_fromscratch.py b/examples/spmd_mnist_classifier_fromscratch.py index 9afee1ae9..c85d6aaf4 100644 --- a/examples/spmd_mnist_classifier_fromscratch.py +++ b/examples/spmd_mnist_classifier_fromscratch.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A basic MNIST example using Numpy and JAX. +"""An MNIST example with single-program multiple-data (SPMD) data parallelism. -The primary aim here is simplicity and minimal dependencies. +The aim here is to illustrate how to use JAX's `pmap` to express and execute +SPMD programs for data parallelism along a batch dimension, while also +minimizing dependencies by avoiding the use of higher-level layers and +optimizers libraries. """ from __future__ import absolute_import diff --git a/examples/spmd_spatially_sharded_conv_net.py b/examples/spmd_spatially_sharded_conv_net.py new file mode 100644 index 000000000..cf59cd684 --- /dev/null +++ b/examples/spmd_spatially_sharded_conv_net.py @@ -0,0 +1,56 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""An MNIST example for single-program multiple-data (SPMD) spatial parallelism. + +The aim here is to illustrate how to use JAX's `pmap` to express and execute +SPMD programs for data parallelism along a spatial dimension (rather than a +batch dimension). +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import itertools + +import numpy.random as npr + +import jax.numpy as np +from jax.config import config +from jax import jit, grad, random +from jax.experimental import optimizers +from jax.experimental import stax +from jax.experimental.stax import Relu, LogSoftmax +from examples import datasets + + +def loss(params, batch): + inputs, targets = batch + preds = predict(params, inputs) + return -np.mean(preds * targets) + +def accuracy(params, batch): + inputs, targets = batch + target_class = np.argmax(targets, axis=1) + predicted_class = np.argmax(predict(params, inputs), axis=1) + return np.mean(predicted_class == target_class) + + +init_random_params, predict = stax.serial( + SpmdConv(3, (2, 2), axis_name="x"), Relu, + SpmdConv(3, (2, 2), axis_name="x"), Relu, + SpmdDense(10), LogSoftmax)