add test case per reviewer comment

This commit is contained in:
Matthew Johnson 2019-10-31 13:20:32 -07:00
parent 213b899ef1
commit d09571ebce

View File

@ -1117,6 +1117,7 @@ class APITest(jtu.JaxTestCase):
X = onp.random.randn(10, 4)
U = onp.random.randn(10, 2)
self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
@ -1127,6 +1128,17 @@ class APITest(jtu.JaxTestCase):
"arg 1 has an axis to be mapped of size 2",
lambda: api.vmap(h, in_axes=(0, 1))(X, U))
self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
r"arg 2 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
"so\n"
"args 0, 2 have axes to be mapped of size 10\n"
"arg 1 has an axis to be mapped of size 2",
lambda: api.vmap(lambda x, y, z: None, in_axes=(0, 1, 0))(X, U, X))
self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"