mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13995 from chiamp:compute_fans
PiperOrigin-RevId: 503180128
This commit is contained in:
commit
97d23dc9d7
@ -157,6 +157,10 @@ def _compute_fans(shape: core.NamedShape,
|
||||
Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the
|
||||
"receptive field" of a convolution (kernel spatial dimensions).
|
||||
"""
|
||||
if shape.rank <= 1:
|
||||
raise ValueError(f"Can't compute input and output sizes of a {shape.rank}"
|
||||
"-dimensional weights tensor. Must be at least 2D.")
|
||||
|
||||
if isinstance(in_axis, int):
|
||||
in_size = shape[in_axis]
|
||||
else:
|
||||
|
@ -314,6 +314,19 @@ class NNInitializersTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(shape, jnp.shape(val))
|
||||
|
||||
def testVarianceScalingError(self):
|
||||
rng = random.PRNGKey(0)
|
||||
shape = (5,)
|
||||
initializer = nn.initializers.variance_scaling(
|
||||
scale=1.0, mode='fan_avg', distribution='truncated_normal')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Can't compute input and output sizes of a 1"
|
||||
"-dimensional weights tensor. Must be at least 2D."
|
||||
):
|
||||
initializer(rng, shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user