vmap: better errors for mismatched axis in keyword arguments

This commit is contained in:
Jake VanderPlas 2022-06-29 14:31:03 -07:00
parent eb0052bdf2
commit cb25a96d43
2 changed files with 20 additions and 3 deletions

View File

@ -1593,10 +1593,15 @@ def _mapped_axis_size(tree, vals, dims, name, *, kws=False):
# in which case we can produce an error message based on argument indices,
# or if it has nested containers.
if kws:
# if keyword arguments are included in the tree, we make adapt the error
position_only_tree, leaf = treedef_children(tree)
if not treedef_is_leaf(leaf):
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
sizes = tree_unflatten(tree, sizes)
raise ValueError(msg.format(f"the tree of axis sizes is:\n{sizes}")) from None
# if keyword arguments are included in the tree, we adapt the error
# message only to be about the positional arguments
tree, leaf = treedef_children(tree)
assert treedef_is_leaf(leaf)
tree = position_only_tree
# TODO(mattjj,phawkins): add a way to inspect pytree kind more directly
if tree == tree_flatten((0,) * tree.num_leaves)[1]:
lines1 = [f"arg {i} has shape {np.shape(x)} and axis {d} is to be mapped"

View File

@ -2586,6 +2586,18 @@ class APITest(jtu.JaxTestCase):
ans = vfun(lambda x: x + 1, jnp.arange(3))
self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False)
def test_vmap_mismatched_keyword(self):
# https://github.com/google/jax/issues/10193
@jax.vmap
def f(x, y):
return x + y
with self.assertRaisesRegex(
ValueError, "vmap got inconsistent sizes for array axes to be mapped:\n"
"the tree of axis sizes is:\n"
r"\(\(1,\), \{'y': 2\}\)"):
f(jnp.array([1]), y=jnp.array([1, 2]))
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
# https://github.com/google/jax/issues/705
def h(a, b):