address comments

This commit is contained in:
James Bradbury 2019-09-03 17:51:29 -07:00
parent cc49d8b325
commit bf28c44ada
5 changed files with 119 additions and 101 deletions

View File

@ -32,13 +32,14 @@ from jax import lax
from jax import random
import jax.numpy as np
from jax.nn import *
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
leaky_relu, selu, gelu, normalize)
from jax.nn.initializers import glorot_normal, normal, ones, zeros
# aliases for backwards compatibility
glorot = initializers.glorot_normal
randn = initializers.normal
zeros = initializers.zeros
ones = initializers.ones
glorot = glorot_normal
randn = normal
logsoftmax = log_softmax
# Following the convention used in Keras and tf.layers, we use CamelCase for the
# names of layer constructors, like Conv and Relu, while using snake_case for
@ -50,7 +51,7 @@ ones = initializers.ones
# apply_fun: takes params, inputs, and an rng key and applies the layer.
def Dense(out_dim, W_init=initializers.glorot_normal(), b_init=initializers.normal()):
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
@ -65,13 +66,12 @@ def Dense(out_dim, W_init=initializers.glorot_normal(), b_init=initializers.norm
def GeneralConv(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=initializers.normal(1e-6)):
b_init=normal(1e-6)):
"""Layer construction function for a general convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or initializers.glorot_normal(rhs_spec.index('I'),
rhs_spec.index('O'))
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
@ -94,12 +94,12 @@ Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=initializers.normal(1e-6)):
b_init=normal(1e-6)):
"""Layer construction function for a general transposed-convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot(rhs_spec.index('O'), rhs_spec.index('I'))
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else

View File

@ -1,4 +1,4 @@
# Copyright 2018 Google LLC
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,78 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common neural network activations and other functions."""
from __future__ import absolute_import
from __future__ import division
import numpy as onp
from jax import lax
from jax import random
from jax.scipy.special import expit
import jax.numpy as np
from jax import jarrett
"""Common functions for neural network libraries."""
from . import initializers
# activations
def relu(x): return np.maximum(x, 0)
def softplus(x): return np.logaddexp(x, 0)
def soft_sign(x): return x / (np.abs(x) + 1)
def sigmoid(x): return expit(x)
def swish(x): return x * sigmoid(x)
def log_sigmoid(x): return -softplus(-x)
def elu(x, alpha=1.0):
return np.where(x > 0, x, alpha * np.expm1(x))
def leaky_relu(x, negative_slope=1e-2):
return np.where(x >= 0, x, negative_slope * x)
def hard_tanh(x):
return np.where(x > 1, 1, np.where(x < -1, -1, x))
def celu(x, alpha=1.0):
"""Continuously-differentiable exponential linear unit activation"""
return np.where(x > 0, x, alpha * np.expm1(x / alpha))
def selu(x):
"""Scaled exponential linear unit activation"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * leaky_relu(x, alpha)
@jarrett
def gelu(x):
"""Gaussian error linear unit activation"""
return x * (lax.erf(x / np.sqrt(2)) + 1) / 2
def glu(x, axis=-1):
"""Gated linear unit activation"""
size = x.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
return x[..., :size] * sigmoid(x[..., size:])
# other functions
def log_softmax(x, axis=-1):
shifted = x - x.max(axis, keepdims=True)
return shifted - np.log(np.sum(np.exp(shifted), axis, keepdims=True))
def softmax(x, axis=-1):
unnormalized = np.exp(x - x.max(axis, keepdims=True))
return unnormalized / unnormalized.sum(axis, keepdims=True)
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
"""Normalize an array by subtracting mean and dividing by sqrt(var)."""
if mean is None:
mean = np.mean(x, axis, keepdims=True)
if variance is None:
# this definition is traditionally seen as less accurate than np.var's
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = np.mean(x**2, axis, keepdims=True) - mean**2
return (x - mean) * lax.rsqrt(variance + epsilon)
from .functions import *

87
jax/nn/functions.py Normal file
View File

