From b86030d86fc461e4323f933ad5081043cff3a994 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Mon, 9 Jan 2023 11:42:20 -0600 Subject: [PATCH] Add Open MPI automatic distributed initialization --- docs/multi_process.md | 5 +-- jax/_src/clusters/__init__.py | 1 + jax/_src/clusters/ompi_cluster.py | 59 +++++++++++++++++++++++++++++++ jax/_src/distributed.py | 4 +-- tests/multiprocess_gpu_test.py | 40 +++++++++++++++++++++ 5 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 jax/_src/clusters/ompi_cluster.py diff --git a/docs/multi_process.md b/docs/multi_process.md index ce896f13c..d59d39d05 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -78,11 +78,12 @@ jax.distributed.initialize(coordinator_address="192.168.0.1:1234", process_id=0) ``` -On Cloud TPU and Slurm environments, you can simply call {func}`jax.distributed.initialize()` with no +On Cloud TPU, Slurm and Open MPI environments, you can simply call {func}`jax.distributed.initialize()` with no arguments. Default values for the arguments will be chosen automatically. -When running on GPUs with Slurm, it is assumed that one process is started per GPU, i.e. each process will +When running on GPUs with Slurm and Open MPI, it is assumed that one process is started per GPU, i.e. each process will be assigned only one visible local device. Otherwise it is assumed that one process is started per host, i.e. each process will be assigned all local devices. +The Open MPI auto-initialization is only used when the JAX processes are launched via `mpirun`/`mpiexec`. ```python import jax diff --git a/jax/_src/clusters/__init__.py b/jax/_src/clusters/__init__.py index eec249df0..c2afe9a7f 100644 --- a/jax/_src/clusters/__init__.py +++ b/jax/_src/clusters/__init__.py @@ -20,5 +20,6 @@ from .cluster import ClusterEnv # the user did not explicitly provide the arguments # to :func:`jax.distributed.initialize`, the first # available one from the list will be picked. +from .ompi_cluster import OmpiCluster from .slurm_cluster import SlurmCluster from .cloud_tpu_cluster import TpuCluster diff --git a/jax/_src/clusters/ompi_cluster.py b/jax/_src/clusters/ompi_cluster.py new file mode 100644 index 000000000..76cac8153 --- /dev/null +++ b/jax/_src/clusters/ompi_cluster.py @@ -0,0 +1,59 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from typing import Optional +from jax._src.clusters import ClusterEnv + +# OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec +_ORTE_URI = 'OMPI_MCA_orte_hnp_uri' +_PROCESS_COUNT = 'OMPI_COMM_WORLD_SIZE' +_PROCESS_ID = 'OMPI_COMM_WORLD_RANK' +_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK' + +class OmpiCluster(ClusterEnv): + @classmethod + def is_env_present(cls) -> bool: + return _ORTE_URI in os.environ + + @classmethod + def get_coordinator_address(cls) -> str: + # Examples of orte_uri: + # 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911 + # 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370 + orte_uri = os.environ[_ORTE_URI] + job_id_str = orte_uri.split('.', maxsplit=1)[0] + # The jobid is always a multiple of 2^12, let's divide it by 2^12 + # to reduce likelihood of port conflict between jobs + job_id = int(job_id_str) // 2**12 + # Pick port in ephemeral range [(65535 - 2^12 + 1), 65535] + port = job_id % 2**12 + (65535 - 2**12 + 1) + launcher_ip_match = re.search(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri) + if launcher_ip_match is None: + raise RuntimeError('Could not parse coordinator IP address from Open MPI environment.') + launcher_ip = next(i for i in launcher_ip_match.groups() if i is not None) + return f'{launcher_ip}:{port}' + + @classmethod + def get_process_count(cls) -> int: + return int(os.environ[_PROCESS_COUNT]) + + @classmethod + def get_process_id(cls) -> int: + return int(os.environ[_PROCESS_ID]) + + @classmethod + def get_local_process_id(cls) -> Optional[int]: + return int(os.environ[_LOCAL_PROCESS_ID]) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 86ed43356..af0692202 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -119,7 +119,7 @@ def initialize(coordinator_address: Optional[str] = None, * it performs health checking, ensuring that all processes shut down if any process dies, and * it is used for distributed checkpointing. - If you are using TPU or Slurm, all arguments are optional: if omitted, they + If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they will be chosen automatically. Otherwise, you must provide the ``coordinator_address``, @@ -140,7 +140,7 @@ def initialize(coordinator_address: Optional[str] = None, May be ``None`` only on supported environments; if ``None`` it will be chosen automatically. local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``. If ``None``, defaults to all local devices being visible to the process except when processes - are launched via Slurm on GPUs. In that case, it will default to a single device per process. + are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process. Raises: RuntimeError: If :func:`~jax.distributed.initialize` is called more than once. diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 40702fa55..fc62608c5 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -14,6 +14,7 @@ import contextlib import os +import shutil import subprocess import sys import threading @@ -187,6 +188,45 @@ class MultiProcessGpuTest(jtu.JaxTestCase): for proc in subprocesses: proc.kill() + def test_gpu_ompi_distributed_initialize(self): + if jtu.device_under_test() != 'gpu': + raise unittest.SkipTest('Tests only for GPU.') + if shutil.which('mpirun') is None: + raise unittest.SkipTest('Tests only for MPI (mpirun not found).') + + num_gpus = 4 + num_gpus_per_task = 1 + + with contextlib.ExitStack() as exit_stack: + args = [ + 'mpirun', + '--oversubscribe', + '--allow-run-as-root', + '-n', + str(num_gpus), + sys.executable, + '-c', + ('import jax, os; ' + 'jax.distributed.initialize(); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")' + ) + ] + env = os.environ.copy() + # In case the job was launched via Slurm, + # prevent OpenMPI from detecting Slurm environment + env.pop('SLURM_JOBID', None) + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=True) + proc = exit_stack.enter_context(proc) + + try: + out, _ = proc.communicate() + self.assertEqual(proc.returncode, 0) + self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') + finally: + proc.kill() + + @unittest.skipIf( os.environ.get("SLURM_JOB_NUM_NODES", None) != "2", "Slurm environment with at least two nodes needed!")