Mention in the pmap() documentation that all devices must be identical.

Fixes https://github.com/google/jax/issues/13203
This commit is contained in:
Peter Hawkins 2022-11-14 10:35:41 -05:00
parent 7600cc8a8e
commit ce17ce0550

View File

@ -1788,6 +1788,11 @@ def pmap(
:py:func:`pmap` compiles ``fun``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
:py:func:`pmap` requires that all of the participating devices are identical.
For example, it is not possible to use :py:func:`pmap` to parallelize a
computation across two different models of GPU. It is currently an error for
the same device to participate twice in the same `pmap`.
**Multi-process platforms:** On multi-process platforms such as TPU pods,
:py:func:`pmap` is designed to be used in SPMD Python programs, where every
process is running the same Python code such that all processes run the same