Merge pull request #25766 from carlosgmartin:nn_initializers_variance_scaling_mode_fan_geo_avg

PiperOrigin-RevId: 721928532
This commit is contained in:
jax authors 2025-01-31 15:41:50 -08:00
commit 872e6c0ec4
2 changed files with 14 additions and 7 deletions

View File

@ -263,7 +263,7 @@ def _complex_truncated_normal(key: Array, upper: ArrayLike,
@export
def variance_scaling(
scale: RealNumeric,
mode: Literal["fan_in"] | Literal["fan_out"] | Literal["fan_avg"],
mode: Literal["fan_in"] | Literal["fan_out"] | Literal["fan_avg"] | Literal["fan_geo_avg"],
distribution: (Literal["truncated_normal"] | Literal["normal"] |
Literal["uniform"]),
in_axis: int | Sequence[int] = -2,
@ -277,11 +277,12 @@ def variance_scaling(
With ``distribution="truncated_normal"`` or ``distribution="normal"``, samples
are drawn from a (truncated) normal distribution with a mean of zero
and a standard deviation (after truncation, if applicable) of
:math:`\sqrt{\frac{scale}{n}}`, where `n` is:
:math:`\sqrt{\frac{scale}{n}}`, where `n` is, for each ``mode``:
* the number of input units in the weights tensor, if ``mode="fan_in"``,
* the number of output units, if ``mode="fan_out"``, or
* the average of the numbers of input and output units, if ``mode="fan_avg"``.
* ``"fan_in"``: the number of inputs
* ``"fan_out"``: the number of outputs
* ``"fan_avg"``: the arithmetic average of the numbers of inputs and outputs
* ``"fan_geo_avg"``: the geometric average of the numbers of inputs and outputs
This initializer can be configured with ``in_axis``, ``out_axis``, and
``batch_axis`` to work with general convolutional or dense layers; axes that
@ -301,7 +302,7 @@ def variance_scaling(
Args:
scale: scaling factor (positive float).
mode: one of ``"fan_in"``, ``"fan_out"``, and ``"fan_avg"``.
mode: one of ``"fan_in"``, ``"fan_out"``, ``"fan_avg"``, and ``"fan_geo_avg"``.
distribution: random distribution to use. One of ``"truncated_normal"``,
``"normal"`` and ``"uniform"``.
in_axis: axis or sequence of axes of the input dimension in the weights
@ -322,6 +323,7 @@ def variance_scaling(
if mode == "fan_in": denominator = fan_in
elif mode == "fan_out": denominator = fan_out
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
elif mode == "fan_geo_avg": denominator = (fan_in * fan_out) ** 0.5
else:
raise ValueError(
f"invalid mode for variance scaling initializer: {mode}")

View File

@ -601,7 +601,12 @@ INITIALIZER_RECS = [
initializer_record("lecun_uniform", nn.initializers.lecun_uniform, jtu.dtypes.inexact),
initializer_record("orthogonal", nn.initializers.orthogonal, jtu.dtypes.floating, 2, 2),
initializer_record("truncated_normal", nn.initializers.truncated_normal, jtu.dtypes.floating, 1),
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4),
initializer_record(
"variance_scaling_fan_geo_avg",
partial(nn.initializers.variance_scaling, 1, "fan_geo_avg", "normal"),
jtu.dtypes.floating,
),
]