mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add test case per reviewer comment
This commit is contained in:
parent
213b899ef1
commit
d09571ebce
@ -1117,6 +1117,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
X = onp.random.randn(10, 4)
|
||||
U = onp.random.randn(10, 2)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
@ -1127,6 +1128,17 @@ class APITest(jtu.JaxTestCase):
|
||||
"arg 1 has an axis to be mapped of size 2",
|
||||
lambda: api.vmap(h, in_axes=(0, 1))(X, U))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
||||
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
|
||||
r"arg 2 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
||||
"so\n"
|
||||
"args 0, 2 have axes to be mapped of size 10\n"
|
||||
"arg 1 has an axis to be mapped of size 2",
|
||||
lambda: api.vmap(lambda x, y, z: None, in_axes=(0, 1, 0))(X, U, X))
|
||||
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
|
Loading…
x
Reference in New Issue
Block a user