mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add Open MPI automatic distributed initialization
This commit is contained in:
parent
ced7332587
commit
b86030d86f
@ -78,11 +78,12 @@ jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
|
||||
process_id=0)
|
||||
```
|
||||
|
||||
On Cloud TPU and Slurm environments, you can simply call {func}`jax.distributed.initialize()` with no
|
||||
On Cloud TPU, Slurm and Open MPI environments, you can simply call {func}`jax.distributed.initialize()` with no
|
||||
arguments. Default values for the arguments will be chosen automatically.
|
||||
When running on GPUs with Slurm, it is assumed that one process is started per GPU, i.e. each process will
|
||||
When running on GPUs with Slurm and Open MPI, it is assumed that one process is started per GPU, i.e. each process will
|
||||
be assigned only one visible local device. Otherwise it is assumed that one process is started per host,
|
||||
i.e. each process will be assigned all local devices.
|
||||
The Open MPI auto-initialization is only used when the JAX processes are launched via `mpirun`/`mpiexec`.
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
@ -20,5 +20,6 @@ from .cluster import ClusterEnv
|
||||
# the user did not explicitly provide the arguments
|
||||
# to :func:`jax.distributed.initialize`, the first
|
||||
# available one from the list will be picked.
|
||||
from .ompi_cluster import OmpiCluster
|
||||
from .slurm_cluster import SlurmCluster
|
||||
from .cloud_tpu_cluster import TpuCluster
|
||||
|
59
jax/_src/clusters/ompi_cluster.py
Normal file
59
jax/_src/clusters/ompi_cluster.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
from jax._src.clusters import ClusterEnv
|
||||
|
||||
# OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec
|
||||
_ORTE_URI = 'OMPI_MCA_orte_hnp_uri'
|
||||
_PROCESS_COUNT = 'OMPI_COMM_WORLD_SIZE'
|
||||
_PROCESS_ID = 'OMPI_COMM_WORLD_RANK'
|
||||
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
|
||||
|
||||
class OmpiCluster(ClusterEnv):
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return _ORTE_URI in os.environ
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls) -> str:
|
||||
# Examples of orte_uri:
|
||||
# 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911
|
||||
# 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370
|
||||
orte_uri = os.environ[_ORTE_URI]
|
||||
job_id_str = orte_uri.split('.', maxsplit=1)[0]
|
||||
# The jobid is always a multiple of 2^12, let's divide it by 2^12
|
||||
# to reduce likelihood of port conflict between jobs
|
||||
job_id = int(job_id_str) // 2**12
|
||||
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
|
||||
port = job_id % 2**12 + (65535 - 2**12 + 1)
|
||||
launcher_ip_match = re.search(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
|
||||
if launcher_ip_match is None:
|
||||
raise RuntimeError('Could not parse coordinator IP address from Open MPI environment.')
|
||||
launcher_ip = next(i for i in launcher_ip_match.groups() if i is not None)
|
||||
return f'{launcher_ip}:{port}'
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
return int(os.environ[_PROCESS_COUNT])
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
return int(os.environ[_PROCESS_ID])
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> Optional[int]:
|
||||
return int(os.environ[_LOCAL_PROCESS_ID])
|
@ -119,7 +119,7 @@ def initialize(coordinator_address: Optional[str] = None,
|
||||
* it performs health checking, ensuring that all processes shut down if any process dies, and
|
||||
* it is used for distributed checkpointing.
|
||||
|
||||
If you are using TPU or Slurm, all arguments are optional: if omitted, they
|
||||
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
|
||||
will be chosen automatically.
|
||||
|
||||
Otherwise, you must provide the ``coordinator_address``,
|
||||
@ -140,7 +140,7 @@ def initialize(coordinator_address: Optional[str] = None,
|
||||
May be ``None`` only on supported environments; if ``None`` it will be chosen automatically.
|
||||
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
|
||||
If ``None``, defaults to all local devices being visible to the process except when processes
|
||||
are launched via Slurm on GPUs. In that case, it will default to a single device per process.
|
||||
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
@ -187,6 +188,45 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
|
||||
for proc in subprocesses:
|
||||
proc.kill()
|
||||
|
||||
def test_gpu_ompi_distributed_initialize(self):
|
||||
if jtu.device_under_test() != 'gpu':
|
||||
raise unittest.SkipTest('Tests only for GPU.')
|
||||
if shutil.which('mpirun') is None:
|
||||
raise unittest.SkipTest('Tests only for MPI (mpirun not found).')
|
||||
|
||||
num_gpus = 4
|
||||
num_gpus_per_task = 1
|
||||
|
||||
with contextlib.ExitStack() as exit_stack:
|
||||
args = [
|
||||
'mpirun',
|
||||
'--oversubscribe',
|
||||
'--allow-run-as-root',
|
||||
'-n',
|
||||
str(num_gpus),
|
||||
sys.executable,
|
||||
'-c',
|
||||
('import jax, os; '
|
||||
'jax.distributed.initialize(); '
|
||||
'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")'
|
||||
)
|
||||
]
|
||||
env = os.environ.copy()
|
||||
# In case the job was launched via Slurm,
|
||||
# prevent OpenMPI from detecting Slurm environment
|
||||
env.pop('SLURM_JOBID', None)
|
||||
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, universal_newlines=True)
|
||||
proc = exit_stack.enter_context(proc)
|
||||
|
||||
try:
|
||||
out, _ = proc.communicate()
|
||||
self.assertEqual(proc.returncode, 0)
|
||||
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
|
||||
finally:
|
||||
proc.kill()
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
||||
"Slurm environment with at least two nodes needed!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user