address reviewer comments, fix test error

This commit is contained in:
Matthew Johnson 2019-10-31 11:57:37 -07:00
parent eae47b2330
commit 14acca7b51
3 changed files with 19 additions and 27 deletions

View File

@ -572,12 +572,11 @@ def vmap(fun, in_axes=0, out_axes=0):
length equal to the number of positional arguments to ``fun``. An integer
or None indicates which array axis to map over for all arguments (with
None indicating not to map any axis), and a tuple indicates which axis to
map for each corresponding positional argument. More generally, if the
positinal arguments to ``fun`` are container types, the corresponding
element of ``in_axes`` can itself be a matching container, so that
distinct array axes can be mapped for different container elements. The
constraint is that ``in_axes`` must be a container tree prefix of the
positional argument tuple passed to ``fun``.
map for each corresponding positional argument. If the positinal arguments
to ``fun`` are container types, the corresponding element of ``in_axes``
can itself be a matching container, so that distinct array axes can be
mapped for different container elements. ``in_axes`` must be a container
tree prefix of the positional argument tuple passed to ``fun``.
out_axes: A nonnegative integer, None, or (nested) standard Python container
(tuple/list/dict) thereof indicating where the mapped axis should appear
in the output.
@ -623,19 +622,8 @@ def vmap(fun, in_axes=0, out_axes=0):
>>> z = np.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree))
[[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]]
>>> print(vfoo(tree)).shape
(6, 2, 5)
"""
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
"but with additional array axes over which {fun} is mapped.")
@ -648,8 +636,9 @@ def vmap(fun, in_axes=0, out_axes=0):
raise TypeError(msg.format(type(in_axes), type(out_axes)))
def _check_axis_sizes(tree, vals, dims):
mapped_axis_sizes = {x.shape[d] for x, d in zip(vals, dims) if d is not None}
try:
sizes, = {x.shape[d] for x, d in zip(vals, dims) if d is not None}
sizes, = mapped_axis_sizes
except ValueError:
msg = "vmap got inconsistent sizes for array axes to be mapped:\n{}"
# we switch the error message based on whether args is a tuple of arrays,

View File

@ -1099,14 +1099,14 @@ class APITest(jtu.JaxTestCase):
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)),
ValueError, "axes specification must be a tree prefix")
def test_vmap_objects_issue_183(self):
def test_vmap_unbatched_object_passthrough_issue_183(self):
# https://github.com/google/jax/issues/183
fun = lambda f, x: f(x)
vfun = api.vmap(fun, (None, 0))
ans = vfun(lambda x: x + 1, np.arange(3))
self.assertAllClose(ans, onp.arange(1, 4), check_dtypes=False)
def test_vmap_error_message_issue_705(self):
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
# https://github.com/google/jax/issues/705
def h(a, b):
return np.sum(a) + np.sum(b)
@ -1126,7 +1126,8 @@ class APITest(jtu.JaxTestCase):
self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
"the tree of axis sizes is:\n",
"the tree of axis sizes is:\n"
"\(10, \[2, 2\]\)",
lambda: api.vmap(h, in_axes=(0, 1))(X, [U, U]))

View File

@ -369,20 +369,22 @@ class BatchingTest(jtu.JaxTestCase):
def testDynamicSlice(self):
# test dynamic_slice via numpy indexing syntax
x = onp.arange(30).reshape((10, 3))
# see https://github.com/google/jax/issues/1613 for an explanation of why we
# need to use np rather than onp to create x and idx
x = np.arange(30).reshape((10, 3))
ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
expected = x[:, 1]
self.assertAllClose(ans, expected, check_dtypes=False)
idx = onp.array([0, 1, 2, 1, 0] * 2)
idx = np.array([0, 1, 2, 1, 0] * 2)
ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
expected = x[onp.arange(10), idx]
self.assertAllClose(ans, expected, check_dtypes=False)
x = onp.arange(3)
idx = onp.array([0, 1, 2, 1, 0] * 2)
x = np.arange(3)
idx = np.array([0, 1, 2, 1, 0] * 2)
ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)