mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
vmap: better errors for mismatched axis in keyword arguments
This commit is contained in:
parent
eb0052bdf2
commit
cb25a96d43
@ -1593,10 +1593,15 @@ def _mapped_axis_size(tree, vals, dims, name, *, kws=False):
|
||||
# in which case we can produce an error message based on argument indices,
|
||||
# or if it has nested containers.
|
||||
if kws:
|
||||
# if keyword arguments are included in the tree, we make adapt the error
|
||||
position_only_tree, leaf = treedef_children(tree)
|
||||
if not treedef_is_leaf(leaf):
|
||||
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
|
||||
sizes = tree_unflatten(tree, sizes)
|
||||
raise ValueError(msg.format(f"the tree of axis sizes is:\n{sizes}")) from None
|
||||
# if keyword arguments are included in the tree, we adapt the error
|
||||
# message only to be about the positional arguments
|
||||
tree, leaf = treedef_children(tree)
|
||||
assert treedef_is_leaf(leaf)
|
||||
tree = position_only_tree
|
||||
|
||||
# TODO(mattjj,phawkins): add a way to inspect pytree kind more directly
|
||||
if tree == tree_flatten((0,) * tree.num_leaves)[1]:
|
||||
lines1 = [f"arg {i} has shape {np.shape(x)} and axis {d} is to be mapped"
|
||||
|
@ -2586,6 +2586,18 @@ class APITest(jtu.JaxTestCase):
|
||||
ans = vfun(lambda x: x + 1, jnp.arange(3))
|
||||
self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False)
|
||||
|
||||
def test_vmap_mismatched_keyword(self):
|
||||
# https://github.com/google/jax/issues/10193
|
||||
@jax.vmap
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "vmap got inconsistent sizes for array axes to be mapped:\n"
|
||||
"the tree of axis sizes is:\n"
|
||||
r"\(\(1,\), \{'y': 2\}\)"):
|
||||
f(jnp.array([1]), y=jnp.array([1, 2]))
|
||||
|
||||
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
|
||||
# https://github.com/google/jax/issues/705
|
||||
def h(a, b):
|
||||
|
Loading…
x
Reference in New Issue
Block a user