Pass axis name to _match_axes and add to error message.

This commit is contained in:
Peter Choy 2021-04-08 14:08:51 +00:00 committed by boyentenbi
parent 454f5e67b1
commit eb9d6e4d21
2 changed files with 20 additions and 6 deletions

View File

@ -57,7 +57,8 @@ def batch_subtrace(main, in_dims, *in_vals):
yield out_vals, out_dims
@lu.transformation
def _match_axes(axis_size, in_dims, out_dims_thunk, out_dim_dests, *in_vals):
def _match_axes(axis_size, axis_name, in_dims, out_dims_thunk, out_dim_dests,
*in_vals):
if axis_size is None:
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
out_vals = yield in_vals, {}
@ -65,7 +66,10 @@ def _match_axes(axis_size, in_dims, out_dims_thunk, out_dim_dests, *in_vals):
out_dims = out_dims_thunk()
for od, od_dest in zip(out_dims, out_dim_dests):
if od is not None and not isinstance(od_dest, int):
msg = f"vmap has mapped output but out_axes is {od_dest}"
if not isinstance(axis_name, core._TempAxisName):
msg = f"vmap has mapped output (axis_name={axis_name}) but out_axes is {od_dest}"
else:
msg = f"vmap has mapped output but out_axes is {od_dest}"
raise ValueError(msg)
yield map(partial(matchaxis, axis_size), out_dims, out_dim_dests, out_vals)
@ -280,8 +284,9 @@ def batch(fun: lu.WrappedFun,
# anlogue of `jvp` in ad.py
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
fun, out_dims_thunk = batch_subtrace(fun)
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims, main_type),
axis_size, in_dims, out_dims_thunk, out_dim_dests)
return _match_axes(
batchfun(fun, axis_name, axis_size, in_dims, main_type), axis_size,
axis_name, in_dims, out_dims_thunk, out_dim_dests)
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat: lu.WrappedFun,

View File

@ -1908,8 +1908,17 @@ class APITest(jtu.JaxTestCase):
api.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))(jnp.array([1., 2.]))
with self.assertRaisesRegex(
ValueError, "vmap has mapped output but out_axes is None"):
# If the output is mapped, then there must be some out_axes specified
ValueError,
r"vmap has mapped output \(axis_name=foo\) but out_axes is None"):
# If the output is mapped (user-named axis), then there must be some
# out_axes specified.
api.vmap(lambda x: x, out_axes=None, axis_name="foo")(jnp.array([1., 2.]))
with self.assertRaisesRegex(
ValueError,
"vmap has mapped output but out_axes is None"):
# If the output is mapped (unnamed axis), then there must be some out_axes
# specified.
api.vmap(lambda x: x, out_axes=None)(jnp.array([1., 2.]))
def test_vmap_structured_in_axes(self):