add spmd mnist example

This commit is contained in:
Matthew Johnson 2019-03-08 09:59:03 -08:00
parent 258caf8294
commit 15da530b03
4 changed files with 129 additions and 74 deletions

View File

@ -59,7 +59,7 @@ def accuracy(params, batch):
if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10] # TODO(mattjj): revise to standard arch
layer_sizes = [784, 1024, 1024, 10]
param_scale = 0.1
step_size = 0.001
num_epochs = 10

View File

@ -0,0 +1,125 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic MNIST example using Numpy and JAX.
The primary aim here is simplicity and minimal dependencies.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import time
import numpy.random as npr
from jax import jit, grad, pmap, replicate, unreplicate
from jax.config import config
from jax.scipy.special import logsumexp
from jax.lib import xla_bridge
from jax import lax
import jax.numpy as np
from examples import datasets
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
def predict(params, inputs):
activations = inputs
for w, b in params[:-1]:
outputs = np.dot(activations, w) + b
activations = np.tanh(outputs)
final_w, final_b = params[-1]
logits = np.dot(activations, final_w) + final_b
return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
@jit
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10]
param_scale = 0.1
step_size = 0.001
num_epochs = 10
batch_size = 128
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
# For this manual SPMD example, we get the number of devices (e.g. GPUs or
# TPUs) that we're using, and use it to reshape data minibatches.
num_devices = xla_bridge.device_count()
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
images, labels = train_images[batch_idx], train_labels[batch_idx]
# For this SPMD example, we reshape the data batch dimension into two
# batch dimensions, one of which is mapped over parallel devices.
batch_size_per_device, ragged = divmod(images.shape[0], num_devices)
if ragged:
msg = "batch size must be divisible by device count, got {} and {}."
raise ValueError(msg.format(batch_size, num_devices))
shape_prefix = (num_devices, batch_size_per_device)
images = images.reshape(shape_prefix + images.shape[1:])
labels = labels.reshape(shape_prefix + labels.shape[1:])
yield images, labels
batches = data_stream()
@partial(pmap, axis_name='batch')
def spmd_update(params, batch):
grads = grad(loss)(params, batch)
# We compute the total gradients, summing across the device-mapped axis,
# using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum.
grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads]
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
# We replicate parameters out across devices. (Check the implementation of
# replicate; analogous to device_put, it's a simple wrapper around pmap.)
params = replicate(init_random_params(param_scale, layer_sizes))
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
params = spmd_update(params, next(batches))
epoch_time = time.time() - start_time
# We evaluate using the jitted `accuracy` function (not using pmap) by
# grabbing just one of the replicated parameter values.
train_acc = accuracy(unreplicate(params), (train_images, train_labels))
test_acc = accuracy(unreplicate(params), (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))

View File

@ -43,7 +43,7 @@ from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef,
tree_transpose, leaf)
from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
WrapHashably, prod)
from .lib.xla_bridge import canonicalize_dtype
from .lib.xla_bridge import canonicalize_dtype, device_count
from .abstract_arrays import ShapedArray
from .interpreters import partial_eval as pe
from .interpreters import xla
@ -678,6 +678,8 @@ tree_to_pval_tuples = partial(process_pytree, pe.pack_pvals)
device_put = jit(lambda x: x)
device_get_array = lambda x: x.copy() if type(x) is xla.DeviceArray else x
device_get = partial(tree_map, device_get_array)
replicate = lambda x: pmap(lambda _: x)(onp.arange(device_count()))
unreplicate = lambda x: tree_map(op.itemgetter(0), x)
def _argnums_partial(f, dyn_argnums, args):

View File

@ -1,72 +0,0 @@
from functools import partial
import operator as op
import numpy as onp
import jax.numpy as np
from jax import pmap, serial_pmap, grad
from jax import lax
from jax.tree_util import tree_map
from jax.lib.xla_bridge import device_count
step_size = 0.01
rng = onp.random.RandomState(0)
def predict(params, inputs):
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = np.tanh(outputs)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return np.sum((predictions - targets)**2)
def update(params, batch):
grads = grad(loss)(params, batch)
new_params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
return new_params
# initialize parameters
layer_sizes = [2, 4, 3] # input size 2, output size 3
scale = 0.01
params = [(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]
# set up fake data
inputs = rng.randn(10, 2) # batch size 10, feature size 2
targets = rng.randn(10, 3) # batch size 10, output size 3
batch = (inputs, targets)
# standard functions
print(loss(params, batch))
print(update(params, batch)[0][0])
# reshape / replicate data
num_devices = device_count()
spmd_params = tree_map(partial(lax.broadcast, sizes=(num_devices,)), params)
spmd_inputs = inputs.reshape((num_devices, -1, 2))
spmd_targets = targets.reshape((num_devices, -1, 3))
spmd_batch = (spmd_inputs, spmd_targets)
@partial(pmap, axis_name='i')
def spmd_loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
batch_loss = np.sum((predictions - targets)**2)
return lax.psum(batch_loss, 'i')
print(spmd_loss(spmd_params, spmd_batch))
@partial(pmap, axis_name='i')
def spmd_update(params, batch):
grads = grad(loss)(params, batch) # loss, not spmd_loss
grads = tree_map(partial(lax.psum, axis_name='i'), grads)
new_params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
return new_params
print(spmd_update(spmd_params, spmd_batch)[0][0])