diff --git a/docs/jax.nn.initializers.rst b/docs/jax.nn.initializers.rst index 9b864289f..d96ba43f0 100644 --- a/docs/jax.nn.initializers.rst +++ b/docs/jax.nn.initializers.rst @@ -31,6 +31,7 @@ key used when generating random numbers to initialize the array. normal ones orthogonal + truncated_normal uniform variance_scaling zeros diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 3edd404c9..3d2cd4852 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -147,7 +147,7 @@ def normal(stddev: RealNumeric = 1e-2, >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) - """ + """ def init(key: KeyArray, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: @@ -155,6 +155,43 @@ def normal(stddev: RealNumeric = 1e-2, return random.normal(key, shape, dtype) * stddev return init +@export +def truncated_normal(stddev: RealNumeric = 1e-2, + dtype: DTypeLikeInexact = jnp.float_, + lower: RealNumeric = -2.0, + upper: RealNumeric = 2.0) -> Initializer: + r"""Builds an initializer that returns truncated-normal random arrays. + + Args: + stddev: optional; the standard deviation of the untruncated distribution. + Note that this function does not apply the stddev correction as is done in + the variancescaling initializers, and users are expected to apply this + correction themselves via the stddev arg if they wish to employ it. + dtype: optional; the initializer's default dtype. + min_val: Float representing the lower bound for truncation. Applied before + the output is multiplied by the stddev. + max_val: Float representing the upper bound for truncation. Applied before + the output is multiplied by the stddev. + + Returns: + An initializer that returns arrays whose values follow the truncated normal + distribution with mean ``0`` and standard deviation ``stddev``, and range + :math:`\rm{lower * stddev} < x < \rm{upper * stddev}`. + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.truncated_normal(5.0) + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + Array([[ 2.9047365, 5.2338114, 5.29852 ], + [-3.836303 , -4.192359 , 0.6022964]], dtype=float32) + """ + + def init(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = dtype) -> Array: + dtype = dtypes.canonicalize_dtype(dtype) + return random.truncated_normal(key, lower, upper, shape, dtype) * stddev + return init + @export def _compute_fans(shape: core.NamedShape, in_axis: Union[int, Sequence[int]] = -2, diff --git a/jax/nn/initializers.py b/jax/nn/initializers.py index c1d7c936b..6c73356ce 100644 --- a/jax/nn/initializers.py +++ b/jax/nn/initializers.py @@ -35,6 +35,7 @@ from jax._src.nn.initializers import ( normal as normal, ones as ones, orthogonal as orthogonal, + truncated_normal as truncated_normal, uniform as uniform, variance_scaling as variance_scaling, xavier_normal as xavier_normal, diff --git a/tests/nn_test.py b/tests/nn_test.py index 1a931e55e..5a3524e80 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -298,6 +298,7 @@ INITIALIZER_RECS = [ initializer_record("lecun_normal", nn.initializers.lecun_normal, jtu.dtypes.inexact), initializer_record("lecun_uniform", nn.initializers.lecun_uniform, jtu.dtypes.inexact), initializer_record("orthogonal", nn.initializers.orthogonal, jtu.dtypes.floating, 2, 2), + initializer_record("truncated_normal", nn.initializers.truncated_normal, jtu.dtypes.floating, 1), initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4) ]