Add truncated normal initializer to jax.nn

PiperOrigin-RevId: 563576354
This commit is contained in:
jax authors 2023-09-07 16:23:03 -07:00
parent bfd79b84e4
commit 311dc9cfde
4 changed files with 41 additions and 1 deletions

View File

@ -31,6 +31,7 @@ key used when generating random numbers to initialize the array.
normal
ones
orthogonal
truncated_normal
uniform
variance_scaling
zeros

View File

@ -155,6 +155,43 @@ def normal(stddev: RealNumeric = 1e-2,
return random.normal(key, shape, dtype) * stddev
return init
@export
def truncated_normal(stddev: RealNumeric = 1e-2,
dtype: DTypeLikeInexact = jnp.float_,
lower: RealNumeric = -2.0,
upper: RealNumeric = 2.0) -> Initializer:
r"""Builds an initializer that returns truncated-normal random arrays.
Args:
stddev: optional; the standard deviation of the untruncated distribution.
Note that this function does not apply the stddev correction as is done in
the variancescaling initializers, and users are expected to apply this
correction themselves via the stddev arg if they wish to employ it.
dtype: optional; the initializer's default dtype.
min_val: Float representing the lower bound for truncation. Applied before
the output is multiplied by the stddev.
max_val: Float representing the upper bound for truncation. Applied before
the output is multiplied by the stddev.
Returns:
An initializer that returns arrays whose values follow the truncated normal
distribution with mean ``0`` and standard deviation ``stddev``, and range
:math:`\rm{lower * stddev} < x < \rm{upper * stddev}`.
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.truncated_normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
Array([[ 2.9047365, 5.2338114, 5.29852 ],
[-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
return random.truncated_normal(key, lower, upper, shape, dtype) * stddev
return init
@export
def _compute_fans(shape: core.NamedShape,
in_axis: Union[int, Sequence[int]] = -2,

View File

@ -35,6 +35,7 @@ from jax._src.nn.initializers import (
normal as normal,
ones as ones,
orthogonal as orthogonal,
truncated_normal as truncated_normal,
uniform as uniform,
variance_scaling as variance_scaling,
xavier_normal as xavier_normal,

View File

@ -298,6 +298,7 @@ INITIALIZER_RECS = [
initializer_record("lecun_normal", nn.initializers.lecun_normal, jtu.dtypes.inexact),
initializer_record("lecun_uniform", nn.initializers.lecun_uniform, jtu.dtypes.inexact),
initializer_record("orthogonal", nn.initializers.orthogonal, jtu.dtypes.floating, 2, 2),
initializer_record("truncated_normal", nn.initializers.truncated_normal, jtu.dtypes.floating, 1),
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
]