From 15da530b033819b2bc2ef762f71a7cb6e56c7867 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 8 Mar 2019 09:59:03 -0800 Subject: [PATCH] add spmd mnist example --- examples/mnist_classifier_fromscratch.py | 2 +- examples/spmd_mnist_classifier_fromscratch.py | 125 ++++++++++++++++++ jax/api.py | 4 +- spmd_toy.py | 72 ---------- 4 files changed, 129 insertions(+), 74 deletions(-) create mode 100644 examples/spmd_mnist_classifier_fromscratch.py delete mode 100644 spmd_toy.py diff --git a/examples/mnist_classifier_fromscratch.py b/examples/mnist_classifier_fromscratch.py index f79861a58..57bee8401 100644 --- a/examples/mnist_classifier_fromscratch.py +++ b/examples/mnist_classifier_fromscratch.py @@ -59,7 +59,7 @@ def accuracy(params, batch): if __name__ == "__main__": - layer_sizes = [784, 1024, 1024, 10] # TODO(mattjj): revise to standard arch + layer_sizes = [784, 1024, 1024, 10] param_scale = 0.1 step_size = 0.001 num_epochs = 10 diff --git a/examples/spmd_mnist_classifier_fromscratch.py b/examples/spmd_mnist_classifier_fromscratch.py new file mode 100644 index 000000000..0559f0567 --- /dev/null +++ b/examples/spmd_mnist_classifier_fromscratch.py @@ -0,0 +1,125 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A basic MNIST example using Numpy and JAX. + +The primary aim here is simplicity and minimal dependencies. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import partial +import time + +import numpy.random as npr + +from jax import jit, grad, pmap, replicate, unreplicate +from jax.config import config +from jax.scipy.special import logsumexp +from jax.lib import xla_bridge +from jax import lax +import jax.numpy as np +from examples import datasets + + +def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): + return [(scale * rng.randn(m, n), scale * rng.randn(n)) + for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] + +def predict(params, inputs): + activations = inputs + for w, b in params[:-1]: + outputs = np.dot(activations, w) + b + activations = np.tanh(outputs) + + final_w, final_b = params[-1] + logits = np.dot(activations, final_w) + final_b + return logits - logsumexp(logits, axis=1, keepdims=True) + +def loss(params, batch): + inputs, targets = batch + preds = predict(params, inputs) + return -np.mean(preds * targets) + +@jit +def accuracy(params, batch): + inputs, targets = batch + target_class = np.argmax(targets, axis=1) + predicted_class = np.argmax(predict(params, inputs), axis=1) + return np.mean(predicted_class == target_class) + + +if __name__ == "__main__": + layer_sizes = [784, 1024, 1024, 10] + param_scale = 0.1 + step_size = 0.001 + num_epochs = 10 + batch_size = 128 + + train_images, train_labels, test_images, test_labels = datasets.mnist() + num_train = train_images.shape[0] + num_complete_batches, leftover = divmod(num_train, batch_size) + num_batches = num_complete_batches + bool(leftover) + + + # For this manual SPMD example, we get the number of devices (e.g. GPUs or + # TPUs) that we're using, and use it to reshape data minibatches. + num_devices = xla_bridge.device_count() + def data_stream(): + rng = npr.RandomState(0) + while True: + perm = rng.permutation(num_train) + for i in range(num_batches): + batch_idx = perm[i * batch_size:(i + 1) * batch_size] + images, labels = train_images[batch_idx], train_labels[batch_idx] + # For this SPMD example, we reshape the data batch dimension into two + # batch dimensions, one of which is mapped over parallel devices. + batch_size_per_device, ragged = divmod(images.shape[0], num_devices) + if ragged: + msg = "batch size must be divisible by device count, got {} and {}." + raise ValueError(msg.format(batch_size, num_devices)) + shape_prefix = (num_devices, batch_size_per_device) + images = images.reshape(shape_prefix + images.shape[1:]) + labels = labels.reshape(shape_prefix + labels.shape[1:]) + yield images, labels + batches = data_stream() + + @partial(pmap, axis_name='batch') + def spmd_update(params, batch): + grads = grad(loss)(params, batch) + # We compute the total gradients, summing across the device-mapped axis, + # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. + grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads] + return [(w - step_size * dw, b - step_size * db) + for (w, b), (dw, db) in zip(params, grads)] + + # We replicate parameters out across devices. (Check the implementation of + # replicate; analogous to device_put, it's a simple wrapper around pmap.) + params = replicate(init_random_params(param_scale, layer_sizes)) + + for epoch in range(num_epochs): + start_time = time.time() + for _ in range(num_batches): + params = spmd_update(params, next(batches)) + epoch_time = time.time() - start_time + + # We evaluate using the jitted `accuracy` function (not using pmap) by + # grabbing just one of the replicated parameter values. + train_acc = accuracy(unreplicate(params), (train_images, train_labels)) + test_acc = accuracy(unreplicate(params), (test_images, test_labels)) + print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) + print("Training set accuracy {}".format(train_acc)) + print("Test set accuracy {}".format(test_acc)) diff --git a/jax/api.py b/jax/api.py index 944336c94..2e920bfc9 100644 --- a/jax/api.py +++ b/jax/api.py @@ -43,7 +43,7 @@ from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef, tree_transpose, leaf) from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip, WrapHashably, prod) -from .lib.xla_bridge import canonicalize_dtype +from .lib.xla_bridge import canonicalize_dtype, device_count from .abstract_arrays import ShapedArray from .interpreters import partial_eval as pe from .interpreters import xla @@ -678,6 +678,8 @@ tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals) device_put = jit(lambda x: x) device_get_array = lambda x: x.copy() if type(x) is xla.DeviceArray else x device_get = partial(tree_map, device_get_array) +replicate = lambda x: pmap(lambda _: x)(onp.arange(device_count())) +unreplicate = lambda x: tree_map(op.itemgetter(0), x) def _argnums_partial(f, dyn_argnums, args): diff --git a/spmd_toy.py b/spmd_toy.py deleted file mode 100644 index 21936e6c2..000000000 --- a/spmd_toy.py +++ /dev/null @@ -1,72 +0,0 @@ -from functools import partial -import operator as op - -import numpy as onp - -import jax.numpy as np -from jax import pmap, serial_pmap, grad -from jax import lax -from jax.tree_util import tree_map -from jax.lib.xla_bridge import device_count - - -step_size = 0.01 -rng = onp.random.RandomState(0) - -def predict(params, inputs): - for W, b in params: - outputs = np.dot(inputs, W) + b - inputs = np.tanh(outputs) - return outputs - -def loss(params, batch): - inputs, targets = batch - predictions = predict(params, inputs) - return np.sum((predictions - targets)**2) - -def update(params, batch): - grads = grad(loss)(params, batch) - new_params = [(W - step_size * dW, b - step_size * db) - for (W, b), (dW, db) in zip(params, grads)] - return new_params - -# initialize parameters -layer_sizes = [2, 4, 3] # input size 2, output size 3 -scale = 0.01 -params = [(scale * rng.randn(m, n), scale * rng.randn(n)) - for m, n in zip(layer_sizes[:-1], layer_sizes[1:])] - -# set up fake data -inputs = rng.randn(10, 2) # batch size 10, feature size 2 -targets = rng.randn(10, 3) # batch size 10, output size 3 -batch = (inputs, targets) - - -# standard functions -print(loss(params, batch)) -print(update(params, batch)[0][0]) - - -# reshape / replicate data -num_devices = device_count() -spmd_params = tree_map(partial(lax.broadcast, sizes=(num_devices,)), params) -spmd_inputs = inputs.reshape((num_devices, -1, 2)) -spmd_targets = targets.reshape((num_devices, -1, 3)) -spmd_batch = (spmd_inputs, spmd_targets) - -@partial(pmap, axis_name='i') -def spmd_loss(params, batch): - inputs, targets = batch - predictions = predict(params, inputs) - batch_loss = np.sum((predictions - targets)**2) - return lax.psum(batch_loss, 'i') -print(spmd_loss(spmd_params, spmd_batch)) - -@partial(pmap, axis_name='i') -def spmd_update(params, batch): - grads = grad(loss)(params, batch) # loss, not spmd_loss - grads = tree_map(partial(lax.psum, axis_name='i'), grads) - new_params = [(W - step_size * dW, b - step_size * db) - for (W, b), (dW, db) in zip(params, grads)] - return new_params -print(spmd_update(spmd_params, spmd_batch)[0][0])