mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8222 from mattjj:document-vmap-axis-name
PiperOrigin-RevId: 403412220
This commit is contained in:
commit
a64ce45c8e
@ -1329,6 +1329,8 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F:
|
||||
(axes) of the array returned by the :func:`vmap`-ed function, which is one
|
||||
more than the number of dimensions (axes) of the corresponding array
|
||||
returned by ``fun``.
|
||||
axis_name: Optional, a hashable Python object used to identify the mapped
|
||||
axis so that parallel collectives can be applied.
|
||||
|
||||
Returns:
|
||||
Batched/vectorized version of ``fun`` with arguments that correspond to
|
||||
@ -1403,6 +1405,16 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F:
|
||||
|
||||
If the ``out_axes`` is specified for a mapped result, the result is transposed
|
||||
accordingly.
|
||||
|
||||
Finally, here's an example using ``axis_name`` together with collectives:
|
||||
|
||||
>>> xs = jnp.arange(3. * 4.).reshape(3, 4)
|
||||
>>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
|
||||
[[12. 15. 18. 21.]
|
||||
[12. 15. 18. 21.]
|
||||
[12. 15. 18. 21.]]
|
||||
|
||||
See the :py:func:`jax.pmap` docstring for more examples involving collectives.
|
||||
"""
|
||||
_check_callable(fun)
|
||||
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
|
||||
|
Loading…
x
Reference in New Issue
Block a user