mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
38 lines
813 B
ReStructuredText
38 lines
813 B
ReStructuredText
|
|
jax.nn.initializers package
|
|
===========================
|
|
|
|
.. currentmodule:: jax.nn.initializers
|
|
|
|
.. automodule:: jax.nn.initializers
|
|
|
|
|
|
Initializers
|
|
------------
|
|
|
|
This module provides common neural network layer initializers,
|
|
consistent with definitions used in Keras and Sonnet.
|
|
|
|
An initializer is a function that takes three arguments:
|
|
``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and
|
|
data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random
|
|
key used when generating random numbers to initialize the array.
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
constant
|
|
delta_orthogonal
|
|
glorot_normal
|
|
glorot_uniform
|
|
he_normal
|
|
he_uniform
|
|
lecun_normal
|
|
lecun_uniform
|
|
normal
|
|
ones
|
|
orthogonal
|
|
uniform
|
|
variance_scaling
|
|
zeros
|