mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add generic interface for auto initialization of distributed JAX service
* Also add slurm cluster support
This commit is contained in:
parent
640e15fe07
commit
412a5379c1
@ -64,7 +64,9 @@ The API {func}`jax.distributed.initialize` takes several arguments, namely:
|
||||
cluster will connect.
|
||||
* `num_processes`: the number of processes in the cluster
|
||||
* `process_id`: the ID number of this process, in the range `[0 ..
|
||||
num_processes)`.
|
||||
num_processes)`.
|
||||
* `local_device_ids`: Restricts the visible devices of the current process to
|
||||
``local_device_ids``.
|
||||
|
||||
For example on GPU, a typical usage is:
|
||||
|
||||
@ -76,9 +78,11 @@ jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
|
||||
process_id=0)
|
||||
```
|
||||
|
||||
On Cloud TPU, you can simply call {func}`jax.distributed.initialize()` with no
|
||||
arguments. Default values for the arguments will be chosen automatically using
|
||||
the TPU pod metadata:
|
||||
On Cloud TPU and Slurm environments, you can simply call {func}`jax.distributed.initialize()` with no
|
||||
arguments. Default values for the arguments will be chosen automatically.
|
||||
When running on GPUs with Slurm, it is assumed that one process is started per GPU, i.e. each process will
|
||||
be assigned only one visible local device. Otherwise it is assumed that one process is started per host,
|
||||
i.e. each process will be assigned all local devices.
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
24
jax/_src/clusters/__init__.py
Normal file
24
jax/_src/clusters/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cluster import ClusterEnv
|
||||
|
||||
# Order of declaration of the cluster environments
|
||||
# will dictate the order in which they will be checked.
|
||||
# Therefore, if multiple environments are available and
|
||||
# the user did not explicitly provide the arguments
|
||||
# to :func:`jax.distributed.initialize`, the first
|
||||
# available one from the list will be picked.
|
||||
from .slurm_cluster import SlurmCluster
|
||||
from .cloud_tpu_cluster import TpuCluster
|
47
jax/_src/clusters/cloud_tpu_cluster.py
Normal file
47
jax/_src/clusters/cloud_tpu_cluster.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from jax._src.clusters import ClusterEnv
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm, get_metadata
|
||||
|
||||
class TpuCluster(ClusterEnv):
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return running_in_cloud_tpu_vm
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls) -> str:
|
||||
return cls._get_worker_endpoints()[0].split(':')[2] + ':8476'
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
return xla_bridge.process_count()
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
if cls.get_process_count() != len(cls._get_worker_endpoints()):
|
||||
raise RuntimeError('Number of workers does not equal the number of '
|
||||
'processes. Auto detecting process_id is not possible.'
|
||||
'Please pass process_id to jax.distributed.initialize() manually.')
|
||||
return int(get_metadata('agent-worker-number'))
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> Optional[int]:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_worker_endpoints() -> str:
|
||||
return get_metadata('worker-network-endpoints').split(',')
|
109
jax/_src/clusters/cluster.py
Normal file
109
jax/_src/clusters/cluster.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Type, Sequence, Tuple
|
||||
from absl import logging
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
|
||||
|
||||
class ClusterEnv:
|
||||
"""Interface for defining a cluster environment.
|
||||
|
||||
To enable auto bootrapping (aka :func:`jax.distributed.initialize()`),
|
||||
cluster environments need to derive from :class:`ClusterEnv` and implement
|
||||
:func:`is_env_present`, :func:`get_coordinator_address`,
|
||||
:func:`get_process_count`, and :func:`get_process_id`.
|
||||
:class:`ClusterEnv` subclasses are automatically detected when imported.
|
||||
"""
|
||||
|
||||
_cluster_types: List[Type['ClusterEnv']] = []
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._cluster_types.append(cls)
|
||||
|
||||
@classmethod
|
||||
# pytype: disable=bad-return-type
|
||||
def auto_detect_unset_distributed_params(cls,
|
||||
coordinator_address: Optional[str],
|
||||
num_processes: Optional[int],
|
||||
process_id: Optional[int],
|
||||
local_device_ids: Optional[Sequence[int]]
|
||||
) -> Tuple[Optional[str], Optional[int], Optional[int],
|
||||
Optional[Sequence[int]]]:
|
||||
if all(p is not None for p in (coordinator_address, num_processes,
|
||||
process_id, local_device_ids)):
|
||||
return (coordinator_address, num_processes, process_id,
|
||||
local_device_ids)
|
||||
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
|
||||
if env:
|
||||
logging.vlog(1, 'Initializing distributed JAX environment via %s', env.__name__)
|
||||
if coordinator_address is None:
|
||||
coordinator_address = env.get_coordinator_address()
|
||||
if num_processes is None:
|
||||
num_processes = env.get_process_count()
|
||||
if process_id is None:
|
||||
process_id = env.get_process_id()
|
||||
# Never automatically set local_device_ids on TPUs
|
||||
# Defaults to single process per device if local_process_id is available.
|
||||
# This only runs if we're in a managed distributed environment.
|
||||
# Otherwise local_device_ids will remain unset,
|
||||
# which will default to all devices being visible.
|
||||
if (local_device_ids is None and not running_in_cloud_tpu_vm and
|
||||
env.get_local_process_id() is not None):
|
||||
local_device_ids = [env.get_local_process_id()] # type: ignore[list-item]
|
||||
else:
|
||||
logging.vlog(1, 'Could not find a known environment for initializing distributed JAX. '
|
||||
'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
|
||||
return (coordinator_address, num_processes, process_id, local_device_ids)
|
||||
# pytype: enable=bad-return-type
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
"""Returns True if process is running in this cluster environment.
|
||||
"""
|
||||
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls) -> str:
|
||||
"""Returns address and port used by JAX to bootstrap.
|
||||
|
||||
Process id 0 will open a tcp socket at "hostname:port" where
|
||||
all the proccesses will connect to initialize the distributed JAX service.
|
||||
The selected port needs to be free.
|
||||
:func:`get_coordinator_address` needs to return the same hostname and port on all the processes.
|
||||
|
||||
Returns:
|
||||
"hostname:port"
|
||||
"""
|
||||
raise NotImplementedError("ClusterEnv subclasses must implement get_coordinator_address")
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
raise NotImplementedError("ClusterEnv subclasses must implement get_process_count")
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
raise NotImplementedError("ClusterEnv subclasses must implement get_process_id")
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> Optional[int]:
|
||||
""" Get index of current process inside a host.
|
||||
|
||||
The method is only useful to support single device per process.
|
||||
In that case, each process will see a local device whose ID is
|
||||
the same as its local process ID.
|
||||
If None, JAX will not restrict the visible devices.
|
||||
"""
|
||||
return None
|
62
jax/_src/clusters/slurm_cluster.py
Normal file
62
jax/_src/clusters/slurm_cluster.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from jax._src.clusters import ClusterEnv
|
||||
|
||||
_JOBID_PARAM = 'SLURM_JOB_ID'
|
||||
_NODE_LIST = 'SLURM_STEP_NODELIST'
|
||||
_PROCESS_COUNT = 'SLURM_NTASKS'
|
||||
_PROCESS_ID = 'SLURM_PROCID'
|
||||
_LOCAL_PROCESS_ID = 'SLURM_LOCALID'
|
||||
_NUM_NODES = 'SLURM_STEP_NUM_NODES'
|
||||
|
||||
class SlurmCluster(ClusterEnv):
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return _JOBID_PARAM in os.environ
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls) -> str:
|
||||
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
|
||||
port = int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1)
|
||||
|
||||
# Parse the first hostname of the job
|
||||
# If we are looking for 'node001',
|
||||
# node_list potential formats are 'node001', 'node001,host2',
|
||||
# 'node[001-0015],host2', and 'node[001,007-015],host2'.
|
||||
node_list = os.environ[_NODE_LIST]
|
||||
delims = {',', '['}
|
||||
ind = next((i for i, ch in enumerate(node_list) if ch in delims), len(node_list))
|
||||
if ind == len(node_list) or node_list[ind] == ',': # Formats: 'node001' or 'node001,host2'
|
||||
return f'{node_list[:ind]}:{port}'
|
||||
else: # Formats: 'node[001-0015],host2' or 'node[001,007-015],host2'
|
||||
prefix = node_list[:ind]
|
||||
suffix = node_list[ind+1:]
|
||||
delims2 = {',', '-'}
|
||||
ind2 = next((i for i, ch in enumerate(suffix) if ch in delims2), None)
|
||||
return f'{prefix}{suffix[:ind2]}:{port}'
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
return int(os.environ[_PROCESS_COUNT])
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
return int(os.environ[_PROCESS_ID])
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> Optional[int]:
|
||||
return int(os.environ[_LOCAL_PROCESS_ID])
|
@ -16,14 +16,14 @@ import atexit
|
||||
import os
|
||||
import functools
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union, Sequence
|
||||
|
||||
from absl import logging
|
||||
from jax._src import cloud_tpu_init
|
||||
from jax._src.clusters import ClusterEnv
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
|
||||
class State:
|
||||
process_id: int = 0
|
||||
service: Optional[Any] = None
|
||||
@ -33,24 +33,18 @@ class State:
|
||||
def initialize(self,
|
||||
coordinator_address: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
process_id: Optional[int] = None):
|
||||
process_id: Optional[int] = None,
|
||||
local_device_ids: Optional[Union[int, Sequence[int]]] = None):
|
||||
coordinator_address = (coordinator_address or
|
||||
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
|
||||
if isinstance(local_device_ids, int):
|
||||
local_device_ids = [local_device_ids]
|
||||
|
||||
if cloud_tpu_init.running_in_cloud_tpu_vm:
|
||||
worker_endpoints = cloud_tpu_init.get_metadata(
|
||||
'worker-network-endpoints').split(',')
|
||||
if coordinator_address is None:
|
||||
coordinator_address = worker_endpoints[0].split(':')[2] + ':8476'
|
||||
if num_processes is None:
|
||||
num_processes = xla_bridge.process_count()
|
||||
if process_id is None:
|
||||
process_id = int(cloud_tpu_init.get_metadata('agent-worker-number'))
|
||||
|
||||
if num_processes != len(worker_endpoints):
|
||||
raise RuntimeError('Number of workers does not equal the number of '
|
||||
'processes. Auto detecting process_id is not possible.'
|
||||
'Please pass process_id manually.')
|
||||
(coordinator_address,
|
||||
num_processes,
|
||||
process_id,
|
||||
local_device_ids) = ClusterEnv.auto_detect_unset_distributed_params(
|
||||
coordinator_address, num_processes, process_id, local_device_ids)
|
||||
|
||||
if coordinator_address is None:
|
||||
raise ValueError('coordinator_address should be defined.')
|
||||
@ -59,6 +53,12 @@ class State:
|
||||
if process_id is None:
|
||||
raise ValueError('The process id of the current process must be defined.')
|
||||
|
||||
if local_device_ids:
|
||||
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
|
||||
logging.info('JAX distributed initialized with visible devices: %s', visible_devices)
|
||||
config.update("jax_cuda_visible_devices", visible_devices)
|
||||
config.update("jax_rocm_visible_devices", visible_devices)
|
||||
|
||||
self.process_id = process_id
|
||||
|
||||
if process_id == 0:
|
||||
@ -104,7 +104,8 @@ global_state = State()
|
||||
|
||||
def initialize(coordinator_address: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
process_id: Optional[int] = None):
|
||||
process_id: Optional[int] = None,
|
||||
local_device_ids: Optional[Union[int, Sequence[int]]] = None):
|
||||
"""Initializes the JAX distributed system.
|
||||
|
||||
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
|
||||
@ -117,24 +118,26 @@ def initialize(coordinator_address: Optional[str] = None,
|
||||
* it performs health checking, ensuring that all processes shut down if any process dies, and
|
||||
* it is used for distributed checkpointing.
|
||||
|
||||
If you are using GPU, you must provide the ``coordinator_address``,
|
||||
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
|
||||
If you are using TPU or Slurm, all arguments are optional: if omitted, they
|
||||
will be chosen automatically.
|
||||
|
||||
If you are using TPU, all arguments are optional: if omitted, they
|
||||
will be chosen automatically from the Cloud TPU metadata.
|
||||
Otherwise, you must provide the ``coordinator_address``,
|
||||
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
|
||||
|
||||
Args:
|
||||
coordinator_address: the IP address of process `0` and a port on which that
|
||||
process should launch a coordinator service. The choice of
|
||||
port does not matter, so long as the port is available on the coordinator
|
||||
and all processes agree on the port.
|
||||
May be ``None`` only on TPU, in which case it will be chosen automatically.
|
||||
num_processes: Number of processes. May be ``None`` only on TPU, in
|
||||
which case it will be chosen automatically based on the TPU slice.
|
||||
May be ``None`` only on supported environments, in which case it will be chosen automatically.
|
||||
num_processes: Number of processes. May be ``None`` only on supported environments, in
|
||||
which case it will be chosen automatically.
|
||||
process_id: The ID number of the current process. The ``process_id`` values across
|
||||
the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``.
|
||||
May be ``None`` only on TPU; if ``None`` it will be chosen from the TPU slice
|
||||
metadata.
|
||||
May be ``None`` only on supported environments; if ``None`` it will be chosen automatically.
|
||||
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
|
||||
If ``None``, defaults to all local devices being visible to the process except when processes
|
||||
are launched via Slurm on GPUs. In that case, it will default to a single device per process.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
|
||||
@ -153,7 +156,7 @@ def initialize(coordinator_address: Optional[str] = None,
|
||||
|
||||
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
|
||||
"""
|
||||
global_state.initialize(coordinator_address, num_processes, process_id)
|
||||
global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids)
|
||||
atexit.register(shutdown)
|
||||
|
||||
|
||||
|
@ -209,5 +209,20 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y[0], jax.device_count())
|
||||
print(y)
|
||||
|
||||
def test_gpu_multi_node_transparent_initialize_and_psum(self):
|
||||
|
||||
jax.distributed.initialize()
|
||||
|
||||
print(f"Total devices: {jax.device_count()}, "
|
||||
f"Devices per task: {jax.local_device_count()}")
|
||||
|
||||
self.assertEqual(jax.device_count(), int(os.environ['SLURM_NTASKS']))
|
||||
self.assertEqual(jax.local_device_count(), 1)
|
||||
|
||||
x = jnp.ones(jax.local_device_count())
|
||||
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
|
||||
self.assertEqual(y[0], jax.device_count())
|
||||
print(y)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user