rocm_jax/jax/_src/distributed.py

165 lines
6.3 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import os
import functools
from typing import Any, Optional
from absl import logging
from jax._src import cloud_tpu_init
from jax._src.config import config
from jax._src.lib import xla_bridge
from jax._src.lib import xla_extension
class State:
process_id: int = 0
service: Optional[Any] = None
client: Optional[Any] = None
preemption_sync_manager: Optional[Any] = None
def initialize(self,
coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if cloud_tpu_init.running_in_cloud_tpu_vm:
worker_endpoints = cloud_tpu_init.get_metadata(
'worker-network-endpoints').split(',')
if coordinator_address is None:
coordinator_address = worker_endpoints[0].split(':')[2] + ':8476'
if num_processes is None:
num_processes = xla_bridge.process_count()
if process_id is None:
process_id = int(cloud_tpu_init.get_metadata('agent-worker-number'))
if num_processes != len(worker_endpoints):
raise RuntimeError('Number of workers does not equal the number of '
'processes. Auto detecting process_id is not possible.'
'Please pass process_id manually.')
if coordinator_address is None:
raise ValueError('coordinator_address should be defined.')
if num_processes is None:
raise ValueError('Number of processes must be defined.')
if process_id is None:
raise ValueError('The process id of the current process must be defined.')
self.process_id = process_id
if process_id == 0:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logging.info('Starting JAX distributed service on %s', coordinator_address)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
if self.client is not None:
raise RuntimeError('distributed.initialize should only be called once.')
# Set init_timeout to 5 min to leave time for all the processes to connect
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service,
init_timeout=300)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()
if config.jax_coordination_service:
self.initialize_preemption_sync_manager()
def shutdown(self):
if self.client:
self.client.shutdown()
self.client = None
if self.service:
self.service.shutdown()
self.service = None
if self.preemption_sync_manager:
self.preemption_sync_manager = None
def initialize_preemption_sync_manager(self):
if self.preemption_sync_manager is not None:
raise RuntimeError(
'Preemption sync manager should only be initialized once.')
self.preemption_sync_manager = (
xla_extension.create_preemption_sync_manager())
self.preemption_sync_manager.initialize(self.client)
global_state = State()
def initialize(coordinator_address: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None):
"""Initializes the JAX distributed system.
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
multi-host GPU and Cloud TPU. :func:`~jax.distributed.initialize` must be
called before performing any JAX computations.
The JAX distributed system serves a number of roles:
* it allows JAX processes to discover each other and share topology information,
* it performs health checking, ensuring that all processes shut down if any process dies, and
* it is used for distributed checkpointing.
If you are using GPU, you must provide the ``coordinator_address``,
``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`.
If you are using TPU, all arguments are optional: if omitted, they
will be chosen automatically from the Cloud TPU metadata.
Args:
coordinator_address: the IP address of process `0` and a port on which that
process should launch a coordinator service. The choice of
port does not matter, so long as the port is available on the coordinator
and all processes agree on the port.
May be ``None`` only on TPU, in which case it will be chosen automatically.
num_processes: Number of processes. May be ``None`` only on TPU, in
which case it will be chosen automatically based on the TPU slice.
process_id: The ID number of the current process. The ``process_id`` values across
the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``.
May be ``None`` only on TPU; if ``None`` it will be chosen from the TPU slice
metadata.
Raises:
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
Example:
Suppose there are two GPU processs, and process 0 is the designated coordinator
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
following commands before anything else.
On process 0:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0) # doctest: +SKIP
On process 1:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
"""
global_state.initialize(coordinator_address, num_processes, process_id)
atexit.register(shutdown)
def shutdown():
"""Shuts down the distributed system.
Does nothing if the distributed system is not running."""
global_state.shutdown()