2018-11-19 07:43:23 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""A mock-up showing a ResNet50 network with training on synthetic data.
|
|
|
|
|
|
|
|
This file uses the stax neural network definition library and the minmax
|
|
|
|
optimization library.
|
|
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import numpy.random as npr
|
|
|
|
|
|
|
|
import jax.numpy as np
|
2018-11-29 12:30:34 -08:00
|
|
|
from jax.config import config
|
2018-11-19 07:43:23 -08:00
|
|
|
from jax import jit, grad
|
|
|
|
from jax.experimental import minmax
|
|
|
|
from jax.experimental import stax
|
|
|
|
from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
|
|
|
|
FanOut, Flatten, GeneralConv, Identity,
|
2018-11-19 17:52:23 -08:00
|
|
|
MaxPool, Relu, LogSoftmax)
|
2018-11-19 07:43:23 -08:00
|
|
|
|
|
|
|
|
|
|
|
# ResNet blocks compose other layers
|
|
|
|
|
|
|
|
def ConvBlock(kernel_size, filters, strides=(2, 2)):
|
|
|
|
ks = kernel_size
|
|
|
|
filters1, filters2, filters3 = filters
|
|
|
|
Main = stax.serial(
|
|
|
|
Conv(filters1, (1, 1), strides), BatchNorm(), Relu,
|
|
|
|
Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
|
|
|
|
Conv(filters3, (1, 1)), BatchNorm())
|
|
|
|
Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())
|
|
|
|
return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
|
|
|
|
|
|
|
|
|
|
|
|
def IdentityBlock(kernel_size, filters):
|
|
|
|
ks = kernel_size
|
|
|
|
filters1, filters2 = filters
|
|
|
|
def make_main(input_shape):
|
|
|
|
# the number of output channels depends on the number of input channels
|
|
|
|
return stax.serial(
|
|
|
|
Conv(filters1, (1, 1)), BatchNorm(), Relu,
|
|
|
|
Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu,
|
|
|
|
Conv(input_shape[3], (1, 1)), BatchNorm())
|
|
|
|
Main = stax.shape_dependent(make_main)
|
|
|
|
return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)
|
|
|
|
|
|
|
|
|
|
|
|
# ResNet architectures compose layers and ResNet blocks
|
|
|
|
|
|
|
|
def ResNet50(num_classes):
|
|
|
|
return stax.serial(
|
|
|
|
GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
|
|
|
|
BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)),
|
|
|
|
ConvBlock(3, [64, 64, 256], strides=(1, 1)),
|
|
|
|
IdentityBlock(3, [64, 64]),
|
|
|
|
IdentityBlock(3, [64, 64]),
|
|
|
|
ConvBlock(3, [128, 128, 512]),
|
|
|
|
IdentityBlock(3, [128, 128]),
|
|
|
|
IdentityBlock(3, [128, 128]),
|
|
|
|
IdentityBlock(3, [128, 128]),
|
|
|
|
ConvBlock(3, [256, 256, 1024]),
|
|
|
|
IdentityBlock(3, [256, 256]),
|
|
|
|
IdentityBlock(3, [256, 256]),
|
|
|
|
IdentityBlock(3, [256, 256]),
|
|
|
|
IdentityBlock(3, [256, 256]),
|
|
|
|
IdentityBlock(3, [256, 256]),
|
|
|
|
ConvBlock(3, [512, 512, 2048]),
|
|
|
|
IdentityBlock(3, [512, 512]),
|
|
|
|
IdentityBlock(3, [512, 512]),
|
2018-11-19 17:52:23 -08:00
|
|
|
AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
|
2018-11-19 15:50:10 -08:00
|
|
|
|
2018-11-19 07:43:23 -08:00
|
|
|
|
2018-11-27 16:40:21 -08:00
|
|
|
if __name__ == "__main__":
|
2018-11-19 07:43:23 -08:00
|
|
|
batch_size = 8
|
|
|
|
num_classes = 1001
|
|
|
|
input_shape = (224, 224, 3, batch_size)
|
|
|
|
step_size = 0.1
|
|
|
|
num_steps = 10
|
|
|
|
|
|
|
|
init_fun, predict_fun = ResNet50(num_classes)
|
|
|
|
_, init_params = init_fun(input_shape)
|
|
|
|
|
|
|
|
def loss(params, batch):
|
|
|
|
inputs, targets = batch
|
|
|
|
logits = predict_fun(params, inputs)
|
|
|
|
return np.sum(logits * targets)
|
|
|
|
|
|
|
|
def accuracy(params, batch):
|
|
|
|
inputs, targets = batch
|
|
|
|
target_class = np.argmax(targets, axis=-1)
|
|
|
|
predicted_class = np.argmax(predict_fun(params, inputs), axis=-1)
|
|
|
|
return np.mean(predicted_class == target_class)
|
|
|
|
|
|
|
|
def synth_batches():
|
|
|
|
rng = npr.RandomState(0)
|
|
|
|
while True:
|
|
|
|
images = rng.rand(*input_shape).astype('float32')
|
|
|
|
labels = rng.randint(num_classes, size=(batch_size, 1))
|
|
|
|
onehot_labels = labels == np.arange(num_classes)
|
|
|
|
yield images, onehot_labels
|
|
|
|
|
|
|
|
opt_init, opt_update = minmax.momentum(step_size, mass=0.9)
|
|
|
|
batches = synth_batches()
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def update(i, opt_state, batch):
|
|
|
|
params = minmax.get_params(opt_state)
|
|
|
|
return opt_update(i, grad(loss)(params, batch), opt_state)
|
|
|
|
|
|
|
|
opt_state = opt_init(init_params)
|
|
|
|
for i in xrange(num_steps):
|
|
|
|
opt_state = update(i, opt_state, next(batches))
|
|
|
|
trained_params = minmax.get_params(opt_state)
|
2018-11-21 22:49:34 -08:00
|
|
|
|