Merge pull request #13995 from chiamp:compute_fans

PiperOrigin-RevId: 503180128
This commit is contained in:
jax authors 2023-01-19 09:12:54 -08:00
commit 97d23dc9d7
2 changed files with 17 additions and 0 deletions

View File

@ -157,6 +157,10 @@ def _compute_fans(shape: core.NamedShape,
Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the
"receptive field" of a convolution (kernel spatial dimensions).
"""
if shape.rank <= 1:
raise ValueError(f"Can't compute input and output sizes of a {shape.rank}"
"-dimensional weights tensor. Must be at least 2D.")
if isinstance(in_axis, int):
in_size = shape[in_axis]
else:

View File

@ -314,6 +314,19 @@ class NNInitializersTest(jtu.JaxTestCase):
self.assertEqual(shape, jnp.shape(val))
def testVarianceScalingError(self):
rng = random.PRNGKey(0)
shape = (5,)
initializer = nn.initializers.variance_scaling(
scale=1.0, mode='fan_avg', distribution='truncated_normal')
with self.assertRaisesRegex(
ValueError,
"Can't compute input and output sizes of a 1"
"-dimensional weights tensor. Must be at least 2D."
):
initializer(rng, shape)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())