mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Merge pull request #25766 from carlosgmartin:nn_initializers_variance_scaling_mode_fan_geo_avg
PiperOrigin-RevId: 721928532
This commit is contained in:
commit
872e6c0ec4
@ -263,7 +263,7 @@ def _complex_truncated_normal(key: Array, upper: ArrayLike,
|
||||
@export
|
||||
def variance_scaling(
|
||||
scale: RealNumeric,
|
||||
mode: Literal["fan_in"] | Literal["fan_out"] | Literal["fan_avg"],
|
||||
mode: Literal["fan_in"] | Literal["fan_out"] | Literal["fan_avg"] | Literal["fan_geo_avg"],
|
||||
distribution: (Literal["truncated_normal"] | Literal["normal"] |
|
||||
Literal["uniform"]),
|
||||
in_axis: int | Sequence[int] = -2,
|
||||
@ -277,11 +277,12 @@ def variance_scaling(
|
||||
With ``distribution="truncated_normal"`` or ``distribution="normal"``, samples
|
||||
are drawn from a (truncated) normal distribution with a mean of zero
|
||||
and a standard deviation (after truncation, if applicable) of
|
||||
:math:`\sqrt{\frac{scale}{n}}`, where `n` is:
|
||||
:math:`\sqrt{\frac{scale}{n}}`, where `n` is, for each ``mode``:
|
||||
|
||||
* the number of input units in the weights tensor, if ``mode="fan_in"``,
|
||||
* the number of output units, if ``mode="fan_out"``, or
|
||||
* the average of the numbers of input and output units, if ``mode="fan_avg"``.
|
||||
* ``"fan_in"``: the number of inputs
|
||||
* ``"fan_out"``: the number of outputs
|
||||
* ``"fan_avg"``: the arithmetic average of the numbers of inputs and outputs
|
||||
* ``"fan_geo_avg"``: the geometric average of the numbers of inputs and outputs
|
||||
|
||||
This initializer can be configured with ``in_axis``, ``out_axis``, and
|
||||
``batch_axis`` to work with general convolutional or dense layers; axes that
|
||||
@ -301,7 +302,7 @@ def variance_scaling(
|
||||
|
||||
Args:
|
||||
scale: scaling factor (positive float).
|
||||
mode: one of ``"fan_in"``, ``"fan_out"``, and ``"fan_avg"``.
|
||||
mode: one of ``"fan_in"``, ``"fan_out"``, ``"fan_avg"``, and ``"fan_geo_avg"``.
|
||||
distribution: random distribution to use. One of ``"truncated_normal"``,
|
||||
``"normal"`` and ``"uniform"``.
|
||||
in_axis: axis or sequence of axes of the input dimension in the weights
|
||||
@ -322,6 +323,7 @@ def variance_scaling(
|
||||
if mode == "fan_in": denominator = fan_in
|
||||
elif mode == "fan_out": denominator = fan_out
|
||||
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
|
||||
elif mode == "fan_geo_avg": denominator = (fan_in * fan_out) ** 0.5
|
||||
else:
|
||||
raise ValueError(
|
||||
f"invalid mode for variance scaling initializer: {mode}")
|
||||
|
@ -601,7 +601,12 @@ INITIALIZER_RECS = [
|
||||
initializer_record("lecun_uniform", nn.initializers.lecun_uniform, jtu.dtypes.inexact),
|
||||
initializer_record("orthogonal", nn.initializers.orthogonal, jtu.dtypes.floating, 2, 2),
|
||||
initializer_record("truncated_normal", nn.initializers.truncated_normal, jtu.dtypes.floating, 1),
|
||||
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
|
||||
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4),
|
||||
initializer_record(
|
||||
"variance_scaling_fan_geo_avg",
|
||||
partial(nn.initializers.variance_scaling, 1, "fan_geo_avg", "normal"),
|
||||
jtu.dtypes.floating,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user