Clarifying docstring for devices argument of pmap.

PiperOrigin-RevId: 383486168
This commit is contained in:
James Martens 2021-07-07 13:50:14 -07:00 committed by jax authors
parent 56087dc863
commit f925b62ea0

View File

@ -1423,8 +1423,10 @@ def pmap(
static. Defaults to ().
devices: This is an experimental feature and the API is likely to change.
Optional, a sequence of Devices to map over. (Available devices can be
retrieved via jax.devices()). If specified, the size of the mapped axis
must be equal to the number of local devices in the sequence. Nested
retrieved via jax.devices()). Must be given identically for each process
in multi-process settings (and will therefore include devices across
processes). If specified, the size of the mapped axis must be equal to
the number of devices in the sequence local to the given process. Nested
:py:func:`pmap` s with ``devices`` specified in either the inner or outer
:py:func:`pmap` are not yet supported.
backend: This is an experimental feature and the API is likely to change.