rocm_jax/examples/spmd_mnist_classifier_fromscratch.py

129 lines
4.9 KiB
Python
Raw Normal View History

# Copyright 2018 The JAX Authors.
2019-03-08 09:59:03 -08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2019-05-23 09:07:44 -07:00
"""An MNIST example with single-program multiple-data (SPMD) data parallelism.
2019-03-08 09:59:03 -08:00
2019-05-23 09:07:44 -07:00
The aim here is to illustrate how to use JAX's `pmap` to express and execute
SPMD programs for data parallelism along a batch dimension, while also
minimizing dependencies by avoiding the use of higher-level layers and
optimizers libraries.
2019-03-08 09:59:03 -08:00
"""
from functools import partial
import time
import numpy as np
2019-03-08 09:59:03 -08:00
import numpy.random as npr
import jax
from jax import jit, grad, pmap
2019-03-08 09:59:03 -08:00
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
2019-03-08 09:59:03 -08:00
from jax import lax
import jax.numpy as jnp
2019-03-08 09:59:03 -08:00
from examples import datasets
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
def predict(params, inputs):
activations = inputs
for w, b in params[:-1]:
outputs = jnp.dot(activations, w) + b
activations = jnp.tanh(outputs)
2019-03-08 09:59:03 -08:00
final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b
2019-03-08 09:59:03 -08:00
return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
2019-03-08 09:59:03 -08:00
@jit
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
2019-03-08 09:59:03 -08:00
if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10]
param_scale = 0.1
step_size = 0.001
num_epochs = 10
batch_size = 128
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
# For this manual SPMD example, we get the number of devices (e.g. GPUs or
# TPU cores) that we're using, and use it to reshape data minibatches.
num_devices = jax.device_count()
2019-03-08 09:59:03 -08:00
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
images, labels = train_images[batch_idx], train_labels[batch_idx]
# For this SPMD example, we reshape the data batch dimension into two
# batch dimensions, one of which is mapped over parallel devices.
batch_size_per_device, ragged = divmod(images.shape[0], num_devices)
if ragged:
msg = "batch size must be divisible by device count, got {} and {}."
raise ValueError(msg.format(batch_size, num_devices))
shape_prefix = (num_devices, batch_size_per_device)
images = images.reshape(shape_prefix + images.shape[1:])
labels = labels.reshape(shape_prefix + labels.shape[1:])
yield images, labels
batches = data_stream()
@partial(pmap, axis_name='batch')
def spmd_update(params, batch):
grads = grad(loss)(params, batch)
# We compute the total gradients, summing across the device-mapped axis,
# using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum.
grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads]
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
# We replicate the parameters so that the constituent arrays have a leading
# dimension of size equal to the number of devices we're pmapping over.
init_params = init_random_params(param_scale, layer_sizes)
replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape)
replicated_params = tree_map(replicate_array, init_params)
2019-03-08 09:59:03 -08:00
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
replicated_params = spmd_update(replicated_params, next(batches))
2019-03-08 09:59:03 -08:00
epoch_time = time.time() - start_time
# We evaluate using the jitted `accuracy` function (not using pmap) by
# grabbing just one of the replicated parameter values.
params = tree_map(lambda x: x[0], replicated_params)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc}")
print(f"Test set accuracy {test_acc}")