mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Document jax.nn.initializers.
This commit is contained in:
parent
7061521137
commit
d3d666d081
@ -13,17 +13,25 @@ Initializers
|
||||
This module provides common neural network layer initializers,
|
||||
consistent with definitions used in Keras and Sonnet.
|
||||
|
||||
An initializer is a function that takes three arguments:
|
||||
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
|
||||
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
|
||||
key used when generating random numbers to initialize the array.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
zeros
|
||||
ones
|
||||
uniform
|
||||
normal
|
||||
variance_scaling
|
||||
glorot_uniform
|
||||
constant
|
||||
delta_orthogonal
|
||||
glorot_normal
|
||||
lecun_uniform
|
||||
lecun_normal
|
||||
he_uniform
|
||||
glorot_uniform
|
||||
he_normal
|
||||
he_uniform
|
||||
lecun_normal
|
||||
lecun_uniform
|
||||
normal
|
||||
ones
|
||||
orthogonal
|
||||
uniform
|
||||
variance_scaling
|
||||
zeros
|
||||
|
@ -18,7 +18,7 @@ used in Keras and Sonnet.
|
||||
"""
|
||||
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -29,22 +29,90 @@ from jax import core
|
||||
from jax._src.util import prod
|
||||
from jax import dtypes
|
||||
|
||||
def zeros(key, shape, dtype=jnp.float_): return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
|
||||
def ones(key, shape, dtype=jnp.float_): return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
|
||||
DType = Any
|
||||
|
||||
def constant(value, dtype=jnp.float_):
|
||||
def zeros(key, shape, dtype: DType = jnp.float_):
|
||||
"""An initializer that returns a constant array full of zeros.
|
||||
|
||||
The ``key`` argument is ignored.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
"""
|
||||
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def ones(key, shape, dtype: DType = jnp.float_):
|
||||
"""An initializer that returns a constant array full of ones.
|
||||
|
||||
The ``key`` argument is ignored.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
|
||||
DeviceArray([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]], dtype=float32)
|
||||
"""
|
||||
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def constant(value, dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds an initializer that returns arrays full of a constant ``value``.
|
||||
|
||||
Args:
|
||||
value: the constant value with which to fill the initializer.
|
||||
dtype: optional; the initializer's default dtype.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.constant(-7)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[-7., -7., -7.],
|
||||
[-7., -7., -7.]], dtype=float32)
|
||||
"""
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
return jnp.full(shape, value, dtype=dtype)
|
||||
return init
|
||||
|
||||
def uniform(scale=1e-2, dtype=jnp.float_):
|
||||
def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds an initializer that returns real uniformly-distributed random arrays.
|
||||
|
||||
Args:
|
||||
scale: optional; the upper bound of the random distribution.
|
||||
dtype: optional; the initializer's default dtype.
|
||||
|
||||
Returns:
|
||||
An initializer that returns arrays whose values are uniformly distributed in
|
||||
the range ``[0, scale)``.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.uniform(10.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
|
||||
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
||||
"""
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
return random.uniform(key, shape, dtype) * scale
|
||||
return init
|
||||
|
||||
def normal(stddev=1e-2, dtype=jnp.float_):
|
||||
def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds an initializer that returns real normally-distributed random arrays.
|
||||
|
||||
Args:
|
||||
stddev: optional; the standard deviation of the distribution.
|
||||
dtype: optional; the initializer's default dtype.
|
||||
|
||||
Returns:
|
||||
An initializer that returns arrays whose values are normally distributed
|
||||
with mean ``0`` and standard deviation ``stddev``.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.normal(5.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
||||
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
||||
"""
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
return random.normal(key, shape, dtype) * stddev
|
||||
@ -98,40 +166,49 @@ def _complex_truncated_normal(key, upper, shape, dtype):
|
||||
theta = 2 * jnp.pi * random.uniform(key_theta, shape, dtype)
|
||||
return r * jnp.exp(1j * theta)
|
||||
|
||||
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1,
|
||||
batch_axis=(), dtype=jnp.float_):
|
||||
"""
|
||||
Initializer capable of adapting its scale to the shape of the weights tensor.
|
||||
def variance_scaling(scale, mode: str, distribution: str,
|
||||
in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
r"""
|
||||
Initializer that adapts its scale to the shape of the weights tensor.
|
||||
|
||||
With `distribution="truncated_normal" or "normal"`, samples are
|
||||
drawn from a truncated/untruncated normal distribution with a mean of zero and
|
||||
a standard deviation (after truncation, if used) `stddev = sqrt(scale / n)`,
|
||||
where `n` is:
|
||||
- number of input units in the weights tensor, if `mode="fan_in"`
|
||||
- number of output units, if `mode="fan_out"`
|
||||
- average of the numbers of input and output units, if `mode="fan_avg"`
|
||||
With ``distribution="truncated_normal"`` or ``distribution="normal"``, samples
|
||||
are drawn from a (truncated) normal distribution with a mean of zero
|
||||
and a standard deviation (after truncation, if applicable) of
|
||||
:math:`\sqrt{\frac{scale}{n}}`, where `n` is:
|
||||
|
||||
This initializer can be configured with in_axis, out_axis, and batch_axis to
|
||||
work with general convolutional or dense layers; axes that are not in any of
|
||||
those arguments are assumed to be the "receptive field" (convolution kernel
|
||||
spatial axes).
|
||||
* the number of input units in the weights tensor, if ``mode="fan_in"``,
|
||||
* the number of output units, if ``mode="fan_out"``, or
|
||||
* the average of the numbers of input and output units, if ``mode="fan_avg"``.
|
||||
|
||||
With `distribution="truncated_normal"`, the absolute values of the samples are
|
||||
truncated below 2 standard deviations before truncation.
|
||||
This initializer can be configured with ``in_axis``, ``out_axis``, and
|
||||
``batch_axis`` to work with general convolutional or dense layers; axes that
|
||||
are not in any of those arguments are assumed to be the "receptive field"
|
||||
(convolution kernel spatial axes).
|
||||
|
||||
With `distribution="uniform"`, samples are drawn from:
|
||||
- a uniform interval, if `dtype` is real
|
||||
- a uniform disk, if `dtype` is complex
|
||||
with a mean of zero and a standard deviation of `stddev`.
|
||||
With ``distribution="truncated_normal"``, the absolute values of the samples
|
||||
are truncated at 2 standard deviations before scaling.
|
||||
|
||||
With ``distribution="uniform"``, samples are drawn from:
|
||||
|
||||
* a uniform interval, if `dtype` is real, or
|
||||
* a uniform disk, if `dtype` is complex,
|
||||
|
||||
with a mean of zero and a standard deviation of ``stddev``.
|
||||
|
||||
Args:
|
||||
scale: scaling factor (positive float).
|
||||
mode: one of "fan_in", "fan_out", and "fan_avg".
|
||||
distribution: random distribution to use. One of "truncated_normal",
|
||||
"normal" and "uniform".
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights tensor.
|
||||
out_axis: axis or sequence of axes of the output dimension in the weights tensor.
|
||||
batch_axis: axis or sequence of axes in the weight tensor that should be ignored.
|
||||
mode: one of ``"fan_in"``, ``"fan_out"``, and ``"fan_avg"``.
|
||||
distribution: random distribution to use. One of ``"truncated_normal"``,
|
||||
``"normal"`` and ``"uniform"``.
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights
|
||||
array.
|
||||
out_axis: axis or sequence of axes of the output dimension in the weights
|
||||
array.
|
||||
batch_axis: axis or sequence of axes in the weight array that should be
|
||||
ignored.
|
||||
dtype: the dtype of the weights.
|
||||
"""
|
||||
|
||||
@ -168,19 +245,250 @@ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1,
|
||||
|
||||
return init
|
||||
|
||||
xavier_uniform = glorot_uniform = partial(variance_scaling, 1.0, "fan_avg", "uniform")
|
||||
xavier_normal = glorot_normal = partial(variance_scaling, 1.0, "fan_avg", "truncated_normal")
|
||||
lecun_uniform = partial(variance_scaling, 1.0, "fan_in", "uniform")
|
||||
lecun_normal = partial(variance_scaling, 1.0, "fan_in", "truncated_normal")
|
||||
kaiming_uniform = he_uniform = partial(variance_scaling, 2.0, "fan_in", "uniform")
|
||||
kaiming_normal = he_normal = partial(variance_scaling, 2.0, "fan_in", "truncated_normal")
|
||||
def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a Glorot uniform initializer (aka Xavier uniform initializer).
|
||||
|
||||
def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
|
||||
A `Glorot uniform initializer`_ is a specialization of
|
||||
:func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,
|
||||
``mode="fan_avg"``, and ``distribution="uniform"``.
|
||||
|
||||
Args:
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights
|
||||
array.
|
||||
out_axis: axis or sequence of axes of the output dimension in the weights
|
||||
array.
|
||||
batch_axis: axis or sequence of axes in the weight array that should be
|
||||
ignored.
|
||||
dtype: the dtype of the weights.
|
||||
|
||||
Returns:
|
||||
An initializer.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[ 0.50350785, 0.8088631 , 0.81566876],
|
||||
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
|
||||
|
||||
.. _Glorot uniform initializer: http://proceedings.mlr.press/v9/glorot10a.html
|
||||
"""
|
||||
Construct an initializer for uniformly distributed orthogonal matrices.
|
||||
return variance_scaling(1.0, "fan_avg", "uniform", in_axis=in_axis,
|
||||
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
|
||||
|
||||
xavier_uniform = glorot_uniform
|
||||
|
||||
|
||||
def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a Glorot normal initializer (aka Xavier normal initializer).
|
||||
|
||||
A `Glorot normal initializer`_ is a specialization of
|
||||
:func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,
|
||||
``mode="fan_avg"``, and ``distribution="truncated_normal"``.
|
||||
|
||||
Args:
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights
|
||||
array.
|
||||
out_axis: axis or sequence of axes of the output dimension in the weights
|
||||
array.
|
||||
batch_axis: axis or sequence of axes in the weight array that should be
|
||||
ignored.
|
||||
dtype: the dtype of the weights.
|
||||
|
||||
Returns:
|
||||
An initializer.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[ 0.41770416, 0.75262755, 0.7619329 ],
|
||||
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
|
||||
|
||||
.. _Glorot normal initializer: http://proceedings.mlr.press/v9/glorot10a.html
|
||||
"""
|
||||
return variance_scaling(1.0, "fan_avg", "truncated_normal", in_axis=in_axis,
|
||||
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
|
||||
|
||||
xavier_normal = glorot_normal
|
||||
|
||||
def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a Lecun uniform initializer.
|
||||
|
||||
A `Lecun uniform initializer`_ is a specialization of
|
||||
:func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,
|
||||
``mode="fan_in"``, and ``distribution="uniform"``.
|
||||
|
||||
Args:
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights
|
||||
array.
|
||||
out_axis: axis or sequence of axes of the output dimension in the weights
|
||||
array.
|
||||
batch_axis: axis or sequence of axes in the weight array that should be
|
||||
ignored.
|
||||
dtype: the dtype of the weights.
|
||||
|
||||
Returns:
|
||||
An initializer.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[ 0.56293887, 0.90433645, 0.9119454 ],
|
||||
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
|
||||
|
||||
.. _Lecun uniform initializer: https://arxiv.org/abs/1706.02515
|
||||
"""
|
||||
return variance_scaling(1.0, "fan_in", "uniform", in_axis=in_axis,
|
||||
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
|
||||
|
||||
def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a Lecun normal initializer.
|
||||
|
||||
A `Lecun normal initializer`_ is a specialization of
|
||||
:func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``,
|
||||
``mode="fan_in"``, and ``distribution="truncated_normal"``.
|
||||
|
||||
Args:
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights
|
||||
array.
|
||||
out_axis: axis or sequence of axes of the output dimension in the weights
|
||||
array.
|
||||
batch_axis: axis or sequence of axes in the weight array that should be
|
||||
ignored.
|
||||
dtype: the dtype of the weights.
|
||||
|
||||
Returns:
|
||||
An initializer.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[ 0.46700746, 0.8414632 , 0.8518669 ],
|
||||
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
|
||||
|
||||
.. _Lecun normal initializer: https://arxiv.org/abs/1706.02515
|
||||
"""
|
||||
return variance_scaling(1.0, "fan_in", "truncated_normal", in_axis=in_axis,
|
||||
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
|
||||
|
||||
|
||||
def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a He uniform initializer (aka Kaiming uniform initializer).
|
||||
|
||||
A `He uniform initializer`_ is a specialization of
|
||||
:func:`jax.nn.initializers.variance_scaling` where ``scale = 2.0``,
|
||||
``mode="fan_in"``, and ``distribution="uniform"``.
|
||||
|
||||
Args:
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights
|
||||
array.
|
||||
out_axis: axis or sequence of axes of the output dimension in the weights
|
||||
array.
|
||||
batch_axis: axis or sequence of axes in the weight array that should be
|
||||
ignored.
|
||||
dtype: the dtype of the weights.
|
||||
|
||||
Returns:
|
||||
An initializer.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.kaiming_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[ 0.79611576, 1.2789248 , 1.2896855 ],
|
||||
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
|
||||
|
||||
.. _He uniform initializer: https://arxiv.org/abs/1706.02515
|
||||
"""
|
||||
return variance_scaling(2.0, "fan_in", "uniform", in_axis=in_axis,
|
||||
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
|
||||
|
||||
kaiming_uniform = he_uniform
|
||||
|
||||
|
||||
def he_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a He normal initializer (aka Kaiming normal initializer).
|
||||
|
||||
A `He normal initializer`_ is a specialization of
|
||||
:func:`jax.nn.initializers.variance_scaling` where ``scale = 2.0``,
|
||||
``mode="fan_in"``, and ``distribution="truncated_normal"``.
|
||||
|
||||
Args:
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights
|
||||
array.
|
||||
out_axis: axis or sequence of axes of the output dimension in the weights
|
||||
array.
|
||||
batch_axis: axis or sequence of axes in the weight array that should be
|
||||
ignored.
|
||||
dtype: the dtype of the weights.
|
||||
|
||||
Returns:
|
||||
An initializer.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.kaiming_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[ 0.6604483 , 1.1900088 , 1.2047218 ],
|
||||
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
|
||||
|
||||
.. _He normal initializer: https://arxiv.org/abs/1706.02515
|
||||
"""
|
||||
return variance_scaling(2.0, "fan_in", "truncated_normal", in_axis=in_axis,
|
||||
out_axis=out_axis, batch_axis=batch_axis, dtype=dtype)
|
||||
|
||||
kaiming_normal = he_normal
|
||||
|
||||
|
||||
def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
|
||||
"""
|
||||
Builds an initializer that returns uniformly distributed orthogonal matrices.
|
||||
|
||||
If the shape is not square, the matrices will have orthonormal rows or columns
|
||||
depending on which side is smaller.
|
||||
|
||||
Args:
|
||||
scale: the upper bound of the uniform distribution.
|
||||
column_axis: the axis that contains the columns that should be orthogonal.
|
||||
dtype: the default dtype of the weights.
|
||||
|
||||
Returns:
|
||||
An orthogonal initializer.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.orthogonal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
|
||||
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
|
||||
"""
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
@ -199,11 +507,38 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
|
||||
return init
|
||||
|
||||
|
||||
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
|
||||
def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
|
||||
"""
|
||||
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
|
||||
Builds an initializer for delta orthogonal kernels.
|
||||
|
||||
The shape must be 3D, 4D or 5D.
|
||||
Args:
|
||||
scale: the upper bound of the uniform distribution.
|
||||
column_axis: the axis that contains the columns that should be orthogonal.
|
||||
dtype: the default dtype of the weights.
|
||||
|
||||
Returns:
|
||||
A `delta orthogonal initializer`_. The shape passed to the initializer must
|
||||
be 3D, 4D, or 5D.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.delta_orthogonal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ]],
|
||||
<BLANKLINE>
|
||||
[[ 0.27858758, -0.7949833 , -0.53887904],
|
||||
[ 0.9120717 , 0.04322892, 0.40774566],
|
||||
[-0.30085585, -0.6050892 , 0.73712474]],
|
||||
<BLANKLINE>
|
||||
[[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ]]], dtype=float32)
|
||||
|
||||
|
||||
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
|
||||
"""
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
@ -146,7 +146,7 @@ complex128 = cdouble = _make_scalar_type(np.complex128)
|
||||
|
||||
int_ = int32 if dtypes.int_ == np.int32 else int64
|
||||
uint = uint32 if dtypes.uint == np.uint32 else uint64
|
||||
float_ = float32 if dtypes.float_ == np.float32 else float64
|
||||
float_: Any = float32 if dtypes.float_ == np.float32 else float64
|
||||
complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128
|
||||
|
||||
generic = np.generic
|
||||
|
Loading…
x
Reference in New Issue
Block a user