Merge pull request #20174 from coreyjadams:main

PiperOrigin-RevId: 650334673
This commit is contained in:
jax authors 2024-07-08 12:19:18 -07:00
commit 0d57c72644
9 changed files with 208 additions and 2 deletions

View File

@ -942,6 +942,7 @@ pytype_strict_library(
"_src/clusters/__init__.py", "_src/clusters/__init__.py",
"_src/clusters/cloud_tpu_cluster.py", "_src/clusters/cloud_tpu_cluster.py",
"_src/clusters/cluster.py", "_src/clusters/cluster.py",
"_src/clusters/mpi4py_cluster.py",
"_src/clusters/ompi_cluster.py", "_src/clusters/ompi_cluster.py",
"_src/clusters/slurm_cluster.py", "_src/clusters/slurm_cluster.py",
"_src/distributed.py", "_src/distributed.py",

View File

@ -22,5 +22,6 @@ from .cluster import ClusterEnv
# available one from the list will be picked. # available one from the list will be picked.
from .ompi_cluster import OmpiCluster from .ompi_cluster import OmpiCluster
from .slurm_cluster import SlurmCluster from .slurm_cluster import SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster from .cloud_tpu_cluster import GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster from .cloud_tpu_cluster import GceTpuCluster

View File

@ -74,6 +74,9 @@ def has_megascale_address():
return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
class BaseTpuCluster(clusters.ClusterEnv): class BaseTpuCluster(clusters.ClusterEnv):
name: str = "tpu"
"""Abstract cluster supports both single and multislice TPU environments. """Abstract cluster supports both single and multislice TPU environments.
If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology. If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
@ -169,6 +172,9 @@ class BaseTpuCluster(clusters.ClusterEnv):
raise NotImplementedError() raise NotImplementedError()
class GceTpuCluster(BaseTpuCluster): class GceTpuCluster(BaseTpuCluster):
name: str = "gcetpu"
@classmethod @classmethod
def is_env_present(cls) -> bool: def is_env_present(cls) -> bool:
if not running_in_cloud_tpu_vm: if not running_in_cloud_tpu_vm:
@ -194,6 +200,9 @@ class GceTpuCluster(BaseTpuCluster):
return [worker.split(':')[2] for worker in workers] return [worker.split(':')[2] for worker in workers]
class GkeTpuCluster(BaseTpuCluster): class GkeTpuCluster(BaseTpuCluster):
name: str = "gketpu"
@classmethod @classmethod
def is_env_present(cls) -> bool: def is_env_present(cls) -> bool:
if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None: if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None:

View File

@ -31,11 +31,13 @@ class ClusterEnv:
""" """
_cluster_types: list[type[ClusterEnv]] = [] _cluster_types: list[type[ClusterEnv]] = []
opt_in_only_method: bool = False # Override this in derived classes if necessary
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
cls._cluster_types.append(cls) cls._cluster_types.append(cls)
@classmethod @classmethod
# pytype: disable=bad-return-type # pytype: disable=bad-return-type
def auto_detect_unset_distributed_params(cls, def auto_detect_unset_distributed_params(cls,
@ -43,14 +45,33 @@ class ClusterEnv:
num_processes: int | None, num_processes: int | None,
process_id: int | None, process_id: int | None,
local_device_ids: Sequence[int] | None, local_device_ids: Sequence[int] | None,
cluster_detection_method: str | None,
initialization_timeout: int | None, initialization_timeout: int | None,
) -> tuple[str | None, int | None, int | None, ) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]: Sequence[int] | None]:
if all(p is not None for p in (coordinator_address, num_processes, if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)): process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id, return (coordinator_address, num_processes, process_id,
local_device_ids) local_device_ids)
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if cluster_detection_method is not None:
env = next( (env for env in cls._cluster_types if env.name == cluster_detection_method), None ) # pytype: disable=attribute-error
if env is None:
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {cluster_detection_method} is not supported.")
elif not env.is_env_present():
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {cluster_detection_method} is supported but not functional in this environment.")
else:
env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)
# Above: I have wrapped the env selection in a conditional to go through
# opt-in methods first (currently only mpi4py) but to check all possible options
# otherwise. Passing no cluster_detection_method results in the default, original behavior.
if env: if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__) logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None: if coordinator_address is None:

