mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11803 from hawkinsp:multigpu
PiperOrigin-RevId: 466374926
This commit is contained in:
commit
d95b27ce1c
@ -432,9 +432,6 @@ operating system, CUDA, and CuDNN are possible, but require [building from
|
||||
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
|
||||
|
||||
* CUDA 11.1 or newer is *required*.
|
||||
* You may be able to use older CUDA versions if you [build from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source),
|
||||
but there are known bugs in CUDA in all CUDA versions older than 11.1, so we
|
||||
do not ship prebuilt binaries for older CUDA versions.
|
||||
* The supported cuDNN versions for the prebuilt wheels are:
|
||||
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
|
||||
installation is new enough, since it supports additional functionality.
|
||||
|
@ -8,4 +8,5 @@ jax.distributed module
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
initialize
|
||||
initialize
|
||||
shutdown
|
@ -2,72 +2,114 @@
|
||||
|
||||
## Introduction
|
||||
|
||||
This guide explains how to use JAX in environments such as [Cloud
|
||||
TPU](https://cloud.google.com/tpu) pods where accelerators are spread across
|
||||
multiple CPU hosts or JAX processes. We’ll refer to these as “multi-process”
|
||||
environments.
|
||||
This guide explains how to use JAX in environments such as
|
||||
GPU clusters and [Cloud TPU](https://cloud.google.com/tpu) pods where
|
||||
accelerators are spread across multiple CPU hosts or JAX processes. We’ll refer
|
||||
to these as “multi-process” environments.
|
||||
|
||||
This guide specifically focuses on how to use collective communication
|
||||
operations (e.g. {func}`jax.lax.psum`) in multi-process settings, although other
|
||||
communication methods may be useful too depending on your use case (e.g. RPC,
|
||||
[mpi4jax](https://github.com/mpi4jax/mpi4jax)). If you’re not already familiar
|
||||
with JAX’s collective operations, we recommend starting with the
|
||||
{doc}`/jax-101/06-parallelism` notebook. An important requirement of multi-process
|
||||
environments in JAX is direct communication links between accelerators, e.g. the
|
||||
high-speed interconnects for Cloud TPUs or
|
||||
[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links are what allow
|
||||
collective operations to run across multiple processes’ worth of accelerators.
|
||||
|
||||
operations (e.g. {func}`jax.lax.psum` ) in multi-process settings, although
|
||||
other communication methods may be useful too depending on your use case (e.g.
|
||||
RPC, [mpi4jax](https://github.com/mpi4jax/mpi4jax)). If you’re not already
|
||||
familiar with JAX’s collective operations, we recommend starting with the
|
||||
{doc}`/jax-101/06-parallelism` notebook. An important requirement of
|
||||
multi-process environments in JAX is direct communication links between
|
||||
accelerators, e.g. the high-speed interconnects for Cloud TPUs or
|
||||
[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links allow
|
||||
collective operations to run across multiple processes’ worth of accelerators
|
||||
with high performance.
|
||||
|
||||
## Multi-process programming model
|
||||
|
||||
Key concepts:
|
||||
* You must run at least one JAX process per host.
|
||||
* Each process has a distinct set of _local_ devices it can address. The
|
||||
_global_ devices are the set of all devices across all processes.
|
||||
* Use standard JAX parallelism APIs like {func}`~jax.pmap` and
|
||||
{func}`~jax.experimental.maps.xmap`. Each process “sees” _local_ input and
|
||||
output to parallelized functions, but communication inside the computations
|
||||
is _global_.
|
||||
* Make sure all processes run the same parallel computations in the same
|
||||
order.
|
||||
|
||||
* You must run at least one JAX process per host.
|
||||
* You should initialize the cluster with {func}`jax.distributed.initialize`.
|
||||
* Each process has a
|
||||
distinct set of *local* devices it can address. The *global* devices are the set
|
||||
of all devices across all processes.
|
||||
* Use standard JAX parallelism APIs like {func}`~jax.pmap` and
|
||||
{func}`~jax.experimental.maps.xmap` . Each process “sees” *local* input and
|
||||
output to parallelized functions, but communication inside the computations
|
||||
is *global*.
|
||||
* Make sure all processes run the same parallel computations in the same
|
||||
order.
|
||||
|
||||
### Launching JAX processes
|
||||
|
||||
Unlike other distributed systems where a single controller node manages many
|
||||
worker nodes, JAX uses a “multi-controller” programming model where each JAX
|
||||
Python process runs independently, sometimes referred to as a
|
||||
{term}`Single Program, Multiple Data (SPMD)<SPMD>` model. Generally, the same
|
||||
JAX Python program is run in each process, with only slight differences between
|
||||
each process’s execution (e.g. different processes will load different input
|
||||
data). Furthermore, **you must manually run your JAX program on each host!** JAX
|
||||
Python process runs independently, sometimes referred to as a {term}`Single
|
||||
Program, Multiple Data (SPMD)<SPMD>` model. Generally, the same JAX Python
|
||||
program is run in each process, with only slight differences between each
|
||||
process’s execution (e.g. different processes will load different input data).
|
||||
Furthermore, **you must manually run your JAX program on each host!** JAX
|
||||
doesn’t automatically start multiple processes from a single program invocation.
|
||||
|
||||
(This is why this guide isn’t offered as a notebook -- we don’t currently have a
|
||||
good way to manage multiple Python processes from a single notebook.)
|
||||
(The requirement for multiple processes is why this guide isn’t offered as a
|
||||
notebook -- we don’t currently have a good way to manage multiple Python
|
||||
processes from a single notebook.)
|
||||
|
||||
### Initializing the cluster
|
||||
|
||||
To initialize the cluster, you should call {func}`jax.distributed.initialize` at
|
||||
the start of each process. {func}`jax.distributed.initialize` must be called
|
||||
early in the program, before any JAX computations are executed.
|
||||
|
||||
The API {func}`jax.distributed.initialize` takes several arguments, namely:
|
||||
|
||||
* `coordinator_address`: the IP address of process 0 in your cluster, together
|
||||
with a port available on that process. Process 0 will start a JAX service
|
||||
exposed via that IP address and port, to which the other processes in the
|
||||
cluster will connect.
|
||||
* `num_processes`: the number of processes in the cluster
|
||||
* `process_id`: the ID number of this process, in the range `[0 ..
|
||||
num_processes)`.
|
||||
|
||||
For example on GPU, a typical usage is:
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
|
||||
num_processes=2,
|
||||
process_id=0)
|
||||
```
|
||||
|
||||
On Cloud TPU, you can simply call {func}`jax.distributed.initialize()` with no
|
||||
arguments. Default values for the arguments will be chosen automatically using
|
||||
the TPU pod metadata:
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
jax.distributed.initialize()
|
||||
```
|
||||
|
||||
On TPU at present calling {func}`jax.distributed.initialize` is optional, but
|
||||
recommanded since it enables additional checkpointing and health checking features.
|
||||
|
||||
### Local vs. global devices
|
||||
|
||||
Before we get to running multi-process computations from your program, it’s
|
||||
important to understand the distinction between _local_ and _global_ devices.
|
||||
important to understand the distinction between *local* and *global* devices.
|
||||
|
||||
**A process’s _local_ devices are those that it can directly address and launch
|
||||
computations on.** For example, in a Cloud TPU pod, each host can only launch
|
||||
computations on the 8 TPU cores attached directly to that host (see the [Cloud
|
||||
TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture)
|
||||
**A process’s *local* devices are those that it can directly address and launch
|
||||
computations on.** For example, on a GPU cluster, each host can only launch
|
||||
computations on the directly attached GPUs. On a Cloud TPU pod, each host can
|
||||
only launch computations on the 8 TPU cores attached directly to that host (see
|
||||
the
|
||||
[Cloud TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture)
|
||||
documentation for more details). You can see a process’s local devices via
|
||||
{func}`jax.local_devices()`.
|
||||
{func}`jax.local_devices()` .
|
||||
|
||||
**The _global_ devices are the devices across all processes.** A computation can
|
||||
**The *global* devices are the devices across all processes.** A computation can
|
||||
span devices across processes and perform collective operations via the direct
|
||||
communication links between devices, as long as each process launches the
|
||||
computation on its local devices. You can see all available global devices via
|
||||
{func}`jax.devices()`. A process’s local devices are always a subset of the
|
||||
{func}`jax.devices()` . A process’s local devices are always a subset of the
|
||||
global devices.
|
||||
|
||||
|
||||
### Running multi-process computations
|
||||
|
||||
So how do you actually run a computation involving cross-process communication?
|
||||
@ -77,22 +119,23 @@ For example, {func}`~jax.pmap` can be used to run a parallel computation across
|
||||
multiple processes. (If you’re not already familiar with how to use
|
||||
{func}`~jax.pmap` to run across multiple devices within a single process, check
|
||||
out the {doc}`/jax-101/06-parallelism` notebook.) Each process should call the
|
||||
same pmapped function and pass in arguments to be mapped across its _local_
|
||||
devices (i.e., the pmapped axis size is equal to the number of local
|
||||
devices). Similarly, the function will return outputs sharded across _local_
|
||||
devices only. Inside the function, however, collective communication operations
|
||||
are run across all _global_ devices, across all processes. Conceptually, this
|
||||
can be thought of as running a pmap over a single array sharded across hosts,
|
||||
where each host “sees” only its local shard of the input and output.
|
||||
same pmapped function and pass in arguments to be mapped across its *local*
|
||||
devices (i.e., the pmapped axis size is equal to the number of local devices).
|
||||
Similarly, the function will return outputs sharded across *local* devices only.
|
||||
Inside the function, however, collective communication operations are run across
|
||||
all *global* devices, across all processes. Conceptually, this can be thought of
|
||||
as running a pmap over a single array sharded across hosts, where each host
|
||||
“sees” only its local shard of the input and output.
|
||||
|
||||
Here’s an example of multi-process pmap in action:
|
||||
|
||||
```python
|
||||
# The following is run in parallel on each host in a Cloud TPU v3-32 pod slice
|
||||
# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
|
||||
>>> import jax
|
||||
>>> jax.device_count() # total number of TPU cores in pod slice
|
||||
>>> jax.distributed.initialize() # On GPU, see above for the necessary arguments.
|
||||
>>> jax.device_count() # total number of accelerator devices in the cluster
|
||||
32
|
||||
>>> jax.local_device_count() # number of TPU cores attached to this host
|
||||
>>> jax.local_device_count() # number of accelerator devices attached to this host
|
||||
8
|
||||
# The psum is performed over all mapped devices across the pod slice
|
||||
>>> xs = jax.numpy.ones(jax.local_device_count())
|
||||
@ -102,12 +145,10 @@ ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
|
||||
|
||||
{func}`~jax.experimental.maps.xmap` works similarly when using a physical
|
||||
hardware mesh (see the {doc}`xmap tutorial</notebooks/xmap_tutorial>` if you’re
|
||||
not familiar with the single-process version). Like {func}`~jax.pmap`, the
|
||||
not familiar with the single-process version). Like {func}`~jax.pmap` , the
|
||||
inputs and outputs are local and any parallel communication inside the xmapped
|
||||
function is global. The mesh is also global.
|
||||
|
||||
TODO: xmap example
|
||||
|
||||
**It’s very important that all processes run the same cross-process computations
|
||||
in the same order.** Running the same JAX Python program in each process is
|
||||
usually sufficient. Some common pitfalls to look out for that may cause
|
||||
|
@ -102,45 +102,53 @@ global_state = State()
|
||||
def initialize(coordinator_address: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
process_id: Optional[int] = None):
|
||||
"""Initialize distributed system for topology discovery.
|
||||
"""Initializes the JAX distributed system.
|
||||
|
||||
Currently, calling ``initialize`` sets up the multi-host GPU backend and Cloud
|
||||
TPU backend.
|
||||
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
|
||||
multi-host GPU and Cloud TPU. :func:`~jax.distributed.initialize` must be
|
||||
called before performing any JAX computations.
|
||||
|
||||
If you are on GPU platform, you will have to provide the coordinator_address
|
||||
and other args to the `initialize` API.
|
||||
The JAX distributed system serves a number of roles:
|
||||
|
||||
If you are on TPU platform, the coordinator_address and other args will be
|
||||
auto detected but you have the option to provide it too.
|
||||
* it allows JAX processes to discover each other and share topology information,
|
||||
* it performs health checking, ensuring that all processes shut down if any process dies, and
|
||||
* it is used for distributed checkpointing.
|
||||
|
||||
If you are using GPU, you must provide the ``coordinator_address``,
|
||||
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
|
||||
|
||||
If you are using TPU, all arguments are optional: if omitted, they
|
||||
will be chosen automatically from the Cloud TPU metadata.
|
||||
|
||||
Args:
|
||||
coordinator_address: IP address and port of the coordinator. The choice of
|
||||
coordinator_address: the IP address of process `0` and a port on which that
|
||||
process should launch a coordinator service. The choice of
|
||||
port does not matter, so long as the port is available on the coordinator
|
||||
and all processes agree on the port.
|
||||
Can be None only for TPU platform. If coordinator_address is None on TPU,
|
||||
then it will be auto detected.
|
||||
num_processes: Number of processes. Can be None only for TPU platform and
|
||||
if None will be determined from the TPU slice metadata.
|
||||
process_id: Id of the current process. Can be None only for TPU platform and
|
||||
if None will default to the current TPU worker id determined via the TPU
|
||||
slice metadata.
|
||||
May be ``None`` only on TPU, in which case it will be chosen automatically.
|
||||
num_processes: Number of processes. May be ``None`` only on TPU, in
|
||||
which case it will be chosen automatically based on the TPU slice.
|
||||
process_id: The ID number of the current process. The ``process_id`` values across
|
||||
the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``.
|
||||
May be ``None`` only on TPU; if ``None`` it will be chosen from the TPU slice
|
||||
metadata.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `distributed.initialize` is called more than once.
|
||||
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
|
||||
|
||||
Example:
|
||||
|
||||
Suppose there are two GPU hosts, and host 0 is the designated coordinator
|
||||
Suppose there are two GPU processs, and process 0 is the designated coordinator
|
||||
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
|
||||
following commands before anything else.
|
||||
|
||||
On host 0:
|
||||
On process 0:
|
||||
|
||||
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP
|
||||
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0) # doctest: +SKIP
|
||||
|
||||
On host 1:
|
||||
On process 1:
|
||||
|
||||
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
|
||||
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
|
||||
"""
|
||||
global_state.initialize(coordinator_address, num_processes, process_id)
|
||||
atexit.register(shutdown)
|
||||
|
Loading…
x
Reference in New Issue
Block a user