mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix stax initialization rng bug, remove temp file
This commit is contained in:
parent
92e5f93a29
commit
8ffb9417e7
@ -1,56 +0,0 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""An MNIST example for single-program multiple-data (SPMD) spatial parallelism.
|
||||
|
||||
The aim here is to illustrate how to use JAX's `pmap` to express and execute
|
||||
SPMD programs for data parallelism along a spatial dimension (rather than a
|
||||
batch dimension).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import itertools
|
||||
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Relu, LogSoftmax
|
||||
from examples import datasets
|
||||
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(preds * targets)
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
target_class = np.argmax(targets, axis=1)
|
||||
predicted_class = np.argmax(predict(params, inputs), axis=1)
|
||||
return np.mean(predicted_class == target_class)
|
||||
|
||||
|
||||
init_random_params, predict = stax.serial(
|
||||
SpmdConv(3, (2, 2), axis_name="x"), Relu,
|
||||
SpmdConv(3, (2, 2), axis_name="x"), Relu,
|
||||
SpmdDense(10), LogSoftmax)
|
@ -64,13 +64,13 @@ def randn(stddev=1e-2):
|
||||
return (stddev * random.normal(rng, shape)).astype('float32')
|
||||
return init
|
||||
|
||||
def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2)):
|
||||
def glorot(out_axis=0, in_axis=1, scale=onp.sqrt(2)):
|
||||
"""An initializer function for random Glorot-scaled coefficients."""
|
||||
def init(rng, shape):
|
||||
fan_in, fan_out = shape[in_dim], shape[out_dim]
|
||||
size = onp.prod(onp.delete(shape, [in_dim, out_dim]))
|
||||
fan_in, fan_out = shape[in_axis], shape[out_axis]
|
||||
size = onp.prod(onp.delete(shape, [in_axis, out_axis]))
|
||||
std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
|
||||
return (std * random.normal(rng, shape)).astype('float32')
|
||||
return std * random.normal(rng, shape, dtype=np.float32)
|
||||
return init
|
||||
|
||||
zeros = lambda rng, shape: np.zeros(shape, dtype='float32')
|
||||
@ -89,7 +89,8 @@ def Dense(out_dim, W_init=glorot(), b_init=randn()):
|
||||
"""Layer constructor function for a dense (fully-connected) layer."""
|
||||
def init_fun(rng, input_shape):
|
||||
output_shape = input_shape[:-1] + (out_dim,)
|
||||
W, b = W_init(rng, (input_shape[-1], out_dim)), b_init(rng, (out_dim,))
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
@ -113,7 +114,8 @@ def GeneralConv(dimension_numbers, out_chan, filter_shape,
|
||||
input_shape, kernel_shape, strides, padding, dimension_numbers)
|
||||
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
|
||||
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
|
||||
W, b = W_init(rng, kernel_shape), b_init(rng, bias_shape)
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
@ -140,7 +142,8 @@ def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
|
||||
input_shape, kernel_shape, strides, padding, dimension_numbers)
|
||||
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
|
||||
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
|
||||
W, b = W_init(rng, kernel_shape), b_init(rng, bias_shape)
|
||||
k1, k2 = random.split(rng)
|
||||
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
@ -160,7 +163,8 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
|
||||
axis = (axis,) if np.isscalar(axis) else axis
|
||||
def init_fun(rng, input_shape):
|
||||
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
|
||||
beta, gamma = _beta_init(rng, shape), _gamma_init(rng, shape)
|
||||
k1, k2 = random.split(rng)
|
||||
beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
|
||||
return input_shape, (beta, gamma)
|
||||
def apply_fun(params, x, **kwargs):
|
||||
beta, gamma = params
|
||||
|
@ -424,15 +424,30 @@ class ShardedDeviceArray(xla.DeviceArray):
|
||||
self.ndim, self.size = len(aval.shape), prod(aval.shape)
|
||||
self._npy_value = None
|
||||
|
||||
def _ids(self):
|
||||
num_bufs = len(self.device_buffers)
|
||||
assignments = assign_shards_to_replicas(num_bufs, self.shape[0])
|
||||
_, ids = onp.unique(assignments, return_index=True)
|
||||
return ids
|
||||
|
||||
@property
|
||||
def _value(self):
|
||||
if self._npy_value is None:
|
||||
ids = self._ids()
|
||||
npy_shards = [buf.to_py() for buf in self.device_buffers]
|
||||
assignments = assign_shards_to_replicas(len(npy_shards), self.shape[0])
|
||||
_, ids = onp.unique(assignments, return_index=True)
|
||||
self._npy_value = onp.stack([npy_shards[i] for i in ids])
|
||||
return self._npy_value
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self._npy_value is None and type(idx) is int:
|
||||
# When we don't have a copy of the data on the host, and we're just trying
|
||||
# to extract a simple integer-indexed slice of the logical array, we can
|
||||
# avoid transferring from all the devices and just communicate with one.
|
||||
ids = self._ids()
|
||||
return self.device_buffers[ids[idx]].to_py()
|
||||
else:
|
||||
return super(ShardedDeviceArray, self).__getitem__(idx)
|
||||
|
||||
core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
|
||||
xla.pytype_aval_mappings[ShardedDeviceArray] = \
|
||||
xla.pytype_aval_mappings[xla.DeviceArray]
|
||||
|
Loading…
x
Reference in New Issue
Block a user