Add Open MPI automatic distributed initialization

This commit is contained in:
Nicolas Castet 2023-01-09 11:42:20 -06:00
parent ced7332587
commit b86030d86f
5 changed files with 105 additions and 4 deletions

View File

@ -78,11 +78,12 @@ jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
process_id=0)
```
On Cloud TPU and Slurm environments, you can simply call {func}`jax.distributed.initialize()` with no
On Cloud TPU, Slurm and Open MPI environments, you can simply call {func}`jax.distributed.initialize()` with no
arguments. Default values for the arguments will be chosen automatically.
When running on GPUs with Slurm, it is assumed that one process is started per GPU, i.e. each process will
When running on GPUs with Slurm and Open MPI, it is assumed that one process is started per GPU, i.e. each process will
be assigned only one visible local device. Otherwise it is assumed that one process is started per host,
i.e. each process will be assigned all local devices.
The Open MPI auto-initialization is only used when the JAX processes are launched via `mpirun`/`mpiexec`.
```python
import jax

View File

@ -20,5 +20,6 @@ from .cluster import ClusterEnv
# the user did not explicitly provide the arguments
# to :func:`jax.distributed.initialize`, the first
# available one from the list will be picked.
from .ompi_cluster import OmpiCluster
from .slurm_cluster import SlurmCluster
from .cloud_tpu_cluster import TpuCluster

View File

@ -0,0 +1,59 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from typing import Optional
from jax._src.clusters import ClusterEnv
# OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec
_ORTE_URI = 'OMPI_MCA_orte_hnp_uri'
_PROCESS_COUNT = 'OMPI_COMM_WORLD_SIZE'
_PROCESS_ID = 'OMPI_COMM_WORLD_RANK'
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
class OmpiCluster(ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ
@classmethod
def get_coordinator_address(cls) -> str:
# Examples of orte_uri:
# 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911
# 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370
orte_uri = os.environ[_ORTE_URI]
job_id_str = orte_uri.split('.', maxsplit=1)[0]
# The jobid is always a multiple of 2^12, let's divide it by 2^12
# to reduce likelihood of port conflict between jobs
job_id = int(job_id_str) // 2**12
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = job_id % 2**12 + (65535 - 2**12 + 1)
launcher_ip_match = re.search(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
if launcher_ip_match is None:
raise RuntimeError('Could not parse coordinator IP address from Open MPI environment.')
launcher_ip = next(i for i in launcher_ip_match.groups() if i is not None)
return f'{launcher_ip}:{port}'
@classmethod
def get_process_count(cls) -> int:
return int(os.environ[_PROCESS_COUNT])
@classmethod
def get_process_id(cls) -> int:
return int(os.environ[_PROCESS_ID])
@classmethod
def get_local_process_id(cls) -> Optional[int]:
return int(os.environ[_LOCAL_PROCESS_ID])

View File

@ -119,7 +119,7 @@ def initialize(coordinator_address: Optional[str] = None,
* it performs health checking, ensuring that all processes shut down if any process dies, and
* it is used for distributed checkpointing.
If you are using TPU or Slurm, all arguments are optional: if omitted, they
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
will be chosen automatically.
Otherwise, you must provide the ``coordinator_address``,
@ -140,7 +140,7 @@ def initialize(coordinator_address: Optional[str] = None,
May be ``None`` only on supported environments; if ``None`` it will be chosen automatically.
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
If ``None``, defaults to all local devices being visible to the process except when processes
are launched via Slurm on GPUs. In that case, it will default to a single device per process.
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
Raises:
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.

View File

@ -14,6 +14,7 @@
import contextlib
import os
import shutil
import subprocess
import sys
import threading
@ -187,6 +188,45 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
for proc in subprocesses:
proc.kill()
def test_gpu_ompi_distributed_initialize(self):
if jtu.device_under_test() != 'gpu':
raise unittest.SkipTest('Tests only for GPU.')
if shutil.which('mpirun') is None:
raise unittest.SkipTest('Tests only for MPI (mpirun not found).')
num_gpus = 4
num_gpus_per_task = 1
with contextlib.ExitStack() as exit_stack:
args = [
'mpirun',
'--oversubscribe',
'--allow-run-as-root',
'-n',
str(num_gpus),
sys.executable,
'-c',
('import jax, os; '
'jax.distributed.initialize(); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")'
)
]
env = os.environ.copy()
# In case the job was launched via Slurm,
# prevent OpenMPI from detecting Slurm environment
env.pop('SLURM_JOBID', None)
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
proc = exit_stack.enter_context(proc)
try:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
finally:
proc.kill()
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")