@ -0,0 +1,87 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared neural network activations and other functions."""
from __future__ import absolute_import
from __future__ import division
import numpy as onp
from jax import lax
from jax import random
from jax.scipy.special import expit
import jax.numpy as np
from jax import jarrett
# activations
def relu(x): return np.maximum(x, 0)
def softplus(x): return np.logaddexp(x, 0)
def soft_sign(x): return x / (np.abs(x) + 1)
def sigmoid(x): return expit(x)
def swish(x): return x * sigmoid(x)
def log_sigmoid(x): return -softplus(-x)
def elu(x, alpha=1.0):
return np.where(x > 0, x, alpha * np.expm1(x))
def leaky_relu(x, negative_slope=1e-2):
return np.where(x >= 0, x, negative_slope * x)
def hard_tanh(x):
return np.where(x > 1, 1, np.where(x < -1, -1, x))
def celu(x, alpha=1.0):
"""Continuously-differentiable exponential linear unit activation"""
return np.where(x > 0, x, alpha * np.expm1(x / alpha))
def selu(x):
"""Scaled exponential linear unit activation"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * leaky_relu(x, alpha)
@jarrett
def gelu(x):
"""Gaussian error linear unit activation"""
return x * (lax.erf(x / np.sqrt(2)) + 1) / 2
def glu(x, axis=-1):
"""Gated linear unit activation"""
size = x.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
return x[..., :size] * sigmoid(x[..., size:])
# other functions
def log_softmax(x, axis=-1):
shifted = x - x.max(axis, keepdims=True)
return shifted - np.log(np.sum(np.exp(shifted), axis, keepdims=True))
def softmax(x, axis=-1):
unnormalized = np.exp(x - x.max(axis, keepdims=True))
return unnormalized / unnormalized.sum(axis, keepdims=True)
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
"""Normalize an array by subtracting mean and dividing by sqrt(var)."""
if mean is None:
mean = np.mean(x, axis, keepdims=True)
if variance is None:
# this definition is traditionally seen as less accurate than np.var's
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = np.mean(x**2, axis, keepdims=True) - mean**2
return (x - mean) * lax.rsqrt(variance + epsilon)

View File

@ -1,4 +1,4 @@
# Copyright 2018 Google LLC
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common neural network layer initializers, consistent with definitions
"""
Common neural network layer initializers, consistent with definitions
used in Keras and Sonnet.
"""
@ -48,18 +49,20 @@ def _compute_fans(shape, in_axis=-2, out_axis=-1):
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1):
def init(key, shape, dtype=np.float32):
variance = scale
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in": variance /= fan_in
elif mode == "fan_out": variance /= fan_out
elif mode == "fan_avg": variance /= (fan_in + fan_out) / 2
else: raise ValueError("invalid mode for variance scaling initializer")
if mode == "fan_in": denominator = fan_in
elif mode == "fan_out": denominator = fan_out
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
else:
raise ValueError(
"invalid mode for variance scaling initializer: {}".format(mode))
variance = np.array(scale / denominator, dtype=dtype)
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
stddev = onp.sqrt(variance) / .87962566103423978
stddev = np.sqrt(variance) / np.array(.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
elif distribution == "normal":
return random.normal(key, shape, dtype) * onp.sqrt(variance)
return random.normal(key, shape, dtype) * np.sqrt(variance)
elif distribution == "uniform":
return random.uniform(key, shape, dtype, -1) * onp.sqrt(3 * variance)
else:

View File

@ -384,13 +384,12 @@ def _normal(key, shape, dtype):
def truncated_normal(key, lower, upper, shape=(), dtype=onp.float64):
"""Sample truncated standard normal random values with given shape and float
dtype.
"""Sample truncated standard normal random values with given shape and dtype.
Args:
key: a PRNGKey used as the random key.
lower: a lower bound for truncation.
upper: an upper bound for truncation.
lower: a floating-point lower bound for truncation.
upper: a floating-point upper bound for truncation.
shape: a tuple of nonnegative integers representing the shape.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
@ -404,9 +403,9 @@ def truncated_normal(key, lower, upper, shape=(), dtype=onp.float64):
@partial(jit, static_argnums=(3, 4))
def _truncated_normal(key, lower, upper, shape, dtype):
_check_shape("truncated_normal", shape)
sqrt2 = onp.sqrt(2)
a = lax.erf(lower / sqrt2)
b = lax.erf(upper / sqrt2)
sqrt2 = onp.array(onp.sqrt(2), dtype)
a = lax.erf(lax.convert_element_type(lower, dtype) / sqrt2)
b = lax.erf(lax.convert_element_type(upper, dtype) / sqrt2)
u = uniform(key, shape, dtype)
return sqrt2 * lax.erf_inv(a + u * (b - a))