View File

@ -0,0 +1,93 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from jax._src import clusters
import socket
from importlib.util import find_spec
class Mpi4pyCluster(clusters.ClusterEnv):
name: str = "mpi4py"
opt_in_only_method: bool = True
@classmethod
def is_env_present(cls) -> bool:
# Relies on mpi4py:
return find_spec("mpi4py") is not None
@classmethod
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
# Using mpi4py, figure out rank 0 and it's hostname.
# Then broadcast the hostname and port.
from mpi4py import MPI #type: ignore
# Get the global communicator:
COMM_WORLD = MPI.COMM_WORLD
# On rank 0, get the hostname:
if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
hostname = socket.gethostname()
# Apparently, we want to pick a port in an ephemeral range...
port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1)
hostname = f'{hostname}:{port_id}'
else:
hostname = "None"
# Broadcast the host_ip to all ranks:
hostname = COMM_WORLD.bcast(hostname, root=0)
return hostname
@classmethod
def get_process_count(cls) -> int:
from mpi4py import MPI # pytype: disable=import-error
return int(MPI.COMM_WORLD.Get_size())
@classmethod
def get_process_id(cls) -> int:
from mpi4py import MPI # pytype: disable=import-error
return int(MPI.COMM_WORLD.Get_rank())
@classmethod
def get_local_process_id(cls) -> int | None:
# Using mpi4py, split the global communicator into sub communicators
# based on hostname. mpi will assign them ranks and that will allow
# a selection of the local process ID.
from mpi4py import MPI # pytype: disable=import-error
COMM_WORLD = MPI.COMM_WORLD
# This is the alternative method that is simpler:
new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
# The rank in the new communicator - which is host-local only - IS the local rank:
return int(new_comm.Get_rank())

View File

@ -25,6 +25,9 @@ _PROCESS_ID = 'OMPI_COMM_WORLD_RANK'
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK' _LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
class OmpiCluster(clusters.ClusterEnv): class OmpiCluster(clusters.ClusterEnv):
name: str = "ompi"
@classmethod @classmethod
def is_env_present(cls) -> bool: def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ return _ORTE_URI in os.environ

View File

@ -25,6 +25,9 @@ _LOCAL_PROCESS_ID = 'SLURM_LOCALID'
_NUM_NODES = 'SLURM_STEP_NUM_NODES' _NUM_NODES = 'SLURM_STEP_NUM_NODES'
class SlurmCluster(clusters.ClusterEnv): class SlurmCluster(clusters.ClusterEnv):
name: str = "slurm"
@classmethod @classmethod
def is_env_present(cls) -> bool: def is_env_present(cls) -> bool:
return _JOBID_PARAM in os.environ return _JOBID_PARAM in os.environ

View File

