mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
address comments
This commit is contained in:
parent
cc49d8b325
commit
bf28c44ada
@ -32,13 +32,14 @@ from jax import lax
|
||||
from jax import random
|
||||
import jax.numpy as np
|
||||
|
||||
from jax.nn import *
|
||||
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
|
||||
leaky_relu, selu, gelu, normalize)
|
||||
from jax.nn.initializers import glorot_normal, normal, ones, zeros
|
||||
|
||||
# aliases for backwards compatibility
|
||||
glorot = initializers.glorot_normal
|
||||
randn = initializers.normal
|
||||
zeros = initializers.zeros
|
||||
ones = initializers.ones
|
||||
glorot = glorot_normal
|
||||
randn = normal
|
||||
logsoftmax = log_softmax
|
||||
|
||||
# Following the convention used in Keras and tf.layers, we use CamelCase for the
|
||||
# names of layer constructors, like Conv and Relu, while using snake_case for
|
||||
@ -50,7 +51,7 @@ ones = initializers.ones
|
||||
# apply_fun: takes params, inputs, and an rng key and applies the layer.
|
||||
|
||||
|
||||
def Dense(out_dim, W_init=initializers.glorot_normal(), b_init=initializers.normal()):
|
||||
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
|
||||
"""Layer constructor function for a dense (fully-connected) layer."""
|
||||
def init_fun(rng, input_shape):
|
||||
output_shape = input_shape[:-1] + (out_dim,)
|
||||
@ -65,13 +66,12 @@ def Dense(out_dim, W_init=initializers.glorot_normal(), b_init=initializers.norm
|
||||
|
||||
def GeneralConv(dimension_numbers, out_chan, filter_shape,
|
||||
strides=None, padding='VALID', W_init=None,
|
||||
b_init=initializers.normal(1e-6)):
|
||||
b_init=normal(1e-6)):
|
||||
"""Layer construction function for a general convolution layer."""
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
one = (1,) * len(filter_shape)
|
||||
strides = strides or one
|
||||
W_init = W_init or initializers.glorot_normal(rhs_spec.index('I'),
|
||||
rhs_spec.index('O'))
|
||||
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
|
||||
def init_fun(rng, input_shape):
|
||||
filter_shape_iter = iter(filter_shape)
|
||||
kernel_shape = [out_chan if c == 'O' else
|
||||
@ -94,12 +94,12 @@ Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
|
||||
|
||||
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
|
||||
strides=None, padding='VALID', W_init=None,
|
||||
b_init=initializers.normal(1e-6)):
|
||||
b_init=normal(1e-6)):
|
||||
"""Layer construction function for a general transposed-convolution layer."""
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
one = (1,) * len(filter_shape)
|
||||
strides = strides or one
|
||||
W_init = W_init or glorot(rhs_spec.index('O'), rhs_spec.index('I'))
|
||||
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
|
||||
def init_fun(rng, input_shape):
|
||||
filter_shape_iter = iter(filter_shape)
|
||||
kernel_shape = [out_chan if c == 'O' else
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2018 Google LLC
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,78 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Common neural network activations and other functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax.scipy.special import expit
|
||||
import jax.numpy as np
|
||||
from jax import jarrett
|
||||
"""Common functions for neural network libraries."""
|
||||
|
||||
from . import initializers
|
||||
|
||||
# activations
|
||||
|
||||
def relu(x): return np.maximum(x, 0)
|
||||
def softplus(x): return np.logaddexp(x, 0)
|
||||
def soft_sign(x): return x / (np.abs(x) + 1)
|
||||
def sigmoid(x): return expit(x)
|
||||
def swish(x): return x * sigmoid(x)
|
||||
def log_sigmoid(x): return -softplus(-x)
|
||||
|
||||
def elu(x, alpha=1.0):
|
||||
return np.where(x > 0, x, alpha * np.expm1(x))
|
||||
|
||||
def leaky_relu(x, negative_slope=1e-2):
|
||||
return np.where(x >= 0, x, negative_slope * x)
|
||||
|
||||
def hard_tanh(x):
|
||||
return np.where(x > 1, 1, np.where(x < -1, -1, x))
|
||||
|
||||
def celu(x, alpha=1.0):
|
||||
"""Continuously-differentiable exponential linear unit activation"""
|
||||
return np.where(x > 0, x, alpha * np.expm1(x / alpha))
|
||||
|
||||
def selu(x):
|
||||
"""Scaled exponential linear unit activation"""
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
scale = 1.0507009873554804934193349852946
|
||||
return scale * leaky_relu(x, alpha)
|
||||
|
||||
@jarrett
|
||||
def gelu(x):
|
||||
"""Gaussian error linear unit activation"""
|
||||
return x * (lax.erf(x / np.sqrt(2)) + 1) / 2
|
||||
|
||||
def glu(x, axis=-1):
|
||||
"""Gated linear unit activation"""
|
||||
size = x.shape[axis]
|
||||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||||
return x[..., :size] * sigmoid(x[..., size:])
|
||||
|
||||
# other functions
|
||||
|
||||
def log_softmax(x, axis=-1):
|
||||
shifted = x - x.max(axis, keepdims=True)
|
||||
return shifted - np.log(np.sum(np.exp(shifted), axis, keepdims=True))
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
unnormalized = np.exp(x - x.max(axis, keepdims=True))
|
||||
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
||||
|
||||
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
|
||||
"""Normalize an array by subtracting mean and dividing by sqrt(var)."""
|
||||
if mean is None:
|
||||
mean = np.mean(x, axis, keepdims=True)
|
||||
if variance is None:
|
||||
# this definition is traditionally seen as less accurate than np.var's
|
||||
# mean((x - mean(x))**2) but may be faster and even, given typical
|
||||
# activation distributions and low-precision arithmetic, more accurate
|
||||
# when used in neural network normalization layers
|
||||
variance = np.mean(x**2, axis, keepdims=True) - mean**2
|
||||
return (x - mean) * lax.rsqrt(variance + epsilon)
|
||||
from .functions import *
|
||||
|
87
jax/nn/functions.py
Normal file
87
jax/nn/functions.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared neural network activations and other functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax.scipy.special import expit
|
||||
import jax.numpy as np
|
||||
from jax import jarrett
|
||||
|
||||
# activations
|
||||
|
||||
def relu(x): return np.maximum(x, 0)
|
||||
def softplus(x): return np.logaddexp(x, 0)
|
||||
def soft_sign(x): return x / (np.abs(x) + 1)
|
||||
def sigmoid(x): return expit(x)
|
||||
def swish(x): return x * sigmoid(x)
|
||||
def log_sigmoid(x): return -softplus(-x)
|
||||
|
||||
def elu(x, alpha=1.0):
|
||||
return np.where(x > 0, x, alpha * np.expm1(x))
|
||||
|
||||
def leaky_relu(x, negative_slope=1e-2):
|
||||
return np.where(x >= 0, x, negative_slope * x)
|
||||
|
||||
def hard_tanh(x):
|
||||
return np.where(x > 1, 1, np.where(x < -1, -1, x))
|
||||
|
||||
def celu(x, alpha=1.0):
|
||||
"""Continuously-differentiable exponential linear unit activation"""
|
||||
return np.where(x > 0, x, alpha * np.expm1(x / alpha))
|
||||
|
||||
def selu(x):
|
||||
"""Scaled exponential linear unit activation"""
|
||||
alpha = 1.6732632423543772848170429916717
|
||||
scale = 1.0507009873554804934193349852946
|
||||
return scale * leaky_relu(x, alpha)
|
||||
|
||||
@jarrett
|
||||
def gelu(x):
|
||||
"""Gaussian error linear unit activation"""
|
||||
return x * (lax.erf(x / np.sqrt(2)) + 1) / 2
|
||||
|
||||
def glu(x, axis=-1):
|
||||
"""Gated linear unit activation"""
|
||||
size = x.shape[axis]
|
||||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||||
return x[..., :size] * sigmoid(x[..., size:])
|
||||
|
||||
# other functions
|
||||
|
||||
def log_softmax(x, axis=-1):
|
||||
shifted = x - x.max(axis, keepdims=True)
|
||||
return shifted - np.log(np.sum(np.exp(shifted), axis, keepdims=True))
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
unnormalized = np.exp(x - x.max(axis, keepdims=True))
|
||||
return unnormalized / unnormalized.sum(axis, keepdims=True)
|
||||
|
||||
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
|
||||
"""Normalize an array by subtracting mean and dividing by sqrt(var)."""
|
||||
if mean is None:
|
||||
mean = np.mean(x, axis, keepdims=True)
|
||||
if variance is None:
|
||||
# this definition is traditionally seen as less accurate than np.var's
|
||||
# mean((x - mean(x))**2) but may be faster and even, given typical
|
||||
# activation distributions and low-precision arithmetic, more accurate
|
||||
# when used in neural network normalization layers
|
||||
variance = np.mean(x**2, axis, keepdims=True) - mean**2
|
||||
return (x - mean) * lax.rsqrt(variance + epsilon)
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2018 Google LLC
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Common neural network layer initializers, consistent with definitions
|
||||
"""
|
||||
Common neural network layer initializers, consistent with definitions
|
||||
used in Keras and Sonnet.
|
||||
"""
|
||||
|
||||
@ -48,18 +49,20 @@ def _compute_fans(shape, in_axis=-2, out_axis=-1):
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1):
|
||||
def init(key, shape, dtype=np.float32):
|
||||
variance = scale
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in": variance /= fan_in
|
||||
elif mode == "fan_out": variance /= fan_out
|
||||
elif mode == "fan_avg": variance /= (fan_in + fan_out) / 2
|
||||
else: raise ValueError("invalid mode for variance scaling initializer")
|
||||
if mode == "fan_in": denominator = fan_in
|
||||
elif mode == "fan_out": denominator = fan_out
|
||||
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError(
|
||||
"invalid mode for variance scaling initializer: {}".format(mode))
|
||||
variance = np.array(scale / denominator, dtype=dtype)
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
stddev = onp.sqrt(variance) / .87962566103423978
|
||||
stddev = np.sqrt(variance) / np.array(.87962566103423978, dtype)
|
||||
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
|
||||
elif distribution == "normal":
|
||||
return random.normal(key, shape, dtype) * onp.sqrt(variance)
|
||||
return random.normal(key, shape, dtype) * np.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return random.uniform(key, shape, dtype, -1) * onp.sqrt(3 * variance)
|
||||
else:
|
||||
|
@ -384,13 +384,12 @@ def _normal(key, shape, dtype):
|
||||
|
||||
|
||||
def truncated_normal(key, lower, upper, shape=(), dtype=onp.float64):
|
||||
"""Sample truncated standard normal random values with given shape and float
|
||||
dtype.
|
||||
"""Sample truncated standard normal random values with given shape and dtype.
|
||||
|
||||
Args:
|
||||
key: a PRNGKey used as the random key.
|
||||
lower: a lower bound for truncation.
|
||||
upper: an upper bound for truncation.
|
||||
lower: a floating-point lower bound for truncation.
|
||||
upper: a floating-point upper bound for truncation.
|
||||
shape: a tuple of nonnegative integers representing the shape.
|
||||
dtype: optional, a float dtype for the returned values (default float64 if
|
||||
jax_enable_x64 is true, otherwise float32).
|
||||
@ -404,9 +403,9 @@ def truncated_normal(key, lower, upper, shape=(), dtype=onp.float64):
|
||||
@partial(jit, static_argnums=(3, 4))
|
||||
def _truncated_normal(key, lower, upper, shape, dtype):
|
||||
_check_shape("truncated_normal", shape)
|
||||
sqrt2 = onp.sqrt(2)
|
||||
a = lax.erf(lower / sqrt2)
|
||||
b = lax.erf(upper / sqrt2)
|
||||
sqrt2 = onp.array(onp.sqrt(2), dtype)
|
||||
a = lax.erf(lax.convert_element_type(lower, dtype) / sqrt2)
|
||||
b = lax.erf(lax.convert_element_type(upper, dtype) / sqrt2)
|
||||
u = uniform(key, shape, dtype)
|
||||
return sqrt2 * lax.erf_inv(a + u * (b - a))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user