Document jax.nn.initializers.

This commit is contained in:
Peter Hawkins 2022-03-07 17:05:51 -05:00
parent 7061521137
commit d3d666d081
3 changed files with 397 additions and 54 deletions

View File

@ -13,17 +13,25 @@ Initializers
This module provides common neural network layer initializers,
consistent with definitions used in Keras and Sonnet.
An initializer is a function that takes three arguments:
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
key used when generating random numbers to initialize the array.
.. autosummary::
:toctree: _autosummary
zeros
ones
uniform
normal
variance_scaling
glorot_uniform
constant
delta_orthogonal
glorot_normal
lecun_uniform
lecun_normal
he_uniform
glorot_uniform
he_normal
he_uniform
lecun_normal
lecun_uniform
normal
ones
orthogonal
uniform
variance_scaling
zeros

View File

@ -18,7 +18,7 @@ used in Keras and Sonnet.
"""
from functools import partial
from typing import Any, Callable, Sequence, Union
import numpy as np
@ -29,22 +29,90 @@ from jax import core
from jax._src.util import prod
from jax import dtypes
def zeros(key, shape, dtype=jnp.float_): return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
def ones(key, shape, dtype=jnp.float_): return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
DType = Any
def constant(value, dtype=jnp.float_):
def zeros(key, shape, dtype: DType = jnp.float_):
"""An initializer that returns a constant array full of zeros.
The ``key`` argument is ignored.
>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
"""
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
def ones(key, shape, dtype: DType = jnp.float_):
"""An initializer that returns a constant array full of ones.
The ``key`` argument is ignored.
>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
DeviceArray([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32)
"""
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
def constant(value, dtype: DType = jnp.float_) -> Callable:
"""Builds an initializer that returns arrays full of a constant ``value``.
Args:
value: the constant value with which to fill the initializer.
dtype: optional; the initializer's default dtype.
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return jnp.full(shape, value, dtype=dtype)
return init
def uniform(scale=1e-2, dtype=jnp.float_):
def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
"""Builds an initializer that returns real uniformly-distributed random arrays.
Args:
scale: optional; the upper bound of the random distribution.
dtype: optional; the initializer's default dtype.
Returns:
An initializer that returns arrays whose values are uniformly distributed in
the range ``[0, scale)``.
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return random.uniform(key, shape, dtype) * scale
return init
def normal(stddev=1e-2, dtype=jnp.float_):
def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
"""Builds an initializer that returns real normally-distributed random arrays.
Args:
stddev: optional; the standard deviation of the distribution.
dtype: optional; the initializer's default dtype.
Returns:
An initializer that returns arrays whose values are normally distributed
with mean ``0`` and standard deviation ``stddev``.
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return random.normal(key, shape, dtype) * stddev
@ -98,40 +166,49 @@ def _complex_truncated_normal(key, upper, shape, dtype):
theta = 2 * jnp.pi * random.uniform(key_theta, shape, dtype)
return r * jnp.exp(1j * theta)
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1,
batch_axis=(), dtype=jnp.float_):
"""
Initializer capable of adapting its scale to the shape of the weights tensor.
def variance_scaling(scale, mode: str, distribution: str,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
r"""
Initializer that adapts its scale to the shape of the weights tensor.
With `distribution="truncated_normal" or "normal"`, samples are
drawn from a truncated/untruncated normal distribution with a mean of zero and
a standard deviation (after truncation, if used) `stddev = sqrt(scale / n)`,
where `n` is:
- number of input units in the weights tensor, if `mode="fan_in"`
- number of output units, if `mode="fan_out"`
- average of the numbers of input and output units, if `mode="fan_avg"`
With ``distribution="truncated_normal"`` or ``distribution="normal"``, samples
are drawn from a (truncated) normal distribution with a mean of zero
and a standard deviation (after truncation, if applicable) of
:math:`\sqrt{\frac{scale}{n}}`, where `n` is:
This initializer can be configured with in_axis, out_axis, and batch_axis to
work with general convolutional or dense layers; axes that are not in any of
those arguments are assumed to be the "receptive field" (convolution kernel
spatial axes).
* the number of input units in the weights tensor, if ``mode="fan_in"``,
* the number of output units, if ``mode="fan_out"``, or
* the average of the numbers of input and output units, if ``mode="fan_avg"``.
With `distribution="truncated_normal"`, the absolute values of the samples are
truncated below 2 standard deviations before truncation.
This initializer can be configured with ``in_axis``, ``out_axis``, and
``batch_axis`` to work with general convolutional or dense layers; axes that
are not in any of those arguments are assumed to be the "receptive field"
(convolution kernel spatial axes).
With `distribution="uniform"`, samples are drawn from:
- a uniform interval, if `dtype` is real
- a uniform disk, if `dtype` is complex
with a mean of zero and a standard deviation of `stddev`.
With ``distribution="truncated_normal"``, the absolute values of the samples
are truncated at 2 standard deviations before scaling.
With ``distribution="uniform"``, samples are drawn from:
* a uniform interval, if `dtype` is real, or
* a uniform disk, if `dtype` is complex,
with a mean of zero and a standard deviation of ``stddev``.
Args:
scale: scaling factor (positive float).
mode: one of "fan_in", "fan_out", and "fan_avg".
distribution: random distribution to use. One of "truncated_normal",
"normal" and "uniform".
in_axis: axis or sequence of axes of the input dimension in the weights tensor.
out_axis: axis or sequence of axes of the output dimension in the weights tensor.
batch_axis: axis or sequence of axes in the weight tensor that should be ignored.
mode: one of ``"fan_in"``, ``"fan_out"``, and ``"fan_avg"``.
distribution: random distribution to use. One of ``"truncated_normal"``,
``"normal"`` and ``"uniform"``.
in_axis: axis or sequence of axes of the input dimension in the weights
array.
out_axis: axis or sequence of axes of the output dimension in the weights
array.
batch_axis: axis or sequence of axes in the weight array that should be
ignored.
dtype: the dtype of the weights.
"""
@ -168,19 +245,250 @@ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1,
return init
xavier_uniform = glorot_uniform = partial(variance_scaling, 1.0, "fan_avg", "uniform")
xavier_normal = glorot_normal = partial(variance_scaling, 1.0, "fan_avg", "truncated_normal")
lecun_uniform = partial(variance_scaling, 1.0, "fan_in", "uniform")
lecun_normal = partial(variance_scaling, 1.0, "fan_in", "truncated_normal")
kaiming_uniform = he_uniform = partial(variance_scaling, 2.0, "fan_in", "uniform")
kaiming_normal = he_normal = partial(variance_scaling, 2.0, "fan_in", "truncated_normal")
def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
"""Builds a Glorot uniform initializer (aka Xavier uniform initializer).
def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
A `Glorot uniform initializer`_ is a specialization of
:func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,
``mode="fan_avg"``, and ``distribution="uniform"``.
Args:
in_axis: axis or sequence of axes of the input dimension in the weights
array.
out_axis: axis or sequence of axes of the output dimension in the weights
array.
batch_axis: axis or sequence of axes in the weight array that should be
ignored.
dtype: the dtype of the weights.
Returns:
An initializer.
Example:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[ 0.50350785, 0.8088631 , 0.81566876],
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
.. _Glorot uniform initializer: http://proceedings.mlr.press/v9/glorot10a.html
"""
Construct an initializer for uniformly distributed orthogonal matrices.
return variance_scaling(1.0, "fan_avg", "uniform", in_axis=in_axis,
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
xavier_uniform = glorot_uniform
def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
"""Builds a Glorot normal initializer (aka Xavier normal initializer).
A `Glorot normal initializer`_ is a specialization of
:func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,
``mode="fan_avg"``, and ``distribution="truncated_normal"``.
Args:
in_axis: axis or sequence of axes of the input dimension in the weights
array.
out_axis: axis or sequence of axes of the output dimension in the weights
array.
batch_axis: axis or sequence of axes in the weight array that should be
ignored.
dtype: the dtype of the weights.
Returns:
An initializer.
Example:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[ 0.41770416, 0.75262755, 0.7619329 ],
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
.. _Glorot normal initializer: http://proceedings.mlr.press/v9/glorot10a.html
"""
return variance_scaling(1.0, "fan_avg", "truncated_normal", in_axis=in_axis,
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
xavier_normal = glorot_normal
def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
"""Builds a Lecun uniform initializer.
A `Lecun uniform initializer`_ is a specialization of
:func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,
``mode="fan_in"``, and ``distribution="uniform"``.
Args:
in_axis: axis or sequence of axes of the input dimension in the weights
array.
out_axis: axis or sequence of axes of the output dimension in the weights
array.
batch_axis: axis or sequence of axes in the weight array that should be
ignored.
dtype: the dtype of the weights.
Returns:
An initializer.
Example:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[ 0.56293887, 0.90433645, 0.9119454 ],
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
.. _Lecun uniform initializer: https://arxiv.org/abs/1706.02515
"""
return variance_scaling(1.0, "fan_in", "uniform", in_axis=in_axis,
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
"""Builds a Lecun normal initializer.
A `Lecun normal initializer`_ is a specialization of
:func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,
``mode="fan_in"``, and ``distribution="truncated_normal"``.
Args:
in_axis: axis or sequence of axes of the input dimension in the weights
array.
out_axis: axis or sequence of axes of the output dimension in the weights
array.
batch_axis: axis or sequence of axes in the weight array that should be
ignored.
dtype: the dtype of the weights.
Returns:
An initializer.
Example:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[ 0.46700746, 0.8414632 , 0.8518669 ],
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
.. _Lecun normal initializer: https://arxiv.org/abs/1706.02515
"""
return variance_scaling(1.0, "fan_in", "truncated_normal", in_axis=in_axis,
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
"""Builds a He uniform initializer (aka Kaiming uniform initializer).
A `He uniform initializer`_ is a specialization of
:func:`jax.nn.initializers.variance_scaling` where ``scale = 2.0``,
``mode="fan_in"``, and ``distribution="uniform"``.
Args:
in_axis: axis or sequence of axes of the input dimension in the weights
array.
out_axis: axis or sequence of axes of the output dimension in the weights
array.
batch_axis: axis or sequence of axes in the weight array that should be
ignored.
dtype: the dtype of the weights.
Returns:
An initializer.
Example:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.kaiming_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[ 0.79611576, 1.2789248 , 1.2896855 ],
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
.. _He uniform initializer: https://arxiv.org/abs/1706.02515
"""
return variance_scaling(2.0, "fan_in", "uniform", in_axis=in_axis,
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
kaiming_uniform = he_uniform
def he_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
"""Builds a He normal initializer (aka Kaiming normal initializer).
A `He normal initializer`_ is a specialization of
:func:`jax.nn.initializers.variance_scaling` where ``scale = 2.0``,
``mode="fan_in"``, and ``distribution="truncated_normal"``.
Args:
in_axis: axis or sequence of axes of the input dimension in the weights
array.
out_axis: axis or sequence of axes of the output dimension in the weights
array.
batch_axis: axis or sequence of axes in the weight array that should be
ignored.
dtype: the dtype of the weights.
Returns:
An initializer.
Example:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.kaiming_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[ 0.6604483 , 1.1900088 , 1.2047218 ],
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
.. _He normal initializer: https://arxiv.org/abs/1706.02515
"""
return variance_scaling(2.0, "fan_in", "truncated_normal", in_axis=in_axis,
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
kaiming_normal = he_normal
def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
"""
Builds an initializer that returns uniformly distributed orthogonal matrices.
If the shape is not square, the matrices will have orthonormal rows or columns
depending on which side is smaller.
Args:
scale: the upper bound of the uniform distribution.
column_axis: the axis that contains the columns that should be orthogonal.
dtype: the default dtype of the weights.
Returns:
An orthogonal initializer.
Example:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
@ -199,11 +507,38 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
return init
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
"""
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
Builds an initializer for delta orthogonal kernels.
The shape must be 3D, 4D or 5D.
Args:
scale: the upper bound of the uniform distribution.
column_axis: the axis that contains the columns that should be orthogonal.
dtype: the default dtype of the weights.
Returns:
A `delta orthogonal initializer`_. The shape passed to the initializer must
be 3D, 4D, or 5D.
Example:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]],
<BLANKLINE>
[[ 0.27858758, -0.7949833 , -0.53887904],
[ 0.9120717 , 0.04322892, 0.40774566],
[-0.30085585, -0.6050892 , 0.73712474]],
<BLANKLINE>
[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]]], dtype=float32)
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)

View File

@ -146,7 +146,7 @@ complex128 = cdouble = _make_scalar_type(np.complex128)
int_ = int32 if dtypes.int_ == np.int32 else int64
uint = uint32 if dtypes.uint == np.uint32 else uint64
float_ = float32 if dtypes.float_ == np.float32 else float64
float_: Any = float32 if dtypes.float_ == np.float32 else float64
complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128
generic = np.generic