improve vmap error messages

fixes #705
This commit is contained in:
Matthew Johnson 2019-10-30 17:31:37 -07:00
parent cbadfd41ce
commit eae47b2330
3 changed files with 58 additions and 4 deletions

View File

@ -647,12 +647,43 @@ def vmap(fun, in_axes=0, out_axes=0):
"or a (nested) tuple of those types, got {} and {} respectively.")
raise TypeError(msg.format(type(in_axes), type(out_axes)))
def _check_axis_sizes(tree, vals, dims):
try:
sizes, = {x.shape[d] for x, d in zip(vals, dims) if d is not None}
except ValueError:
msg = "vmap got inconsistent sizes for array axes to be mapped:\n{}"
# we switch the error message based on whether args is a tuple of arrays,
# in which case we can produce an error message based on argument indices,
# or if it has nested containers.
# TODO(mattjj,phawkins): add a way to inspect pytree kind more directly
if tree == tree_flatten((core.unit,) * tree.num_leaves)[1]:
lines1 = ["arg {} has shape {} and axis {} is to be mapped"
.format(i, x.shape, d) for i, (x, d) in enumerate(zip(vals, dims))]
sizes = collections.defaultdict(list)
for i, (x, d) in enumerate(zip(vals, dims)):
if d is not None:
sizes[x.shape[d]].append(i)
lines2 = ["{} {} {} {} to be mapped of size {}".format(
"args" if len(idxs) > 1 else "arg",
", ".join(map(str, idxs)),
"have" if len(idxs) > 1 else "has",
"axes" if len(idxs) > 1 else "an axis",
size)
for size, idxs in sizes.items()]
raise ValueError(msg.format("\n".join(lines1 + ["so"] + lines2)))
else:
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
sizes = tree_unflatten(tree, sizes)
raise ValueError(msg.format("the tree of axis sizes is:\n{}".format(sizes)))
@wraps(fun, docstr=docstr)
def batched_fun(*args):
args_flat, in_tree = tree_flatten(args)
f = lu.wrap_init(fun)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
out_flat = batching.batch(flat_fun, args_flat, _flatten_axes(in_tree, in_axes),
in_axes_flat = _flatten_axes(in_tree, in_axes)
_check_axis_sizes(in_tree, args_flat, in_axes_flat)
out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
lambda: _flatten_axes(out_tree(), out_axes))
return tree_unflatten(out_tree(), out_flat)

View File

@ -38,8 +38,8 @@ map = safe_map
def batch(fun, in_vals, in_dims, out_dim_dests):
out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
def batch_fun(fun, in_vals, in_dims):
@ -163,8 +163,8 @@ def get_primitive_batcher(p):
try:
return primitive_batchers[p]
except KeyError:
raise NotImplementedError(
"Batching rule for '{}' not implemented".format(p))
msg = "Batching rule for '{}' not implemented"
raise NotImplementedError(msg.format(p))
def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim)

View File

@ -1106,6 +1106,29 @@ class APITest(jtu.JaxTestCase):
ans = vfun(lambda x: x + 1, np.arange(3))
self.assertAllClose(ans, onp.arange(1, 4), check_dtypes=False)
def test_vmap_error_message_issue_705(self):
# https://github.com/google/jax/issues/705
def h(a, b):
return np.sum(a) + np.sum(b)
X = onp.random.randn(10, 4)
U = onp.random.randn(10, 2)
self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
"arg 0 has shape \(10, 4\) and axis 0 is to be mapped\n"
"arg 1 has shape \(10, 2\) and axis 1 is to be mapped\n"
"so\n"
"arg 0 has an axis to be mapped of size 10\n"
"arg 1 has an axis to be mapped of size 2",
lambda: api.vmap(h, in_axes=(0, 1))(X, U))
self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
"the tree of axis sizes is:\n",
lambda: api.vmap(h, in_axes=(0, 1))(X, [U, U]))
if __name__ == '__main__':
absltest.main()