improve vmap docstring and tree prefix errors

fixes #795
This commit is contained in:
Matthew Johnson 2019-10-28 14:03:52 -07:00
parent affa2dcca4
commit f5079a6281
2 changed files with 68 additions and 15 deletions

View File

@ -565,21 +565,28 @@ def vmap(fun, in_axes=0, out_axes=0):
Args:
fun: Function to be mapped over additional axes.
in_axes: Specifies which input axes to map over. Normally this is a tuple with
one axes specification for each function argument. An integer is interpreted
as a tuple with the same value for all arguments. One argument axes specification
can be an integer (0 means first dimension), None (means that the dimension is
broadcasted). If the argument is a tuple of values, then the axes specification
can be a matching tuple as well.
out_axes: Specifies which output axes to map over. These may be integers,
`None`, or (possibly nested) tuples of integers or `None`.
in_axes: A nonnegative integer, None, or (nested) standard Python container
(tuple/list/dict) thereof specifying which input array axes to map over.
If each positional argument to ``fun`` is an array, then ``in_axes`` can
be a nonnegative integer, a None, or a tuple of integers and Nones with
length equal to the number of positional arguments to ``fun``. An integer
or None indicates which array axis to map over for all arguments (with
None indicating not to map any axis), and a tuple indicates which axis to
map for each corresponding positional argument. More generally, if the
positinal arguments to ``fun`` are container types, the corresponding
element of ``in_axes`` can itself be a matching container, so that
distinct array axes can be mapped for different container elements. The
constraint is that ``in_axes`` must be a container tree prefix of the
positional argument tuple passed to ``fun``.
out_axes: A nonnegative integer, None, or (nested) standard Python container
(tuple/list/dict) thereof indicating where the mapped axis should appear
in the output.
Returns:
Batched/vectorized version of `fun` with arguments that correspond to those
of `fun`, but with extra array axes at positions indicated by `in_axes`, and
a return value that corresponds to that of `fun`, but with extra array axes
at positions indicated by `out_axes`.
Batched/vectorized version of ``fun`` with arguments that correspond to
those of ``fun``, but with extra array axes at positions indicated by
``in_axes``, and a return value that corresponds to that of ``fun``, but
with extra array axes at positions indicated by ``out_axes``.
For example, we can implement a matrix-matrix product using a vector dot
product:
@ -588,12 +595,47 @@ def vmap(fun, in_axes=0, out_axes=0):
>>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
Here we use `[a,b]` to indicate an array with shape (a,b). Here are some
Here we use ``[a,b]`` to indicate an array with shape (a,b). Here are some
variants:
>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
Here's an example of using container types in ``in_axes`` to specify which
axes of the container elements to map over:
>>> A, B, C, D = 2, 3, 4, 5
>>> x = np.ones((A, B))
>>> y = np.ones((B, C))
>>> z = np.ones((C, D))
>>> def foo(tree_arg):
... x, (y, z) = tree_arg
... return np.dot(x, np.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6 # batch size
>>> x = np.ones((K, A, B)) # batch axis in different locations
>>> y = np.ones((B, K, C))
>>> z = np.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree))
[[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]]
"""
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
"but with additional array axes over which {fun} is mapped.")
@ -620,7 +662,12 @@ def _flatten_axes(treedef, axis_tree):
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
tree_multimap(add_leaves, _replace_nones(axis_tree), dummy)
try:
tree_multimap(add_leaves, _replace_nones(axis_tree), dummy)
except ValueError:
msg = ("axes specification must be a tree prefix of the corresponding "
"value, got specification {} for value {}.")
raise ValueError(msg.format(axis_tree, treedef))
axes = [None if a is _none_proxy else a for a in axes]
return axes

View File

@ -1093,6 +1093,12 @@ class APITest(jtu.JaxTestCase):
b = np.dot(a + np.eye(a.shape[0]), real_x)
print(gf(a, b)) # doesn't crash
def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
jtu.check_raises_regexp(
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)),
ValueError, "axes specification must be a tree prefix")
if __name__ == '__main__':
absltest.main()