mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Clarifying docstring for devices
argument of pmap
.
PiperOrigin-RevId: 383486168
This commit is contained in:
parent
56087dc863
commit
f925b62ea0
@ -1423,8 +1423,10 @@ def pmap(
|
||||
static. Defaults to ().
|
||||
devices: This is an experimental feature and the API is likely to change.
|
||||
Optional, a sequence of Devices to map over. (Available devices can be
|
||||
retrieved via jax.devices()). If specified, the size of the mapped axis
|
||||
must be equal to the number of local devices in the sequence. Nested
|
||||
retrieved via jax.devices()). Must be given identically for each process
|
||||
in multi-process settings (and will therefore include devices across
|
||||
processes). If specified, the size of the mapped axis must be equal to
|
||||
the number of devices in the sequence local to the given process. Nested
|
||||
:py:func:`pmap` s with ``devices`` specified in either the inner or outer
|
||||
:py:func:`pmap` are not yet supported.
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
|
Loading…
x
Reference in New Issue
Block a user