mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Mention in the pmap() documentation that all devices must be identical.
Fixes https://github.com/google/jax/issues/13203
This commit is contained in:
parent
7600cc8a8e
commit
ce17ce0550
@ -1788,6 +1788,11 @@ def pmap(
|
||||
:py:func:`pmap` compiles ``fun``, so while it can be combined with
|
||||
:py:func:`jit`, it's usually unnecessary.
|
||||
|
||||
:py:func:`pmap` requires that all of the participating devices are identical.
|
||||
For example, it is not possible to use :py:func:`pmap` to parallelize a
|
||||
computation across two different models of GPU. It is currently an error for
|
||||
the same device to participate twice in the same `pmap`.
|
||||
|
||||
**Multi-process platforms:** On multi-process platforms such as TPU pods,
|
||||
:py:func:`pmap` is designed to be used in SPMD Python programs, where every
|
||||
process is running the same Python code such that all processes run the same
|
||||
|
Loading…
x
Reference in New Issue
Block a user