mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Use jax.random for stax initialization
This commit is contained in:
parent
ca9151acc9
commit
1c9b9a57fd
@ -560,6 +560,7 @@ Here’s an example:
|
||||
|
||||
```python
|
||||
import jax.numpy as np
|
||||
from jax import random
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
|
||||
|
||||
@ -573,8 +574,9 @@ net_init, net_apply = stax.serial(
|
||||
)
|
||||
|
||||
# Initialize parameters, not committing to a batch shape
|
||||
rng = random.PRNGKey(0)
|
||||
in_shape = (-1, 28, 28, 1)
|
||||
out_shape, net_params = net_init(in_shape)
|
||||
out_shape, net_params = net_init(rng, in_shape)
|
||||
|
||||
# Apply network to dummy inputs
|
||||
inputs = np.zeros((128, 28, 28, 1))
|
||||
@ -689,7 +691,7 @@ specialized on shapes and dtypes, but not specialized all the way to concrete
|
||||
values, the Python code under a `jit` decorator must be applicable to abstract
|
||||
values. If we try to evaluate `x > 0` on an abstract `x`, the result is an
|
||||
abstract value representing the set `{True, False}`, and so a Python branch like
|
||||
`if x > 0` will raise an error: it doesn’t know which way to go!
|
||||
`if x > 0` will raise an error: it doesn’t know which way to go!
|
||||
See [What’s supported](#whats-supported) for more
|
||||
information about `jit` requirements.
|
||||
|
||||
|
@ -27,7 +27,7 @@ import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, Relu, LogSoftmax
|
||||
@ -51,6 +51,8 @@ init_random_params, predict = stax.serial(
|
||||
Dense(10), LogSoftmax)
|
||||
|
||||
if __name__ == "__main__":
|
||||
rng = random.PRNGKey(0)
|
||||
|
||||
step_size = 0.001
|
||||
num_epochs = 10
|
||||
batch_size = 128
|
||||
@ -77,7 +79,7 @@ if __name__ == "__main__":
|
||||
params = optimizers.get_params(opt_state)
|
||||
return opt_update(i, grad(loss)(params, batch), opt_state)
|
||||
|
||||
_, init_params = init_random_params((-1, 28 * 28))
|
||||
_, init_params = init_random_params(rng, (-1, 28 * 28))
|
||||
opt_state = opt_init(init_params)
|
||||
itercount = itertools.count()
|
||||
|
||||
|
@ -97,8 +97,9 @@ if __name__ == "__main__":
|
||||
num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
|
||||
_, init_encoder_params = encoder_init((batch_size, 28 * 28))
|
||||
_, init_decoder_params = decoder_init((batch_size, 10))
|
||||
enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
|
||||
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
|
||||
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
|
||||
init_params = init_encoder_params, init_decoder_params
|
||||
|
||||
opt_init, opt_update = optimizers.momentum(step_size, mass=0.9)
|
||||
|
@ -27,7 +27,7 @@ from six.moves import xrange
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad
|
||||
from jax import jit, grad, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
|
||||
@ -87,6 +87,8 @@ def ResNet50(num_classes):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rng_key = random.PRNGKey(0)
|
||||
|
||||
batch_size = 8
|
||||
num_classes = 1001
|
||||
input_shape = (224, 224, 3, batch_size)
|
||||
@ -94,7 +96,7 @@ if __name__ == "__main__":
|
||||
num_steps = 10
|
||||
|
||||
init_fun, predict_fun = ResNet50(num_classes)
|
||||
_, init_params = init_fun(input_shape)
|
||||
_, init_params = init_fun(rng_key, input_shape)
|
||||
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
|
@ -26,7 +26,6 @@ import itertools
|
||||
import operator as op
|
||||
|
||||
import numpy as onp
|
||||
import numpy.random as npr
|
||||
from six.moves import reduce
|
||||
|
||||
from jax import lax
|
||||
@ -59,37 +58,38 @@ def fastvar(x, axis, keepdims):
|
||||
|
||||
# Initializers
|
||||
|
||||
def randn(stddev=1e-2, rng=npr):
|
||||
def randn(stddev=1e-2):
|
||||
"""An initializer function for random normal coefficients."""
|
||||
def init(shape):
|
||||
return rng.normal(size=shape, scale=stddev).astype('float32')
|
||||
def init(rng, shape):
|
||||
return stddev * random.normal(rng, shape)
|
||||
return init
|
||||
|
||||
def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2), rng=npr):
|
||||
def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2)):
|
||||
"""An initializer function for random Glorot-scaled coefficients."""
|
||||
def init(shape):
|
||||
def init(rng, shape):
|
||||
fan_in, fan_out = shape[in_dim], shape[out_dim]
|
||||
size = onp.prod(onp.delete(shape, [in_dim, out_dim]))
|
||||
std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
|
||||
return rng.normal(size=shape, scale=std).astype('float32')
|
||||
return std * random.normal(rng, shape)
|
||||
return init
|
||||
|
||||
zeros = functools.partial(np.zeros, dtype='float32')
|
||||
ones = functools.partial(np.ones, dtype='float32')
|
||||
zeros = lambda rng, shape: np.zeros(shape, dtype='float32')
|
||||
ones = lambda rng, shape: np.ones(shape, dtype='float32')
|
||||
|
||||
|
||||
# Layers
|
||||
|
||||
# Each layer constructor function returns an (init_fun, apply_fun) pair, where
|
||||
# init_fun: takes an input shape and returns an (output_shape, params) pair,
|
||||
# init_fun: takes an rng key and an input shape and returns an
|
||||
# (output_shape, params) pair,
|
||||
# apply_fun: takes params, inputs, and an rng key and applies the layer.
|
||||
|
||||
|
||||
def Dense(out_dim, W_init=glorot(), b_init=randn()):
|
||||
"""Layer constructor function for a dense (fully-connected) layer."""
|
||||
def init_fun(input_shape):
|
||||
def init_fun(rng, input_shape):
|
||||
output_shape = input_shape[:-1] + (out_dim,)
|
||||
W, b = W_init((input_shape[-1], out_dim)), b_init((out_dim,))
|
||||
W, b = W_init(rng, (input_shape[-1], out_dim)), b_init(rng, (out_dim,))
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
@ -104,7 +104,7 @@ def GeneralConv(dimension_numbers, out_chan, filter_shape,
|
||||
one = (1,) * len(filter_shape)
|
||||
strides = strides or one
|
||||
W_init = W_init or glorot(rhs_spec.index('O'), rhs_spec.index('I'))
|
||||
def init_fun(input_shape):
|
||||
def init_fun(rng, input_shape):
|
||||
filter_shape_iter = iter(filter_shape)
|
||||
kernel_shape = [out_chan if c == 'O' else
|
||||
input_shape[lhs_spec.index('C')] if c == 'I' else
|
||||
@ -113,7 +113,7 @@ def GeneralConv(dimension_numbers, out_chan, filter_shape,
|
||||
input_shape, kernel_shape, strides, padding, dimension_numbers)
|
||||
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
|
||||
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
|
||||
W, b = W_init(kernel_shape), b_init(bias_shape)
|
||||
W, b = W_init(rng, kernel_shape), b_init(rng, bias_shape)
|
||||
return output_shape, (W, b)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
W, b = params
|
||||
@ -126,12 +126,12 @@ Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
|
||||
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
|
||||
beta_init=zeros, gamma_init=ones):
|
||||
"""Layer construction function for a batch normalization layer."""
|
||||
_beta_init = lambda shape: beta_init(shape) if center else ()
|
||||
_gamma_init = lambda shape: gamma_init(shape) if scale else ()
|
||||
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
|
||||
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
|
||||
axis = (axis,) if np.isscalar(axis) else axis
|
||||
def init_fun(input_shape):
|
||||
def init_fun(rng, input_shape):
|
||||
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
|
||||
beta, gamma = _beta_init(shape), _gamma_init(shape)
|
||||
beta, gamma = _beta_init(rng, shape), _gamma_init(rng, shape)
|
||||
return input_shape, (beta, gamma)
|
||||
def apply_fun(params, x, **kwargs):
|
||||
beta, gamma = params
|
||||
@ -150,7 +150,7 @@ def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
|
||||
|
||||
|
||||
def _elemwise_no_params(fun, **fun_kwargs):
|
||||
init_fun = lambda input_shape: (input_shape, ())
|
||||
init_fun = lambda rng, input_shape: (input_shape, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
|
||||
return init_fun, apply_fun
|
||||
Tanh = _elemwise_no_params(np.tanh)
|
||||
@ -168,7 +168,7 @@ def _pooling_layer(reducer, init_val, rescaler=None):
|
||||
rescale = rescaler(window_shape, strides, padding) if rescaler else None
|
||||
dims = (1,) + window_shape + (1,) # NHWC
|
||||
strides = (1,) + strides + (1,)
|
||||
def init_fun(input_shape):
|
||||
def init_fun(rng, input_shape):
|
||||
out_shape = lax.reduce_window_shape_tuple(input_shape, dims, strides, padding)
|
||||
return out_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
@ -191,7 +191,7 @@ AvgPool = _pooling_layer(lax.add, 0., _normalize_by_window_size)
|
||||
|
||||
def Flatten():
|
||||
"""Layer construction function for flattening all but the leading dim."""
|
||||
def init_fun(input_shape):
|
||||
def init_fun(rng, input_shape):
|
||||
output_shape = input_shape[0], reduce(op.mul, input_shape[1:], 1)
|
||||
return output_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
@ -202,7 +202,7 @@ Flatten = Flatten()
|
||||
|
||||
def Identity():
|
||||
"""Layer construction function for an identity layer."""
|
||||
init_fun = lambda input_shape: (input_shape, ())
|
||||
init_fun = lambda rng, input_shape: (input_shape, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: inputs
|
||||
return init_fun, apply_fun
|
||||
Identity = Identity()
|
||||
@ -210,14 +210,14 @@ Identity = Identity()
|
||||
|
||||
def FanOut(num):
|
||||
"""Layer construction function for a fan-out layer."""
|
||||
init_fun = lambda input_shape: ([input_shape] * num, ())
|
||||
init_fun = lambda rng, input_shape: ([input_shape] * num, ())
|
||||
apply_fun = lambda params, inputs, **kwargs: [inputs] * num
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
def FanInSum():
|
||||
"""Layer construction function for a fan-in sum layer."""
|
||||
init_fun = lambda input_shape: (input_shape[0], ())
|
||||
init_fun = lambda rng, input_shape: (input_shape[0], ())
|
||||
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
|
||||
return init_fun, apply_fun
|
||||
FanInSum = FanInSum()
|
||||
@ -225,7 +225,7 @@ FanInSum = FanInSum()
|
||||
|
||||
def FanInConcat(axis=-1):
|
||||
"""Layer construction function for a fan-in concatenation layer."""
|
||||
def init_fun(input_shape):
|
||||
def init_fun(rng, input_shape):
|
||||
ax = axis % len(input_shape[0])
|
||||
concat_size = sum(shape[ax] for shape in input_shape)
|
||||
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
|
||||
@ -237,7 +237,7 @@ def FanInConcat(axis=-1):
|
||||
|
||||
def Dropout(rate, mode='train'):
|
||||
"""Layer construction function for a dropout layer with given rate."""
|
||||
def init_fun(input_shape):
|
||||
def init_fun(rng, input_shape):
|
||||
return input_shape, ()
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
rng = kwargs.get('rng', None)
|
||||
@ -270,10 +270,11 @@ def serial(*layers):
|
||||
"""
|
||||
nlayers = len(layers)
|
||||
init_funs, apply_funs = zip(*layers)
|
||||
def init_fun(input_shape):
|
||||
def init_fun(rng, input_shape):
|
||||
params = []
|
||||
for init_fun in init_funs:
|
||||
input_shape, param = init_fun(input_shape)
|
||||
rng, layer_rng = random.split(rng)
|
||||
input_shape, param = init_fun(layer_rng, input_shape)
|
||||
params.append(param)
|
||||
return input_shape, params
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
@ -302,8 +303,10 @@ def parallel(*layers):
|
||||
"""
|
||||
nlayers = len(layers)
|
||||
init_funs, apply_funs = zip(*layers)
|
||||
def init_fun(input_shape):
|
||||
return zip(*[init(shape) for init, shape in zip(init_funs, input_shape)])
|
||||
def init_fun(rng, input_shape):
|
||||
rngs = random.split(rng, len(init_funs))
|
||||
return zip(*[init(rng, shape) for init, rng, shape
|
||||
in zip(init_funs, rngs, input_shape)])
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
rng = kwargs.pop('rng', None)
|
||||
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
|
||||
@ -323,8 +326,8 @@ def shape_dependent(make_layer):
|
||||
layer as returned by `make_layer` but with its construction delayed until
|
||||
input shapes are known.
|
||||
"""
|
||||
def init_fun(input_shape):
|
||||
return make_layer(input_shape)[0](input_shape)
|
||||
def init_fun(rng, input_shape):
|
||||
return make_layer(input_shape)[0](rng, input_shape)
|
||||
def apply_fun(params, inputs, **kwargs):
|
||||
return make_layer(inputs.shape)[1](params, inputs, **kwargs)
|
||||
return init_fun, apply_fun
|
||||
|
@ -116,6 +116,7 @@
|
||||
"from jax import vmap # for auto-vectorizing functions\n",
|
||||
"from functools import partial # for use with vmap\n",
|
||||
"from jax import jit # for compiling functions for speedup\n",
|
||||
"from jax import random # Stax initialization uses jax.random
|
||||
"from jax.experimental import stax # neural network library\n",
|
||||
"from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers\n",
|
||||
"import matplotlib.pyplot as plt # visualization"
|
||||
@ -134,7 +135,8 @@
|
||||
" Dense(1)\n",
|
||||
")\n",
|
||||
"in_shape = (-1, 1,)\n",
|
||||
"out_shape, net_params = net_init(in_shape)"
|
||||
"rng = random.PRNGKey(0)
|
||||
"out_shape, net_params = net_init(rng, in_shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -40,9 +40,10 @@ def random_inputs(rng, input_shape):
|
||||
|
||||
|
||||
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
||||
result_shape, params = init_fun(input_shape)
|
||||
inputs = random_inputs(onp.random.RandomState(0), input_shape)
|
||||
rng_key = random.PRNGKey(0)
|
||||
rng_key, init_key = random.split(rng_key)
|
||||
result_shape, params = init_fun(init_key, input_shape)
|
||||
inputs = random_inputs(onp.random.RandomState(0), input_shape)
|
||||
result = apply_fun(params, inputs, rng=rng_key)
|
||||
test_case.assertEqual(result.shape, result_shape)
|
||||
|
||||
@ -53,14 +54,16 @@ class StaxTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_shape={}".format(shape), "shape": shape}
|
||||
for shape in [(2, 3), (5,)]))
|
||||
def testRandnInitShape(self, shape):
|
||||
out = stax.randn()(shape)
|
||||
key = random.PRNGKey(0)
|
||||
out = stax.randn()(key, shape)
|
||||
self.assertEqual(out.shape, shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}".format(shape), "shape": shape}
|
||||
for shape in [(2, 3), (2, 3, 4)]))
|
||||
def testGlorotInitShape(self, shape):
|
||||
out = stax.glorot()(shape)
|
||||
key = random.PRNGKey(0)
|
||||
out = stax.glorot()(key, shape)
|
||||
self.assertEqual(out.shape, shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -164,11 +167,12 @@ class StaxTest(jtu.JaxTestCase):
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
|
||||
|
||||
def testIssue182(self):
|
||||
key = random.PRNGKey(0)
|
||||
init_fun, apply_fun = stax.Softmax
|
||||
input_shape = (10, 3)
|
||||
inputs = onp.arange(30.).astype("float32").reshape(input_shape)
|
||||
|
||||
out_shape, params = init_fun(input_shape)
|
||||
out_shape, params = init_fun(key, input_shape)
|
||||
out = apply_fun(params, inputs)
|
||||
|
||||
assert out_shape == out.shape
|
||||
@ -176,11 +180,12 @@ class StaxTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
def testBatchNormShapeNHWC(self):
|
||||
key = random.PRNGKey(0)
|
||||
init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
|
||||
input_shape = (4, 5, 6, 7)
|
||||
inputs = random_inputs(onp.random.RandomState(0), input_shape)
|
||||
|
||||
out_shape, params = init_fun(input_shape)
|
||||
out_shape, params = init_fun(key, input_shape)
|
||||
out = apply_fun(params, inputs)
|
||||
|
||||
self.assertEqual(out_shape, input_shape)
|
||||
@ -190,12 +195,13 @@ class StaxTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_shape, out.shape)
|
||||
|
||||
def testBatchNormShapeNCHW(self):
|
||||
key = random.PRNGKey(0)
|
||||
# Regression test for https://github.com/google/jax/issues/461
|
||||
init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
|
||||
input_shape = (4, 5, 6, 7)
|
||||
inputs = random_inputs(onp.random.RandomState(0), input_shape)
|
||||
|
||||
out_shape, params = init_fun(input_shape)
|
||||
out_shape, params = init_fun(key, input_shape)
|
||||
out = apply_fun(params, inputs)
|
||||
|
||||
self.assertEqual(out_shape, input_shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user