mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #20174 from coreyjadams:main
PiperOrigin-RevId: 650334673
This commit is contained in:
commit
0d57c72644
@ -942,6 +942,7 @@ pytype_strict_library(
|
|||||||
"_src/clusters/__init__.py",
|
"_src/clusters/__init__.py",
|
||||||
"_src/clusters/cloud_tpu_cluster.py",
|
"_src/clusters/cloud_tpu_cluster.py",
|
||||||
"_src/clusters/cluster.py",
|
"_src/clusters/cluster.py",
|
||||||
|
"_src/clusters/mpi4py_cluster.py",
|
||||||
"_src/clusters/ompi_cluster.py",
|
"_src/clusters/ompi_cluster.py",
|
||||||
"_src/clusters/slurm_cluster.py",
|
"_src/clusters/slurm_cluster.py",
|
||||||
"_src/distributed.py",
|
"_src/distributed.py",
|
||||||
|
@ -22,5 +22,6 @@ from .cluster import ClusterEnv
|
|||||||
# available one from the list will be picked.
|
# available one from the list will be picked.
|
||||||
from .ompi_cluster import OmpiCluster
|
from .ompi_cluster import OmpiCluster
|
||||||
from .slurm_cluster import SlurmCluster
|
from .slurm_cluster import SlurmCluster
|
||||||
|
from .mpi4py_cluster import Mpi4pyCluster
|
||||||
from .cloud_tpu_cluster import GkeTpuCluster
|
from .cloud_tpu_cluster import GkeTpuCluster
|
||||||
from .cloud_tpu_cluster import GceTpuCluster
|
from .cloud_tpu_cluster import GceTpuCluster
|
||||||
|
@ -74,6 +74,9 @@ def has_megascale_address():
|
|||||||
return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
|
return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
|
||||||
|
|
||||||
class BaseTpuCluster(clusters.ClusterEnv):
|
class BaseTpuCluster(clusters.ClusterEnv):
|
||||||
|
|
||||||
|
name: str = "tpu"
|
||||||
|
|
||||||
"""Abstract cluster supports both single and multislice TPU environments.
|
"""Abstract cluster supports both single and multislice TPU environments.
|
||||||
|
|
||||||
If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
|
If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
|
||||||
@ -169,6 +172,9 @@ class BaseTpuCluster(clusters.ClusterEnv):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
class GceTpuCluster(BaseTpuCluster):
|
class GceTpuCluster(BaseTpuCluster):
|
||||||
|
|
||||||
|
name: str = "gcetpu"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_env_present(cls) -> bool:
|
def is_env_present(cls) -> bool:
|
||||||
if not running_in_cloud_tpu_vm:
|
if not running_in_cloud_tpu_vm:
|
||||||
@ -194,6 +200,9 @@ class GceTpuCluster(BaseTpuCluster):
|
|||||||
return [worker.split(':')[2] for worker in workers]
|
return [worker.split(':')[2] for worker in workers]
|
||||||
|
|
||||||
class GkeTpuCluster(BaseTpuCluster):
|
class GkeTpuCluster(BaseTpuCluster):
|
||||||
|
|
||||||
|
name: str = "gketpu"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_env_present(cls) -> bool:
|
def is_env_present(cls) -> bool:
|
||||||
if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None:
|
if running_in_cloud_tpu_vm and os.environ.get("TPU_WORKER_HOSTNAMES") is not None:
|
||||||
|
@ -31,11 +31,13 @@ class ClusterEnv:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_cluster_types: list[type[ClusterEnv]] = []
|
_cluster_types: list[type[ClusterEnv]] = []
|
||||||
|
opt_in_only_method: bool = False # Override this in derived classes if necessary
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
cls._cluster_types.append(cls)
|
cls._cluster_types.append(cls)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
# pytype: disable=bad-return-type
|
# pytype: disable=bad-return-type
|
||||||
def auto_detect_unset_distributed_params(cls,
|
def auto_detect_unset_distributed_params(cls,
|
||||||
@ -43,14 +45,33 @@ class ClusterEnv:
|
|||||||
num_processes: int | None,
|
num_processes: int | None,
|
||||||
process_id: int | None,
|
process_id: int | None,
|
||||||
local_device_ids: Sequence[int] | None,
|
local_device_ids: Sequence[int] | None,
|
||||||
|
cluster_detection_method: str | None,
|
||||||
initialization_timeout: int | None,
|
initialization_timeout: int | None,
|
||||||
) -> tuple[str | None, int | None, int | None,
|
) -> tuple[str | None, int | None, int | None,
|
||||||
Sequence[int] | None]:
|
Sequence[int] | None]:
|
||||||
|
|
||||||
if all(p is not None for p in (coordinator_address, num_processes,
|
if all(p is not None for p in (coordinator_address, num_processes,
|
||||||
process_id, local_device_ids)):
|
process_id, local_device_ids)):
|
||||||
return (coordinator_address, num_processes, process_id,
|
return (coordinator_address, num_processes, process_id,
|
||||||
local_device_ids)
|
local_device_ids)
|
||||||
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
|
|
||||||
|
# First, we check the spec detection method because it will ignore submitted values
|
||||||
|
# If if succeeds.
|
||||||
|
if cluster_detection_method is not None:
|
||||||
|
env = next( (env for env in cls._cluster_types if env.name == cluster_detection_method), None ) # pytype: disable=attribute-error
|
||||||
|
if env is None:
|
||||||
|
logger.error(f"Automatic Distributed initialization can not proceed:"
|
||||||
|
f" {cluster_detection_method} is not supported.")
|
||||||
|
elif not env.is_env_present():
|
||||||
|
logger.error(f"Automatic Distributed initialization can not proceed:"
|
||||||
|
f" {cluster_detection_method} is supported but not functional in this environment.")
|
||||||
|
else:
|
||||||
|
env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)
|
||||||
|
|
||||||
|
# Above: I have wrapped the env selection in a conditional to go through
|
||||||
|
# opt-in methods first (currently only mpi4py) but to check all possible options
|
||||||
|
# otherwise. Passing no cluster_detection_method results in the default, original behavior.
|
||||||
|
|
||||||
if env:
|
if env:
|
||||||
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
|
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
|
||||||
if coordinator_address is None:
|
if coordinator_address is None:
|
||||||
|
93
jax/_src/clusters/mpi4py_cluster.py
Normal file
93
jax/_src/clusters/mpi4py_cluster.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright 2024 The JAX Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from jax._src import clusters
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
|
||||||
|
class Mpi4pyCluster(clusters.ClusterEnv):
|
||||||
|
|
||||||
|
|
||||||
|
name: str = "mpi4py"
|
||||||
|
opt_in_only_method: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_env_present(cls) -> bool:
|
||||||
|
|
||||||
|
# Relies on mpi4py:
|
||||||
|
return find_spec("mpi4py") is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
|
||||||
|
|
||||||
|
# Using mpi4py, figure out rank 0 and it's hostname.
|
||||||
|
# Then broadcast the hostname and port.
|
||||||
|
|
||||||
|
|
||||||
|
from mpi4py import MPI #type: ignore
|
||||||
|
# Get the global communicator:
|
||||||
|
COMM_WORLD = MPI.COMM_WORLD
|
||||||
|
|
||||||
|
# On rank 0, get the hostname:
|
||||||
|
|
||||||
|
if COMM_WORLD.Get_rank() == 0:
|
||||||
|
# Order all the hostnames, and find unique ones
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
|
# Apparently, we want to pick a port in an ephemeral range...
|
||||||
|
port_id = hash(hostname) % 2**12 + (65535 - 2**12 + 1)
|
||||||
|
|
||||||
|
hostname = f'{hostname}:{port_id}'
|
||||||
|
|
||||||
|
else:
|
||||||
|
hostname = "None"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Broadcast the host_ip to all ranks:
|
||||||
|
hostname = COMM_WORLD.bcast(hostname, root=0)
|
||||||
|
|
||||||
|
|
||||||
|
return hostname
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_process_count(cls) -> int:
|
||||||
|
from mpi4py import MPI # pytype: disable=import-error
|
||||||
|
return int(MPI.COMM_WORLD.Get_size())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_process_id(cls) -> int:
|
||||||
|
from mpi4py import MPI # pytype: disable=import-error
|
||||||
|
return int(MPI.COMM_WORLD.Get_rank())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_local_process_id(cls) -> int | None:
|
||||||
|
|
||||||
|
# Using mpi4py, split the global communicator into sub communicators
|
||||||
|
# based on hostname. mpi will assign them ranks and that will allow
|
||||||
|
# a selection of the local process ID.
|
||||||
|
from mpi4py import MPI # pytype: disable=import-error
|
||||||
|
COMM_WORLD = MPI.COMM_WORLD
|
||||||
|
|
||||||
|
# This is the alternative method that is simpler:
|
||||||
|
new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
|
||||||
|
|
||||||
|
|
||||||
|
# The rank in the new communicator - which is host-local only - IS the local rank:
|
||||||
|
return int(new_comm.Get_rank())
|
@ -25,6 +25,9 @@ _PROCESS_ID = 'OMPI_COMM_WORLD_RANK'
|
|||||||
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
|
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
|
||||||
|
|
||||||
class OmpiCluster(clusters.ClusterEnv):
|
class OmpiCluster(clusters.ClusterEnv):
|
||||||
|
|
||||||
|
name: str = "ompi"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_env_present(cls) -> bool:
|
def is_env_present(cls) -> bool:
|
||||||
return _ORTE_URI in os.environ
|
return _ORTE_URI in os.environ
|
||||||
|
@ -25,6 +25,9 @@ _LOCAL_PROCESS_ID = 'SLURM_LOCALID'
|
|||||||
_NUM_NODES = 'SLURM_STEP_NUM_NODES'
|
_NUM_NODES = 'SLURM_STEP_NUM_NODES'
|
||||||
|
|
||||||
class SlurmCluster(clusters.ClusterEnv):
|
class SlurmCluster(clusters.ClusterEnv):
|
||||||
|
|
||||||
|
name: str = "slurm"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_env_present(cls) -> bool:
|
def is_env_present(cls) -> bool:
|
||||||
return _JOBID_PARAM in os.environ
|
return _JOBID_PARAM in os.environ
|
||||||
|
@ -41,6 +41,7 @@ class State:
|
|||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
process_id: int | None = None,
|
process_id: int | None = None,
|
||||||
local_device_ids: int | Sequence[int] | None = None,
|
local_device_ids: int | Sequence[int] | None = None,
|
||||||
|
cluster_detection_method: str | None = None,
|
||||||
initialization_timeout: int = 300,
|
initialization_timeout: int = 300,
|
||||||
coordinator_bind_address: str | None = None):
|
coordinator_bind_address: str | None = None):
|
||||||
coordinator_address = (coordinator_address or
|
coordinator_address = (coordinator_address or
|
||||||
@ -48,12 +49,14 @@ class State:
|
|||||||
if isinstance(local_device_ids, int):
|
if isinstance(local_device_ids, int):
|
||||||
local_device_ids = [local_device_ids]
|
local_device_ids = [local_device_ids]
|
||||||
|
|
||||||
|
|
||||||
(coordinator_address, num_processes, process_id, local_device_ids) = (
|
(coordinator_address, num_processes, process_id, local_device_ids) = (
|
||||||
clusters.ClusterEnv.auto_detect_unset_distributed_params(
|
clusters.ClusterEnv.auto_detect_unset_distributed_params(
|
||||||
coordinator_address,
|
coordinator_address,
|
||||||
num_processes,
|
num_processes,
|
||||||
process_id,
|
process_id,
|
||||||
local_device_ids,
|
local_device_ids,
|
||||||
|
cluster_detection_method,
|
||||||
initialization_timeout,
|
initialization_timeout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -84,6 +87,18 @@ class State:
|
|||||||
|
|
||||||
self.process_id = process_id
|
self.process_id = process_id
|
||||||
|
|
||||||
|
# Emit a warning about PROXY variables if they are in the user's env:
|
||||||
|
proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()]
|
||||||
|
|
||||||
|
if len(proxy_vars) > 0:
|
||||||
|
vars = " ".join(proxy_vars) + ". "
|
||||||
|
warning = (
|
||||||
|
f'JAX detected proxy variable(s) in the environment as distributed setup: {vars}'
|
||||||
|
'On some systems, this may cause a hang of distributed.initialize and '
|
||||||
|
'you may need to unset these ENV variable(s)'
|
||||||
|
)
|
||||||
|
logger.warning(warning)
|
||||||
|
|
||||||
if process_id == 0:
|
if process_id == 0:
|
||||||
if self.service is not None:
|
if self.service is not None:
|
||||||
raise RuntimeError('distributed.initialize should only be called once.')
|
raise RuntimeError('distributed.initialize should only be called once.')
|
||||||
@ -130,6 +145,7 @@ def initialize(coordinator_address: str | None = None,
|
|||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
process_id: int | None = None,
|
process_id: int | None = None,
|
||||||
local_device_ids: int | Sequence[int] | None = None,
|
local_device_ids: int | Sequence[int] | None = None,
|
||||||
|
cluster_detection_method: str | None = None,
|
||||||
initialization_timeout: int = 300,
|
initialization_timeout: int = 300,
|
||||||
coordinator_bind_address: str | None = None):
|
coordinator_bind_address: str | None = None):
|
||||||
"""Initializes the JAX distributed system.
|
"""Initializes the JAX distributed system.
|
||||||
@ -147,9 +163,20 @@ def initialize(coordinator_address: str | None = None,
|
|||||||
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
|
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
|
||||||
will be chosen automatically.
|
will be chosen automatically.
|
||||||
|
|
||||||
|
The ``cluster_detection_method`` may be used to choose a specific method for detecting those
|
||||||
|
distributed arguments. You may pass any of the automatic ``spec_detect_methods`` to this
|
||||||
|
argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI
|
||||||
|
installations, if you have a functional ``mpi4py`` installed, you may pass
|
||||||
|
``cluster_detection_method="mpi4py"`` to bootstrap the required arguments.
|
||||||
|
|
||||||
Otherwise, you must provide the ``coordinator_address``,
|
Otherwise, you must provide the ``coordinator_address``,
|
||||||
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
|
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
|
||||||
|
|
||||||
|
Please note: on some systems, particularly HPC clusters that only access external networks
|
||||||
|
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
|
||||||
|
:func:`~jax.distributed.initialize` may timeout. You may need to unset these variables
|
||||||
|
prior to application launch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
coordinator_address: the IP address of process `0` and a port on which that
|
coordinator_address: the IP address of process `0` and a port on which that
|
||||||
process should launch a coordinator service. The choice of
|
process should launch a coordinator service. The choice of
|
||||||
@ -166,6 +193,10 @@ def initialize(coordinator_address: str | None = None,
|
|||||||
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
|
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
|
||||||
If ``None``, defaults to all local devices being visible to the process except when processes
|
If ``None``, defaults to all local devices being visible to the process except when processes
|
||||||
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
|
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
|
||||||
|
cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
|
||||||
|
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
|
||||||
|
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
|
||||||
|
Legacy auto-detect options (OMPI, Slurm) remain enabled.
|
||||||
initialization_timeout: Time period (in seconds) for which connection will
|
initialization_timeout: Time period (in seconds) for which connection will
|
||||||
be retried. If the initialization takes more than the timeout specified,
|
be retried. If the initialization takes more than the timeout specified,
|
||||||
the initialization will error. Defaults to 300 secs i.e. 5 mins.
|
the initialization will error. Defaults to 300 secs i.e. 5 mins.
|
||||||
@ -197,7 +228,8 @@ def initialize(coordinator_address: str | None = None,
|
|||||||
raise RuntimeError("jax.distributed.initialize() must be called before "
|
raise RuntimeError("jax.distributed.initialize() must be called before "
|
||||||
"any JAX computations are executed.")
|
"any JAX computations are executed.")
|
||||||
global_state.initialize(coordinator_address, num_processes, process_id,
|
global_state.initialize(coordinator_address, num_processes, process_id,
|
||||||
local_device_ids, initialization_timeout, coordinator_bind_address)
|
local_device_ids, cluster_detection_method,
|
||||||
|
initialization_timeout, coordinator_bind_address)
|
||||||
atexit.register(shutdown)
|
atexit.register(shutdown)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,6 +33,9 @@ from jax._src import util
|
|||||||
from jax.experimental import pjit
|
from jax.experimental import pjit
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
# Used to test for mpi4py installation and skip tests if not installed
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import portpicker
|
import portpicker
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -218,6 +221,46 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
|
|||||||
finally:
|
finally:
|
||||||
proc.kill()
|
proc.kill()
|
||||||
|
|
||||||
|
def test_gpu_mpi4py_distributed_initialize(self):
|
||||||
|
if not jtu.test_device_matches(['gpu']):
|
||||||
|
raise unittest.SkipTest('Tests only for GPU.')
|
||||||
|
if shutil.which('mpirun') is None:
|
||||||
|
raise unittest.SkipTest('Tests only for MPI (mpirun not found).')
|
||||||
|
if importlib.util.find_spec("mpi4py") is None:
|
||||||
|
raise unittest.SkipTest('Test of mpi4py initialize only possible with mpi4py installed.')
|
||||||
|
|
||||||
|
num_gpus = 4
|
||||||
|
num_gpus_per_task = 1
|
||||||
|
|
||||||
|
with contextlib.ExitStack() as exit_stack:
|
||||||
|
args = [
|
||||||
|
'mpirun',
|
||||||
|
'--oversubscribe',
|
||||||
|
'--allow-run-as-root',
|
||||||
|
'-n',
|
||||||
|
str(num_gpus),
|
||||||
|
sys.executable,
|
||||||
|
'-c',
|
||||||
|
('import jax, os; '
|
||||||
|
'jax.distributed.initialize(spec_detection_method="mpi4py"); '
|
||||||
|
'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")'
|
||||||
|
)
|
||||||
|
]
|
||||||
|
env = os.environ.copy()
|
||||||
|
# In case the job was launched via Slurm,
|
||||||
|
# prevent OpenMPI from detecting Slurm environment
|
||||||
|
env.pop('SLURM_JOBID', None)
|
||||||
|
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE, universal_newlines=True)
|
||||||
|
proc = exit_stack.enter_context(proc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
out, _ = proc.communicate()
|
||||||
|
self.assertEqual(proc.returncode, 0)
|
||||||
|
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
|
||||||
|
finally:
|
||||||
|
proc.kill()
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user