Use jax.random for stax initialization

This commit is contained in:
Jamie Townsend 2019-04-03 12:54:02 +01:00
parent ca9151acc9
commit 1c9b9a57fd
7 changed files with 66 additions and 48 deletions

View File

@ -560,6 +560,7 @@ Heres an example:
```python
import jax.numpy as np
from jax import random
from jax.experimental import stax
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
@ -573,8 +574,9 @@ net_init, net_apply = stax.serial(
)
# Initialize parameters, not committing to a batch shape
rng = random.PRNGKey(0)
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(in_shape)
out_shape, net_params = net_init(rng, in_shape)
# Apply network to dummy inputs
inputs = np.zeros((128, 28, 28, 1))
@ -689,7 +691,7 @@ specialized on shapes and dtypes, but not specialized all the way to concrete
values, the Python code under a `jit` decorator must be applicable to abstract
values. If we try to evaluate `x > 0` on an abstract `x`, the result is an
abstract value representing the set `{True, False}`, and so a Python branch like
`if x > 0` will raise an error: it doesnt know which way to go!
`if x > 0` will raise an error: it doesnt know which way to go!
See [Whats supported](#whats-supported) for more
information about `jit` requirements.

View File

@ -27,7 +27,7 @@ import numpy.random as npr
import jax.numpy as np
from jax.config import config
from jax import jit, grad
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
@ -51,6 +51,8 @@ init_random_params, predict = stax.serial(
Dense(10), LogSoftmax)
if __name__ == "__main__":
rng = random.PRNGKey(0)
step_size = 0.001
num_epochs = 10
batch_size = 128
@ -77,7 +79,7 @@ if __name__ == "__main__":
params = optimizers.get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params((-1, 28 * 28))
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

View File

@ -97,8 +97,9 @@ if __name__ == "__main__":
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
num_batches = num_complete_batches + bool(leftover)
_, init_encoder_params = encoder_init((batch_size, 28 * 28))
_, init_decoder_params = decoder_init((batch_size, 10))
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params
opt_init, opt_update = optimizers.momentum(step_size, mass=0.9)

View File

@ -27,7 +27,7 @@ from six.moves import xrange
import jax.numpy as np
from jax.config import config
from jax import jit, grad
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
@ -87,6 +87,8 @@ def ResNet50(num_classes):
if __name__ == "__main__":
rng_key = random.PRNGKey(0)
batch_size = 8
num_classes = 1001
input_shape = (224, 224, 3, batch_size)
@ -94,7 +96,7 @@ if __name__ == "__main__":
num_steps = 10
init_fun, predict_fun = ResNet50(num_classes)
_, init_params = init_fun(input_shape)
_, init_params = init_fun(rng_key, input_shape)
def loss(params, batch):
inputs, targets = batch

View File

@ -26,7 +26,6 @@ import itertools
import operator as op
import numpy as onp
import numpy.random as npr
from six.moves import reduce
from jax import lax
@ -59,37 +58,38 @@ def fastvar(x, axis, keepdims):
# Initializers
def randn(stddev=1e-2, rng=npr):
def randn(stddev=1e-2):
"""An initializer function for random normal coefficients."""
def init(shape):
return rng.normal(size=shape, scale=stddev).astype('float32')
def init(rng, shape):
return stddev * random.normal(rng, shape)
return init
def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2), rng=npr):
def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2)):
"""An initializer function for random Glorot-scaled coefficients."""
def init(shape):
def init(rng, shape):
fan_in, fan_out = shape[in_dim], shape[out_dim]
size = onp.prod(onp.delete(shape, [in_dim, out_dim]))
std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
return rng.normal(size=shape, scale=std).astype('float32')
return std * random.normal(rng, shape)
return init
zeros = functools.partial(np.zeros, dtype='float32')
ones = functools.partial(np.ones, dtype='float32')
zeros = lambda rng, shape: np.zeros(shape, dtype='float32')
ones = lambda rng, shape: np.ones(shape, dtype='float32')
# Layers
# Each layer constructor function returns an (init_fun, apply_fun) pair, where
# init_fun: takes an input shape and returns an (output_shape, params) pair,
# init_fun: takes an rng key and an input shape and returns an
# (output_shape, params) pair,
# apply_fun: takes params, inputs, and an rng key and applies the layer.
def Dense(out_dim, W_init=glorot(), b_init=randn()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(input_shape):
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
W, b = W_init((input_shape[-1], out_dim)), b_init((out_dim,))
W, b = W_init(rng, (input_shape[-1], out_dim)), b_init(rng, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
@ -104,7 +104,7 @@ def GeneralConv(dimension_numbers, out_chan, filter_shape,
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot(rhs_spec.index('O'), rhs_spec.index('I'))
def init_fun(input_shape):
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
@ -113,7 +113,7 @@ def GeneralConv(dimension_numbers, out_chan, filter_shape,
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
W, b = W_init(kernel_shape), b_init(bias_shape)
W, b = W_init(rng, kernel_shape), b_init(rng, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
@ -126,12 +126,12 @@ Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
_beta_init = lambda shape: beta_init(shape) if center else ()
_gamma_init = lambda shape: gamma_init(shape) if scale else ()
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
axis = (axis,) if np.isscalar(axis) else axis
def init_fun(input_shape):
def init_fun(rng, input_shape):
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
beta, gamma = _beta_init(shape), _gamma_init(shape)
beta, gamma = _beta_init(rng, shape), _gamma_init(rng, shape)
return input_shape, (beta, gamma)
def apply_fun(params, x, **kwargs):
beta, gamma = params
@ -150,7 +150,7 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
def _elemwise_no_params(fun, **fun_kwargs):
init_fun = lambda input_shape: (input_shape, ())
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
return init_fun, apply_fun
Tanh = _elemwise_no_params(np.tanh)
@ -168,7 +168,7 @@ def _pooling_layer(reducer, init_val, rescaler=None):
rescale = rescaler(window_shape, strides, padding) if rescaler else None
dims = (1,) + window_shape + (1,) # NHWC
strides = (1,) + strides + (1,)
def init_fun(input_shape):
def init_fun(rng, input_shape):
out_shape = lax.reduce_window_shape_tuple(input_shape, dims, strides, padding)
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
@ -191,7 +191,7 @@ AvgPool = _pooling_layer(lax.add, 0., _normalize_by_window_size)
def Flatten():
"""Layer construction function for flattening all but the leading dim."""
def init_fun(input_shape):
def init_fun(rng, input_shape):
output_shape = input_shape[0], reduce(op.mul, input_shape[1:], 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
@ -202,7 +202,7 @@ Flatten = Flatten()
def Identity():
"""Layer construction function for an identity layer."""
init_fun = lambda input_shape: (input_shape, ())
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: inputs
return init_fun, apply_fun
Identity = Identity()
@ -210,14 +210,14 @@ Identity = Identity()
def FanOut(num):
"""Layer construction function for a fan-out layer."""
init_fun = lambda input_shape: ([input_shape] * num, ())
init_fun = lambda rng, input_shape: ([input_shape] * num, ())
apply_fun = lambda params, inputs, **kwargs: [inputs] * num
return init_fun, apply_fun
def FanInSum():
"""Layer construction function for a fan-in sum layer."""
init_fun = lambda input_shape: (input_shape[0], ())
init_fun = lambda rng, input_shape: (input_shape[0], ())
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
return init_fun, apply_fun
FanInSum = FanInSum()
@ -225,7 +225,7 @@ FanInSum = FanInSum()
def FanInConcat(axis=-1):
"""Layer construction function for a fan-in concatenation layer."""
def init_fun(input_shape):
def init_fun(rng, input_shape):
ax = axis % len(input_shape[0])
concat_size = sum(shape[ax] for shape in input_shape)
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
@ -237,7 +237,7 @@ def FanInConcat(axis=-1):
def Dropout(rate, mode='train'):
"""Layer construction function for a dropout layer with given rate."""
def init_fun(input_shape):
def init_fun(rng, input_shape):
return input_shape, ()
def apply_fun(params, inputs, **kwargs):
rng = kwargs.get('rng', None)
@ -270,10 +270,11 @@ def serial(*layers):
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(input_shape):
def init_fun(rng, input_shape):
params = []
for init_fun in init_funs:
input_shape, param = init_fun(input_shape)
rng, layer_rng = random.split(rng)
input_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
return input_shape, params
def apply_fun(params, inputs, **kwargs):
@ -302,8 +303,10 @@ def parallel(*layers):
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(input_shape):
return zip(*[init(shape) for init, shape in zip(init_funs, input_shape)])
def init_fun(rng, input_shape):
rngs = random.split(rng, len(init_funs))
return zip(*[init(rng, shape) for init, rng, shape
in zip(init_funs, rngs, input_shape)])
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
@ -323,8 +326,8 @@ def shape_dependent(make_layer):
layer as returned by `make_layer` but with its construction delayed until
input shapes are known.
"""
def init_fun(input_shape):
return make_layer(input_shape)[0](input_shape)
def init_fun(rng, input_shape):
return make_layer(input_shape)[0](rng, input_shape)
def apply_fun(params, inputs, **kwargs):
return make_layer(inputs.shape)[1](params, inputs, **kwargs)
return init_fun, apply_fun

View File

@ -116,6 +116,7 @@
"from jax import vmap # for auto-vectorizing functions\n",
"from functools import partial # for use with vmap\n",
"from jax import jit # for compiling functions for speedup\n",
"from jax import random # Stax initialization uses jax.random
"from jax.experimental import stax # neural network library\n",
"from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers\n",
"import matplotlib.pyplot as plt # visualization"
@ -134,7 +135,8 @@
" Dense(1)\n",
")\n",
"in_shape = (-1, 1,)\n",
"out_shape, net_params = net_init(in_shape)"
"rng = random.PRNGKey(0)
"out_shape, net_params = net_init(rng, in_shape)"
]
},
{

View File

@ -40,9 +40,10 @@ def random_inputs(rng, input_shape):
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
result_shape, params = init_fun(input_shape)
inputs = random_inputs(onp.random.RandomState(0), input_shape)
rng_key = random.PRNGKey(0)
rng_key, init_key = random.split(rng_key)
result_shape, params = init_fun(init_key, input_shape)
inputs = random_inputs(onp.random.RandomState(0), input_shape)
result = apply_fun(params, inputs, rng=rng_key)
test_case.assertEqual(result.shape, result_shape)
@ -53,14 +54,16 @@ class StaxTest(jtu.JaxTestCase):
{"testcase_name": "_shape={}".format(shape), "shape": shape}
for shape in [(2, 3), (5,)]))
def testRandnInitShape(self, shape):
out = stax.randn()(shape)
key = random.PRNGKey(0)
out = stax.randn()(key, shape)
self.assertEqual(out.shape, shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape), "shape": shape}
for shape in [(2, 3), (2, 3, 4)]))
def testGlorotInitShape(self, shape):
out = stax.glorot()(shape)
key = random.PRNGKey(0)
out = stax.glorot()(key, shape)
self.assertEqual(out.shape, shape)
@parameterized.named_parameters(jtu.cases_from_list(
@ -164,11 +167,12 @@ class StaxTest(jtu.JaxTestCase):
_CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
def testIssue182(self):
key = random.PRNGKey(0)
init_fun, apply_fun = stax.Softmax
input_shape = (10, 3)
inputs = onp.arange(30.).astype("float32").reshape(input_shape)
out_shape, params = init_fun(input_shape)
out_shape, params = init_fun(key, input_shape)
out = apply_fun(params, inputs)
assert out_shape == out.shape
@ -176,11 +180,12 @@ class StaxTest(jtu.JaxTestCase):
def testBatchNormShapeNHWC(self):
key = random.PRNGKey(0)
init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
input_shape = (4, 5, 6, 7)
inputs = random_inputs(onp.random.RandomState(0), input_shape)
out_shape, params = init_fun(input_shape)
out_shape, params = init_fun(key, input_shape)
out = apply_fun(params, inputs)
self.assertEqual(out_shape, input_shape)
@ -190,12 +195,13 @@ class StaxTest(jtu.JaxTestCase):
self.assertEqual(out_shape, out.shape)
def testBatchNormShapeNCHW(self):
key = random.PRNGKey(0)
# Regression test for https://github.com/google/jax/issues/461
init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
input_shape = (4, 5, 6, 7)
inputs = random_inputs(onp.random.RandomState(0), input_shape)
out_shape, params = init_fun(input_shape)
out_shape, params = init_fun(key, input_shape)
out = apply_fun(params, inputs)
self.assertEqual(out_shape, input_shape)