@ -41,6 +41,7 @@ class State:
num_processes: int | None = None, num_processes: int | None = None,
process_id: int | None = None, process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None, local_device_ids: int | Sequence[int] | None = None,
cluster_detection_method: str | None = None,
initialization_timeout: int = 300, initialization_timeout: int = 300,
coordinator_bind_address: str | None = None): coordinator_bind_address: str | None = None):
coordinator_address = (coordinator_address or coordinator_address = (coordinator_address or
@ -48,12 +49,14 @@ class State:
if isinstance(local_device_ids, int): if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids] local_device_ids = [local_device_ids]
(coordinator_address, num_processes, process_id, local_device_ids) = ( (coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params( clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address, coordinator_address,
num_processes, num_processes,
process_id, process_id,
local_device_ids, local_device_ids,
cluster_detection_method,
initialization_timeout, initialization_timeout,
) )
) )
@ -84,6 +87,18 @@ class State:
self.process_id = process_id self.process_id = process_id
# Emit a warning about PROXY variables if they are in the user's env:
proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()]
if len(proxy_vars) > 0:
vars = " ".join(proxy_vars) + ". "
warning = (
f'JAX detected proxy variable(s) in the environment as distributed setup: {vars}'
'On some systems, this may cause a hang of distributed.initialize and '
'you may need to unset these ENV variable(s)'
)
logger.warning(warning)
if process_id == 0: if process_id == 0:
if self.service is not None: if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.') raise RuntimeError('distributed.initialize should only be called once.')
@ -130,6 +145,7 @@ def initialize(coordinator_address: str | None = None,
num_processes: int | None = None, num_processes: int | None = None,
process_id: int | None = None, process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None, local_device_ids: int | Sequence[int] | None = None,
cluster_detection_method: str | None = None,
initialization_timeout: int = 300, initialization_timeout: int = 300,
coordinator_bind_address: str | None = None): coordinator_bind_address: str | None = None):
"""Initializes the JAX distributed system. """Initializes the JAX distributed system.
@ -147,9 +163,20 @@ def initialize(coordinator_address: str | None = None,
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
will be chosen automatically. will be chosen automatically.
The ``cluster_detection_method`` may be used to choose a specific method for detecting those
distributed arguments. You may pass any of the automatic ``spec_detect_methods`` to this
argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI
installations, if you have a functional ``mpi4py`` installed, you may pass
``cluster_detection_method="mpi4py"`` to bootstrap the required arguments.
Otherwise, you must provide the ``coordinator_address``, Otherwise, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`. ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
Please note: on some systems, particularly HPC clusters that only access external networks
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
:func:`~jax.distributed.initialize` may timeout. You may need to unset these variables
prior to application launch.
Args: Args:
coordinator_address: the IP address of process `0` and a port on which that coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of process should launch a coordinator service. The choice of
@ -166,6 +193,10 @@ def initialize(coordinator_address: str | None = None,
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``. local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
If ``None``, defaults to all local devices being visible to the process except when processes If ``None``, defaults to all local devices being visible to the process except when processes
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process. are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
Legacy auto-detect options (OMPI, Slurm) remain enabled.
initialization_timeout: Time period (in seconds) for which connection will initialization_timeout: Time period (in seconds) for which connection will
be retried. If the initialization takes more than the timeout specified, be retried. If the initialization takes more than the timeout specified,
the initialization will error. Defaults to 300 secs i.e. 5 mins. the initialization will error. Defaults to 300 secs i.e. 5 mins.
@ -197,7 +228,8 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before " raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.") "any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id, global_state.initialize(coordinator_address, num_processes, process_id,
local_device_ids, initialization_timeout, coordinator_bind_address) local_device_ids, cluster_detection_method,
initialization_timeout, coordinator_bind_address)
atexit.register(shutdown) atexit.register(shutdown)

View File

@ -33,6 +33,9 @@ from jax._src import util
from jax.experimental import pjit from jax.experimental import pjit
import jax.numpy as jnp import jax.numpy as jnp
# Used to test for mpi4py installation and skip tests if not installed
import importlib.util
try: try:
import portpicker import portpicker
except ImportError: except ImportError:
@ -218,6 +221,46 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
finally: finally:
proc.kill() proc.kill()
def test_gpu_mpi4py_distributed_initialize(self):
if not jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Tests only for GPU.')
if shutil.which('mpirun') is None:
raise unittest.SkipTest('Tests only for MPI (mpirun not found).')
if importlib.util.find_spec("mpi4py") is None:
raise unittest.SkipTest('Test of mpi4py initialize only possible with mpi4py installed.')
num_gpus = 4
num_gpus_per_task = 1
with contextlib.ExitStack() as exit_stack:
args = [
'mpirun',
'--oversubscribe',
'--allow-run-as-root',
'-n',
str(num_gpus),
sys.executable,
'-c',
('import jax, os; '
'jax.distributed.initialize(spec_detection_method="mpi4py"); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")'
)
]
env = os.environ.copy()
# In case the job was launched via Slurm,
# prevent OpenMPI from detecting Slurm environment
env.pop('SLURM_JOBID', None)
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
proc = exit_stack.enter_context(proc)
try:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
finally:
proc.kill()
@unittest.skipIf( @unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2", os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",