mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add truncated normal initializer to jax.nn
PiperOrigin-RevId: 563576354
This commit is contained in:
parent
bfd79b84e4
commit
311dc9cfde
@ -31,6 +31,7 @@ key used when generating random numbers to initialize the array.
|
||||
normal
|
||||
ones
|
||||
orthogonal
|
||||
truncated_normal
|
||||
uniform
|
||||
variance_scaling
|
||||
zeros
|
||||
|
@ -155,6 +155,43 @@ def normal(stddev: RealNumeric = 1e-2,
|
||||
return random.normal(key, shape, dtype) * stddev
|
||||
return init
|
||||
|
||||
@export
|
||||
def truncated_normal(stddev: RealNumeric = 1e-2,
|
||||
dtype: DTypeLikeInexact = jnp.float_,
|
||||
lower: RealNumeric = -2.0,
|
||||
upper: RealNumeric = 2.0) -> Initializer:
|
||||
r"""Builds an initializer that returns truncated-normal random arrays.
|
||||
|
||||
Args:
|
||||
stddev: optional; the standard deviation of the untruncated distribution.
|
||||
Note that this function does not apply the stddev correction as is done in
|
||||
the variancescaling initializers, and users are expected to apply this
|
||||
correction themselves via the stddev arg if they wish to employ it.
|
||||
dtype: optional; the initializer's default dtype.
|
||||
min_val: Float representing the lower bound for truncation. Applied before
|
||||
the output is multiplied by the stddev.
|
||||
max_val: Float representing the upper bound for truncation. Applied before
|
||||
the output is multiplied by the stddev.
|
||||
|
||||
Returns:
|
||||
An initializer that returns arrays whose values follow the truncated normal
|
||||
distribution with mean ``0`` and standard deviation ``stddev``, and range
|
||||
:math:`\rm{lower * stddev} < x < \rm{upper * stddev}`.
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.truncated_normal(5.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
Array([[ 2.9047365, 5.2338114, 5.29852 ],
|
||||
[-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
|
||||
"""
|
||||
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = dtype) -> Array:
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
return random.truncated_normal(key, lower, upper, shape, dtype) * stddev
|
||||
return init
|
||||
|
||||
@export
|
||||
def _compute_fans(shape: core.NamedShape,
|
||||
in_axis: Union[int, Sequence[int]] = -2,
|
||||
|
@ -35,6 +35,7 @@ from jax._src.nn.initializers import (
|
||||
normal as normal,
|
||||
ones as ones,
|
||||
orthogonal as orthogonal,
|
||||
truncated_normal as truncated_normal,
|
||||
uniform as uniform,
|
||||
variance_scaling as variance_scaling,
|
||||
xavier_normal as xavier_normal,
|
||||
|
@ -298,6 +298,7 @@ INITIALIZER_RECS = [
|
||||
initializer_record("lecun_normal", nn.initializers.lecun_normal, jtu.dtypes.inexact),
|
||||
initializer_record("lecun_uniform", nn.initializers.lecun_uniform, jtu.dtypes.inexact),
|
||||
initializer_record("orthogonal", nn.initializers.orthogonal, jtu.dtypes.floating, 2, 2),
|
||||
initializer_record("truncated_normal", nn.initializers.truncated_normal, jtu.dtypes.floating, 1),
|
||||
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
|
||||
]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user