Add generic interface for auto initialization of distributed JAX service

* Also add slurm cluster support
This commit is contained in:
Nicolas Castet 2022-08-26 14:23:57 -05:00
parent 640e15fe07
commit 412a5379c1
7 changed files with 297 additions and 33 deletions

View File

@ -64,7 +64,9 @@ The API {func}`jax.distributed.initialize` takes several arguments, namely:
cluster will connect.
* `num_processes`: the number of processes in the cluster
* `process_id`: the ID number of this process, in the range `[0 ..
num_processes)`.
num_processes)`.
* `local_device_ids`: Restricts the visible devices of the current process to
``local_device_ids``.
For example on GPU, a typical usage is:
@ -76,9 +78,11 @@ jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
process_id=0)
```
On Cloud TPU, you can simply call {func}`jax.distributed.initialize()` with no
arguments. Default values for the arguments will be chosen automatically using
the TPU pod metadata:
On Cloud TPU and Slurm environments, you can simply call {func}`jax.distributed.initialize()` with no
arguments. Default values for the arguments will be chosen automatically.
When running on GPUs with Slurm, it is assumed that one process is started per GPU, i.e. each process will
be assigned only one visible local device. Otherwise it is assumed that one process is started per host,
i.e. each process will be assigned all local devices.
```python
import jax

View File

@ -0,0 +1,24 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cluster import ClusterEnv
# Order of declaration of the cluster environments
# will dictate the order in which they will be checked.
# Therefore, if multiple environments are available and
# the user did not explicitly provide the arguments
# to :func:`jax.distributed.initialize`, the first
# available one from the list will be picked.
from .slurm_cluster import SlurmCluster
from .cloud_tpu_cluster import TpuCluster

View File

@ -0,0 +1,47 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from jax._src.clusters import ClusterEnv
from jax._src.lib import xla_bridge
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm, get_metadata
class TpuCluster(ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
return running_in_cloud_tpu_vm
@classmethod
def get_coordinator_address(cls) -> str:
return cls._get_worker_endpoints()[0].split(':')[2] + ':8476'
@classmethod
def get_process_count(cls) -> int:
return xla_bridge.process_count()
@classmethod
def get_process_id(cls) -> int:
if cls.get_process_count() != len(cls._get_worker_endpoints()):
raise RuntimeError('Number of workers does not equal the number of '
'processes. Auto detecting process_id is not possible.'
'Please pass process_id to jax.distributed.initialize() manually.')
return int(get_metadata('agent-worker-number'))
@classmethod
def get_local_process_id(cls) -> Optional[int]:
return None
@staticmethod
def _get_worker_endpoints() -> str:
return get_metadata('worker-network-endpoints').split(',')

View File

@ -0,0 +1,109 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Type, Sequence, Tuple
from absl import logging
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
class ClusterEnv:
"""Interface for defining a cluster environment.
To enable auto bootrapping (aka :func:`jax.distributed.initialize()`),
cluster environments need to derive from :class:`ClusterEnv` and implement
:func:`is_env_present`, :func:`get_coordinator_address`,
:func:`get_process_count`, and :func:`get_process_id`.
:class:`ClusterEnv` subclasses are automatically detected when imported.
"""
_cluster_types: List[Type['ClusterEnv']] = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._cluster_types.append(cls)
@classmethod
# pytype: disable=bad-return-type
def auto_detect_unset_distributed_params(cls,
coordinator_address: Optional[str],
num_processes: Optional[int],
process_id: Optional[int],
local_device_ids: Optional[Sequence[int]]
) -> Tuple[Optional[str], Optional[int], Optional[int],
Optional[Sequence[int]]]:
if all(p is not None for p in (coordinator_address, num_processes,
process_id, local_device_ids)):
return (coordinator_address, num_processes, process_id,
local_device_ids)
env = next((env for env in cls._cluster_types if env.is_env_present()), None)
if env:
logging.vlog(1, 'Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
coordinator_address = env.get_coordinator_address()
if num_processes is None:
num_processes = env.get_process_count()
if process_id is None:
process_id = env.get_process_id()
# Never automatically set local_device_ids on TPUs
# Defaults to single process per device if local_process_id is available.
# This only runs if we're in a managed distributed environment.
# Otherwise local_device_ids will remain unset,
# which will default to all devices being visible.
if (local_device_ids is None and not running_in_cloud_tpu_vm and
env.get_local_process_id() is not None):
local_device_ids = [env.get_local_process_id()] # type: ignore[list-item]
else:
logging.vlog(1, 'Could not find a known environment for initializing distributed JAX. '
'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
return (coordinator_address, num_processes, process_id, local_device_ids)
# pytype: enable=bad-return-type
@classmethod
def is_env_present(cls) -> bool:
"""Returns True if process is running in this cluster environment.
"""
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")
@classmethod
def get_coordinator_address(cls) -> str:
"""Returns address and port used by JAX to bootstrap.
Process id 0 will open a tcp socket at "hostname:port" where
all the proccesses will connect to initialize the distributed JAX service.
The selected port needs to be free.
:func:`get_coordinator_address` needs to return the same hostname and port on all the processes.
Returns:
"hostname:port"
"""
raise NotImplementedError("ClusterEnv subclasses must implement get_coordinator_address")
@classmethod
def get_process_count(cls) -> int:
raise NotImplementedError("ClusterEnv subclasses must implement get_process_count")
@classmethod
def get_process_id(cls) -> int:
raise NotImplementedError("ClusterEnv subclasses must implement get_process_id")
@classmethod
def get_local_process_id(cls) -> Optional[int]:
""" Get index of current process inside a host.
The method is only useful to support single device per process.
In that case, each process will see a local device whose ID is
the same as its local process ID.
If None, JAX will not restrict the visible devices.
"""
return None

View File

@ -0,0 +1,62 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
from jax._src.clusters import ClusterEnv
_JOBID_PARAM = 'SLURM_JOB_ID'
_NODE_LIST = 'SLURM_STEP_NODELIST'
_PROCESS_COUNT = 'SLURM_NTASKS'
_PROCESS_ID = 'SLURM_PROCID'
_LOCAL_PROCESS_ID = 'SLURM_LOCALID'
_NUM_NODES = 'SLURM_STEP_NUM_NODES'
class SlurmCluster(ClusterEnv):
@classmethod
def is_env_present(cls) -> bool:
return _JOBID_PARAM in os.environ
@classmethod
def get_coordinator_address(cls) -> str:
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1)
# Parse the first hostname of the job
# If we are looking for 'node001',
# node_list potential formats are 'node001', 'node001,host2',
# 'node[001-0015],host2', and 'node[001,007-015],host2'.
node_list = os.environ[_NODE_LIST]
delims = {',', '['}
ind = next((i for i, ch in enumerate(node_list) if ch in delims), len(node_list))
if ind == len(node_list) or node_list[ind] == ',': # Formats: 'node001' or 'node001,host2'
return f'{node_list[:ind]}:{port}'
else: # Formats: 'node[001-0015],host2' or 'node[001,007-015],host2'
prefix = node_list[:ind]
suffix = node_list[ind+1:]
delims2 = {',', '-'}
ind2 = next((i for i, ch in enumerate(suffix) if ch in delims2), None)
return f'{prefix}{suffix[:ind2]}:{port}'
@classmethod
def get_process_count(cls) -> int:
return int(os.environ[_PROCESS_COUNT])
@classmethod
def get_process_id(cls) -> int:
return int(os.environ[_PROCESS_ID])
@classmethod
def get_local_process_id(cls) -> Optional[int]:
return int(os.environ[_LOCAL_PROCESS_ID])

View File

@ -16,14 +16,14 @@ import atexit
import os
import functools
from typing import Any, Optional
from typing import Any, Optional, Union, Sequence
from absl import logging
from jax._src import cloud_tpu_init
from jax._src.clusters import ClusterEnv
from jax._src.config import config
from jax._src.lib import xla_bridge
from jax._src.lib import xla_extension
class State:
process_id: int = 0
service: Optional[Any] = None
@ -33,24 +33,18 @@ class State:
def initialize(self,
coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
process_id: Optional[int] = None,
local_device_ids: Optional[Union[int, Sequence[int]]] = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):
local_device_ids = [local_device_ids]
if cloud_tpu_init.running_in_cloud_tpu_vm:
worker_endpoints = cloud_tpu_init.get_metadata(
'worker-network-endpoints').split(',')
if coordinator_address is None:
coordinator_address = worker_endpoints[0].split(':')[2] + ':8476'
if num_processes is None:
num_processes = xla_bridge.process_count()
if process_id is None:
process_id = int(cloud_tpu_init.get_metadata('agent-worker-number'))
if num_processes != len(worker_endpoints):
raise RuntimeError('Number of workers does not equal the number of '
'processes. Auto detecting process_id is not possible.'
'Please pass process_id manually.')
(coordinator_address,
num_processes,
process_id,
local_device_ids) = ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address, num_processes, process_id, local_device_ids)
if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
@ -59,6 +53,12 @@ class State:
if process_id is None:
raise ValueError('The process id of the current process must be defined.')
if local_device_ids:
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
logging.info('JAX distributed initialized with visible devices: %s', visible_devices)
config.update("jax_cuda_visible_devices", visible_devices)
config.update("jax_rocm_visible_devices", visible_devices)
self.process_id = process_id
if process_id == 0:
@ -104,7 +104,8 @@ global_state = State()
def initialize(coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
process_id: Optional[int] = None,
local_device_ids: Optional[Union[int, Sequence[int]]] = None):
"""Initializes the JAX distributed system.
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
@ -117,24 +118,26 @@ def initialize(coordinator_address: Optional[str] = None,
* it performs health checking, ensuring that all processes shut down if any process dies, and
* it is used for distributed checkpointing.
If you are using GPU, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
If you are using TPU or Slurm, all arguments are optional: if omitted, they
will be chosen automatically.
If you are using TPU, all arguments are optional: if omitted, they
will be chosen automatically from the Cloud TPU metadata.
Otherwise, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
Args:
coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of
port does not matter, so long as the port is available on the coordinator
and all processes agree on the port.
May be ``None`` only on TPU, in which case it will be chosen automatically.
num_processes: Number of processes. May be ``None`` only on TPU, in
which case it will be chosen automatically based on the TPU slice.
May be ``None`` only on supported environments, in which case it will be chosen automatically.
num_processes: Number of processes. May be ``None`` only on supported environments, in
which case it will be chosen automatically.
process_id: The ID number of the current process. The ``process_id`` values across
the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``.
May be ``None`` only on TPU; if ``None`` it will be chosen from the TPU slice
metadata.
May be ``None`` only on supported environments; if ``None`` it will be chosen automatically.
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
If ``None``, defaults to all local devices being visible to the process except when processes
are launched via Slurm on GPUs. In that case, it will default to a single device per process.
Raises:
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
@ -153,7 +156,7 @@ def initialize(coordinator_address: Optional[str] = None,
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
"""
global_state.initialize(coordinator_address, num_processes, process_id)
global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids)
atexit.register(shutdown)

View File

@ -209,5 +209,20 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
self.assertEqual(y[0], jax.device_count())
print(y)
def test_gpu_multi_node_transparent_initialize_and_psum(self):
jax.distributed.initialize()
print(f"Total devices: {jax.device_count()}, "
f"Devices per task: {jax.local_device_count()}")
self.assertEqual(jax.device_count(), int(os.environ['SLURM_NTASKS']))
self.assertEqual(jax.local_device_count(), 1)
x = jnp.ones(jax.local_device_count())
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
self.assertEqual(y[0], jax.device_count())
print(y)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())