Add batch_axis to variance scaling initializers

PiperOrigin-RevId: 426522731
This commit is contained in:
James Bradbury 2022-02-04 17:01:36 -08:00 committed by jax authors
parent 086a607d8c
commit 5dd1c75969
2 changed files with 32 additions and 4 deletions

View File

@ -44,7 +44,14 @@ def normal(stddev=1e-2, dtype=jnp.float_):
return random.normal(key, shape, dtype) * stddev
return init
def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1):
def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1,
batch_axis=()):
"""
Compute effective input and output sizes for a linear or convolutional layer.
Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the
"receptive field" of a convolution (kernel spatial dimensions).
"""
if isinstance(in_axis, int):
in_size = shape[in_axis]
else:
@ -53,7 +60,11 @@ def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1):
out_size = shape[out_axis]
else:
out_size = int(np.prod([shape[i] for i in out_axis]))
receptive_field_size = shape.total / in_size / out_size
if isinstance(batch_axis, int):
batch_size = shape[batch_axis]
else:
batch_size = int(np.prod([shape[i] for i in batch_axis]))
receptive_field_size = shape.total / in_size / out_size / batch_size
fan_in = in_size * receptive_field_size
fan_out = out_size * receptive_field_size
return fan_in, fan_out
@ -81,7 +92,8 @@ def _complex_truncated_normal(key, upper, shape, dtype):
theta = 2 * jnp.pi * random.uniform(key_theta, shape, dtype)
return r * jnp.exp(1j * theta)
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1,
batch_axis=(), dtype=jnp.float_):
"""
Initializer capable of adapting its scale to the shape of the weights tensor.
@ -93,6 +105,11 @@ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=j
- number of output units, if `mode="fan_out"`
- average of the numbers of input and output units, if `mode="fan_avg"`
This initializer can be configured with in_axis, out_axis, and batch_axis to
work with general convolutional or dense layers; axes that are not in any of
those arguments are assumed to be the "receptive field" (convolution kernel
spatial axes).
With `distribution="truncated_normal"`, the absolute values of the samples are
truncated below 2 standard deviations before truncation.
@ -108,13 +125,14 @@ def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=j
"normal" and "uniform".
in_axis: axis or sequence of axes of the input dimension in the weights tensor.
out_axis: axis or sequence of axes of the output dimension in the weights tensor.
batch_axis: axis or sequence of axes in the weight tensor that should be ignored.
dtype: the dtype of the weights.
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
if mode == "fan_in": denominator = fan_in
elif mode == "fan_out": denominator = fan_out
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2

View File

@ -277,6 +277,16 @@ class NNInitializersTest(jtu.JaxTestCase):
self.assertEqual(shape, jnp.shape(val))
def testVarianceScalingBatchAxis(self):
rng = random.PRNGKey(0)
shape = (2, 3, 4, 5)
initializer = nn.initializers.variance_scaling(
scale=1.0, mode='fan_avg', distribution='truncated_normal',
in_axis=0, out_axis=(2, 3), batch_axis=1)
val = initializer(rng, shape)
self.assertEqual(shape, jnp.shape(val))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())