mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
affa2dcca4
commit
f5079a6281
77
jax/api.py
77
jax/api.py
@ -565,21 +565,28 @@ def vmap(fun, in_axes=0, out_axes=0):
|
||||
|
||||
Args:
|
||||
fun: Function to be mapped over additional axes.
|
||||
in_axes: Specifies which input axes to map over. Normally this is a tuple with
|
||||
one axes specification for each function argument. An integer is interpreted
|
||||
as a tuple with the same value for all arguments. One argument axes specification
|
||||
can be an integer (0 means first dimension), None (means that the dimension is
|
||||
broadcasted). If the argument is a tuple of values, then the axes specification
|
||||
can be a matching tuple as well.
|
||||
|
||||
out_axes: Specifies which output axes to map over. These may be integers,
|
||||
`None`, or (possibly nested) tuples of integers or `None`.
|
||||
in_axes: A nonnegative integer, None, or (nested) standard Python container
|
||||
(tuple/list/dict) thereof specifying which input array axes to map over.
|
||||
If each positional argument to ``fun`` is an array, then ``in_axes`` can
|
||||
be a nonnegative integer, a None, or a tuple of integers and Nones with
|
||||
length equal to the number of positional arguments to ``fun``. An integer
|
||||
or None indicates which array axis to map over for all arguments (with
|
||||
None indicating not to map any axis), and a tuple indicates which axis to
|
||||
map for each corresponding positional argument. More generally, if the
|
||||
positinal arguments to ``fun`` are container types, the corresponding
|
||||
element of ``in_axes`` can itself be a matching container, so that
|
||||
distinct array axes can be mapped for different container elements. The
|
||||
constraint is that ``in_axes`` must be a container tree prefix of the
|
||||
positional argument tuple passed to ``fun``.
|
||||
out_axes: A nonnegative integer, None, or (nested) standard Python container
|
||||
(tuple/list/dict) thereof indicating where the mapped axis should appear
|
||||
in the output.
|
||||
|
||||
Returns:
|
||||
Batched/vectorized version of `fun` with arguments that correspond to those
|
||||
of `fun`, but with extra array axes at positions indicated by `in_axes`, and
|
||||
a return value that corresponds to that of `fun`, but with extra array axes
|
||||
at positions indicated by `out_axes`.
|
||||
Batched/vectorized version of ``fun`` with arguments that correspond to
|
||||
those of ``fun``, but with extra array axes at positions indicated by
|
||||
``in_axes``, and a return value that corresponds to that of ``fun``, but
|
||||
with extra array axes at positions indicated by ``out_axes``.
|
||||
|
||||
For example, we can implement a matrix-matrix product using a vector dot
|
||||
product:
|
||||
@ -588,12 +595,47 @@ def vmap(fun, in_axes=0, out_axes=0):
|
||||
>>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
|
||||
>>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
|
||||
|
||||
Here we use `[a,b]` to indicate an array with shape (a,b). Here are some
|
||||
Here we use ``[a,b]`` to indicate an array with shape (a,b). Here are some
|
||||
variants:
|
||||
|
||||
>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis)
|
||||
>>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis)
|
||||
>>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
|
||||
|
||||
Here's an example of using container types in ``in_axes`` to specify which
|
||||
axes of the container elements to map over:
|
||||
|
||||
>>> A, B, C, D = 2, 3, 4, 5
|
||||
>>> x = np.ones((A, B))
|
||||
>>> y = np.ones((B, C))
|
||||
>>> z = np.ones((C, D))
|
||||
>>> def foo(tree_arg):
|
||||
... x, (y, z) = tree_arg
|
||||
... return np.dot(x, np.dot(y, z))
|
||||
>>> tree = (x, (y, z))
|
||||
>>> print(foo(tree))
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
>>> from jax import vmap
|
||||
>>> K = 6 # batch size
|
||||
>>> x = np.ones((K, A, B)) # batch axis in different locations
|
||||
>>> y = np.ones((B, K, C))
|
||||
>>> z = np.ones((C, D, K))
|
||||
>>> tree = (x, (y, z))
|
||||
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
|
||||
>>> print(vfoo(tree))
|
||||
[[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]
|
||||
[[12. 12. 12. 12. 12.]
|
||||
[12. 12. 12. 12. 12.]]]
|
||||
"""
|
||||
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
|
||||
"but with additional array axes over which {fun} is mapped.")
|
||||
@ -620,7 +662,12 @@ def _flatten_axes(treedef, axis_tree):
|
||||
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
|
||||
axes = []
|
||||
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
||||
tree_multimap(add_leaves, _replace_nones(axis_tree), dummy)
|
||||
try:
|
||||
tree_multimap(add_leaves, _replace_nones(axis_tree), dummy)
|
||||
except ValueError:
|
||||
msg = ("axes specification must be a tree prefix of the corresponding "
|
||||
"value, got specification {} for value {}.")
|
||||
raise ValueError(msg.format(axis_tree, treedef))
|
||||
axes = [None if a is _none_proxy else a for a in axes]
|
||||
return axes
|
||||
|
||||
|
@ -1093,6 +1093,12 @@ class APITest(jtu.JaxTestCase):
|
||||
b = np.dot(a + np.eye(a.shape[0]), real_x)
|
||||
print(gf(a, b)) # doesn't crash
|
||||
|
||||
def test_vmap_in_axes_tree_prefix_error(self):
|
||||
# https://github.com/google/jax/issues/795
|
||||
jtu.check_raises_regexp(
|
||||
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)),
|
||||
ValueError, "axes specification must be a tree prefix")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user