Merge pull request #11803 from hawkinsp:multigpu

PiperOrigin-RevId: 466374926
This commit is contained in:
jax authors 2022-08-09 09:02:21 -07:00
commit d95b27ce1c
4 changed files with 124 additions and 77 deletions

View File

@ -432,9 +432,6 @@ operating system, CUDA, and CuDNN are possible, but require [building from
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
* CUDA 11.1 or newer is *required*.
* You may be able to use older CUDA versions if you [build from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source),
but there are known bugs in CUDA in all CUDA versions older than 11.1, so we
do not ship prebuilt binaries for older CUDA versions.
* The supported cuDNN versions for the prebuilt wheels are:
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
installation is new enough, since it supports additional functionality.

View File

@ -8,4 +8,5 @@ jax.distributed module
.. autosummary::
:toctree: _autosummary
initialize
initialize
shutdown

View File

@ -2,72 +2,114 @@
## Introduction
This guide explains how to use JAX in environments such as [Cloud
TPU](https://cloud.google.com/tpu) pods where accelerators are spread across
multiple CPU hosts or JAX processes. Well refer to these as “multi-process”
environments.
This guide explains how to use JAX in environments such as
GPU clusters and [Cloud TPU](https://cloud.google.com/tpu) pods where
accelerators are spread across multiple CPU hosts or JAX processes. Well refer
to these as “multi-process” environments.
This guide specifically focuses on how to use collective communication
operations (e.g. {func}`jax.lax.psum`) in multi-process settings, although other
communication methods may be useful too depending on your use case (e.g. RPC,
[mpi4jax](https://github.com/mpi4jax/mpi4jax)). If youre not already familiar
with JAXs collective operations, we recommend starting with the
{doc}`/jax-101/06-parallelism` notebook. An important requirement of multi-process
environments in JAX is direct communication links between accelerators, e.g. the
high-speed interconnects for Cloud TPUs or
[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links are what allow
collective operations to run across multiple processes worth of accelerators.
operations (e.g. {func}`jax.lax.psum` ) in multi-process settings, although
other communication methods may be useful too depending on your use case (e.g.
RPC, [mpi4jax](https://github.com/mpi4jax/mpi4jax)). If youre not already
familiar with JAXs collective operations, we recommend starting with the
{doc}`/jax-101/06-parallelism` notebook. An important requirement of
multi-process environments in JAX is direct communication links between
accelerators, e.g. the high-speed interconnects for Cloud TPUs or
[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links allow
collective operations to run across multiple processes worth of accelerators
with high performance.
## Multi-process programming model
Key concepts:
* You must run at least one JAX process per host.
* Each process has a distinct set of _local_ devices it can address. The
_global_ devices are the set of all devices across all processes.
* Use standard JAX parallelism APIs like {func}`~jax.pmap` and
{func}`~jax.experimental.maps.xmap`. Each process “sees” _local_ input and
output to parallelized functions, but communication inside the computations
is _global_.
* Make sure all processes run the same parallel computations in the same
order.
* You must run at least one JAX process per host.
* You should initialize the cluster with {func}`jax.distributed.initialize`.
* Each process has a
distinct set of *local* devices it can address. The *global* devices are the set
of all devices across all processes.
* Use standard JAX parallelism APIs like {func}`~jax.pmap` and
{func}`~jax.experimental.maps.xmap` . Each process “sees” *local* input and
output to parallelized functions, but communication inside the computations
is *global*.
* Make sure all processes run the same parallel computations in the same
order.
### Launching JAX processes
Unlike other distributed systems where a single controller node manages many
worker nodes, JAX uses a “multi-controller” programming model where each JAX
Python process runs independently, sometimes referred to as a
{term}`Single Program, Multiple Data (SPMD)<SPMD>` model. Generally, the same
JAX Python program is run in each process, with only slight differences between
each processs execution (e.g. different processes will load different input
data). Furthermore, **you must manually run your JAX program on each host!** JAX
Python process runs independently, sometimes referred to as a {term}`Single
Program, Multiple Data (SPMD)<SPMD>` model. Generally, the same JAX Python
program is run in each process, with only slight differences between each
processs execution (e.g. different processes will load different input data).
Furthermore, **you must manually run your JAX program on each host!** JAX
doesnt automatically start multiple processes from a single program invocation.
(This is why this guide isnt offered as a notebook -- we dont currently have a
good way to manage multiple Python processes from a single notebook.)
(The requirement for multiple processes is why this guide isnt offered as a
notebook -- we dont currently have a good way to manage multiple Python
processes from a single notebook.)
### Initializing the cluster
To initialize the cluster, you should call {func}`jax.distributed.initialize` at
the start of each process. {func}`jax.distributed.initialize` must be called
early in the program, before any JAX computations are executed.
The API {func}`jax.distributed.initialize` takes several arguments, namely:
* `coordinator_address`: the IP address of process 0 in your cluster, together
with a port available on that process. Process 0 will start a JAX service
exposed via that IP address and port, to which the other processes in the
cluster will connect.
* `num_processes`: the number of processes in the cluster
* `process_id`: the ID number of this process, in the range `[0 ..
num_processes)`.
For example on GPU, a typical usage is:
```python
import jax
jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
num_processes=2,
process_id=0)
```
On Cloud TPU, you can simply call {func}`jax.distributed.initialize()` with no
arguments. Default values for the arguments will be chosen automatically using
the TPU pod metadata:
```python
import jax
jax.distributed.initialize()
```
On TPU at present calling {func}`jax.distributed.initialize` is optional, but
recommanded since it enables additional checkpointing and health checking features.
### Local vs. global devices
Before we get to running multi-process computations from your program, its
important to understand the distinction between _local_ and _global_ devices.
important to understand the distinction between *local* and *global* devices.
**A processs _local_ devices are those that it can directly address and launch
computations on.** For example, in a Cloud TPU pod, each host can only launch
computations on the 8 TPU cores attached directly to that host (see the [Cloud
TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture)
**A processs *local* devices are those that it can directly address and launch
computations on.** For example, on a GPU cluster, each host can only launch
computations on the directly attached GPUs. On a Cloud TPU pod, each host can
only launch computations on the 8 TPU cores attached directly to that host (see
the
[Cloud TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture)
documentation for more details). You can see a processs local devices via
{func}`jax.local_devices()`.
{func}`jax.local_devices()` .
**The _global_ devices are the devices across all processes.** A computation can
**The *global* devices are the devices across all processes.** A computation can
span devices across processes and perform collective operations via the direct
communication links between devices, as long as each process launches the
computation on its local devices. You can see all available global devices via
{func}`jax.devices()`. A processs local devices are always a subset of the
{func}`jax.devices()` . A processs local devices are always a subset of the
global devices.
### Running multi-process computations
So how do you actually run a computation involving cross-process communication?
@ -77,22 +119,23 @@ For example, {func}`~jax.pmap` can be used to run a parallel computation across
multiple processes. (If youre not already familiar with how to use
{func}`~jax.pmap` to run across multiple devices within a single process, check
out the {doc}`/jax-101/06-parallelism` notebook.) Each process should call the
same pmapped function and pass in arguments to be mapped across its _local_
devices (i.e., the pmapped axis size is equal to the number of local
devices). Similarly, the function will return outputs sharded across _local_
devices only. Inside the function, however, collective communication operations
are run across all _global_ devices, across all processes. Conceptually, this
can be thought of as running a pmap over a single array sharded across hosts,
where each host “sees” only its local shard of the input and output.
same pmapped function and pass in arguments to be mapped across its *local*
devices (i.e., the pmapped axis size is equal to the number of local devices).
Similarly, the function will return outputs sharded across *local* devices only.
Inside the function, however, collective communication operations are run across
all *global* devices, across all processes. Conceptually, this can be thought of
as running a pmap over a single array sharded across hosts, where each host
“sees” only its local shard of the input and output.
Heres an example of multi-process pmap in action:
```python
# The following is run in parallel on each host in a Cloud TPU v3-32 pod slice
# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.device_count() # total number of TPU cores in pod slice
>>> jax.distributed.initialize() # On GPU, see above for the necessary arguments.
>>> jax.device_count() # total number of accelerator devices in the cluster
32
>>> jax.local_device_count() # number of TPU cores attached to this host
>>> jax.local_device_count() # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
@ -102,12 +145,10 @@ ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
{func}`~jax.experimental.maps.xmap` works similarly when using a physical
hardware mesh (see the {doc}`xmap tutorial</notebooks/xmap_tutorial>` if youre
not familiar with the single-process version). Like {func}`~jax.pmap`, the
not familiar with the single-process version). Like {func}`~jax.pmap` , the
inputs and outputs are local and any parallel communication inside the xmapped
function is global. The mesh is also global.
TODO: xmap example
**Its very important that all processes run the same cross-process computations
in the same order.** Running the same JAX Python program in each process is
usually sufficient. Some common pitfalls to look out for that may cause

View File

@ -102,45 +102,53 @@ global_state = State()
def initialize(coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
"""Initialize distributed system for topology discovery.
"""Initializes the JAX distributed system.
Currently, calling ``initialize`` sets up the multi-host GPU backend and Cloud
TPU backend.
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
multi-host GPU and Cloud TPU. :func:`~jax.distributed.initialize` must be
called before performing any JAX computations.
If you are on GPU platform, you will have to provide the coordinator_address
and other args to the `initialize` API.
The JAX distributed system serves a number of roles:
If you are on TPU platform, the coordinator_address and other args will be
auto detected but you have the option to provide it too.
* it allows JAX processes to discover each other and share topology information,
* it performs health checking, ensuring that all processes shut down if any process dies, and
* it is used for distributed checkpointing.
If you are using GPU, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
If you are using TPU, all arguments are optional: if omitted, they
will be chosen automatically from the Cloud TPU metadata.
Args:
coordinator_address: IP address and port of the coordinator. The choice of
coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of
port does not matter, so long as the port is available on the coordinator
and all processes agree on the port.
Can be None only for TPU platform. If coordinator_address is None on TPU,
then it will be auto detected.
num_processes: Number of processes. Can be None only for TPU platform and
if None will be determined from the TPU slice metadata.
process_id: Id of the current process. Can be None only for TPU platform and
if None will default to the current TPU worker id determined via the TPU
slice metadata.
May be ``None`` only on TPU, in which case it will be chosen automatically.
num_processes: Number of processes. May be ``None`` only on TPU, in
which case it will be chosen automatically based on the TPU slice.
process_id: The ID number of the current process. The ``process_id`` values across
the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``.
May be ``None`` only on TPU; if ``None`` it will be chosen from the TPU slice
metadata.
Raises:
RuntimeError: If `distributed.initialize` is called more than once.
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
Example:
Suppose there are two GPU hosts, and host 0 is the designated coordinator
Suppose there are two GPU processs, and process 0 is the designated coordinator
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
following commands before anything else.
On host 0:
On process 0:
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0) # doctest: +SKIP
On host 1:
On process 1:
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
"""
global_state.initialize(coordinator_address, num_processes, process_id)
atexit.register(shutdown)