Merge pull request #8222 from mattjj:document-vmap-axis-name

PiperOrigin-RevId: 403412220
This commit is contained in:
jax authors 2021-10-15 10:43:00 -07:00
commit a64ce45c8e

View File

@ -1329,6 +1329,8 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F:
(axes) of the array returned by the :func:`vmap`-ed function, which is one
more than the number of dimensions (axes) of the corresponding array
returned by ``fun``.
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
Returns:
Batched/vectorized version of ``fun`` with arguments that correspond to
@ -1403,6 +1405,16 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F:
If the ``out_axes`` is specified for a mapped result, the result is transposed
accordingly.
Finally, here's an example using ``axis_name`` together with collectives:
>>> xs = jnp.arange(3. * 4.).reshape(3, 4)
>>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
[[12. 15. 18. 21.]
[12. 15. 18. 21.]
[12. 15. 18. 21.]]
See the :py:func:`jax.pmap` docstring for more examples involving collectives.
"""
_check_callable(fun)
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "