fix stax initialization rng bug, remove temp file

This commit is contained in:
Matthew Johnson 2019-05-23 11:28:15 -07:00
parent 92e5f93a29
commit 8ffb9417e7
3 changed files with 29 additions and 66 deletions

View File

@ -1,56 +0,0 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An MNIST example for single-program multiple-data (SPMD) spatial parallelism.
The aim here is to illustrate how to use JAX's `pmap` to express and execute
SPMD programs for data parallelism along a spatial dimension (rather than a
batch dimension).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import itertools
import numpy.random as npr
import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Relu, LogSoftmax
from examples import datasets
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
def accuracy(params, batch):
inputs, targets = batch
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
SpmdConv(3, (2, 2), axis_name="x"), Relu,
SpmdConv(3, (2, 2), axis_name="x"), Relu,
SpmdDense(10), LogSoftmax)

View File

@ -64,13 +64,13 @@ def randn(stddev=1e-2):
return (stddev * random.normal(rng, shape)).astype('float32')
return init
def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2)):
def glorot(out_axis=0, in_axis=1, scale=onp.sqrt(2)):
"""An initializer function for random Glorot-scaled coefficients."""
def init(rng, shape):
fan_in, fan_out = shape[in_dim], shape[out_dim]
size = onp.prod(onp.delete(shape, [in_dim, out_dim]))
fan_in, fan_out = shape[in_axis], shape[out_axis]
size = onp.prod(onp.delete(shape, [in_axis, out_axis]))
std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
return (std * random.normal(rng, shape)).astype('float32')
return std * random.normal(rng, shape, dtype=np.float32)
return init
zeros = lambda rng, shape: np.zeros(shape, dtype='float32')
@ -89,7 +89,8 @@ def Dense(out_dim, W_init=glorot(), b_init=randn()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
W, b = W_init(rng, (input_shape[-1], out_dim)), b_init(rng, (out_dim,))
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
@ -113,7 +114,8 @@ def GeneralConv(dimension_numbers, out_chan, filter_shape,
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
W, b = W_init(rng, kernel_shape), b_init(rng, bias_shape)
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
@ -140,7 +142,8 @@ def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
W, b = W_init(rng, kernel_shape), b_init(rng, bias_shape)
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
@ -160,7 +163,8 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
axis = (axis,) if np.isscalar(axis) else axis
def init_fun(rng, input_shape):
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
beta, gamma = _beta_init(rng, shape), _gamma_init(rng, shape)
k1, k2 = random.split(rng)
beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
return input_shape, (beta, gamma)
def apply_fun(params, x, **kwargs):
beta, gamma = params

View File

@ -424,15 +424,30 @@ class ShardedDeviceArray(xla.DeviceArray):
self.ndim, self.size = len(aval.shape), prod(aval.shape)
self._npy_value = None
def _ids(self):
num_bufs = len(self.device_buffers)
assignments = assign_shards_to_replicas(num_bufs, self.shape[0])
_, ids = onp.unique(assignments, return_index=True)
return ids
@property
def _value(self):
if self._npy_value is None:
ids = self._ids()
npy_shards = [buf.to_py() for buf in self.device_buffers]
assignments = assign_shards_to_replicas(len(npy_shards), self.shape[0])
_, ids = onp.unique(assignments, return_index=True)
self._npy_value = onp.stack([npy_shards[i] for i in ids])
return self._npy_value
def __getitem__(self, idx):
if self._npy_value is None and type(idx) is int:
# When we don't have a copy of the data on the host, and we're just trying
# to extract a simple integer-indexed slice of the logical array, we can
# avoid transferring from all the devices and just communicate with one.
ids = self._ids()
return self.device_buffers[ids[idx]].to_py()
else:
return super(ShardedDeviceArray, self).__getitem__(idx)
core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
xla.pytype_aval_mappings[ShardedDeviceArray] = \
xla.pytype_aval_mappings[xla.DeviceArray]