mirror of
https://github.com/ROCm/jax.git
synced 2025-04-22 21:06:04 +00:00

Array serialization in array_serialization.py contains a mixture of JAX specific serialization logic and tensorstore driver. This change separates JAX and tensorstore methods (a) making serialization more modular and (b) potentially allowing for alternative array serialization backends in the future. Additional clean-up changes include: - making ocdbt kvstore driver default in tensorstore - robustified array serialization tests especially on multi-host - explicit tensorstore array chunking to ensure chunk file size does not blow up PiperOrigin-RevId: 749175295
287 lines
13 KiB
Python
287 lines
13 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
|
|
from jax._src import clusters
|
|
from jax._src import config
|
|
from jax._src import xla_bridge
|
|
from jax._src.lib import xla_extension
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
_CHECK_PROXY_ENVS = config.bool_flag(
|
|
name="jax_check_proxy_envs",
|
|
default=True,
|
|
help="Checks proxy vars in user envs and emit warnings.",
|
|
)
|
|
|
|
|
|
class State:
|
|
process_id: int = 0
|
|
num_processes: int = 1
|
|
service: xla_extension.DistributedRuntimeService | Any | None = None
|
|
client: xla_extension.DistributedRuntimeClient | Any | None = None
|
|
preemption_sync_manager: Any | None = None
|
|
coordinator_address: str | None = None
|
|
slice_index: int | None = None
|
|
|
|
def initialize(self,
|
|
coordinator_address: str | None = None,
|
|
num_processes: int | None = None,
|
|
process_id: int | None = None,
|
|
local_device_ids: int | Sequence[int] | None = None,
|
|
cluster_detection_method: str | None = None,
|
|
initialization_timeout: int = 300,
|
|
coordinator_bind_address: str | None = None,
|
|
service_heartbeat_interval_seconds: int = 10,
|
|
service_max_missing_heartbeats: int = 10,
|
|
client_heartbeat_interval_seconds: int = 10,
|
|
client_max_missing_heartbeats: int = 10,
|
|
slice_index: int | None = None):
|
|
coordinator_address = (coordinator_address or
|
|
os.environ.get('JAX_COORDINATOR_ADDRESS'))
|
|
if isinstance(local_device_ids, int):
|
|
local_device_ids = [local_device_ids]
|
|
|
|
if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')):
|
|
local_device_ids = list(map(int, env_ids.split(",")))
|
|
|
|
if (cluster_detection_method != 'deactivate' and
|
|
None in (coordinator_address, num_processes, process_id, local_device_ids)):
|
|
(coordinator_address, num_processes, process_id, local_device_ids) = (
|
|
clusters.ClusterEnv.auto_detect_unset_distributed_params(
|
|
coordinator_address,
|
|
num_processes,
|
|
process_id,
|
|
local_device_ids,
|
|
cluster_detection_method,
|
|
initialization_timeout,
|
|
)
|
|
)
|
|
|
|
if coordinator_address is None:
|
|
raise ValueError('coordinator_address should be defined.')
|
|
if num_processes is None:
|
|
raise ValueError('Number of processes must be defined.')
|
|
if process_id is None:
|
|
raise ValueError('The process id of the current process must be defined.')
|
|
if not isinstance(process_id, int):
|
|
raise TypeError("process_id must be a nonnegative int. "
|
|
f"Got process_id={process_id} of type {type(process_id)}.")
|
|
if not isinstance(num_processes, int):
|
|
raise TypeError("num_processes must be a positive int. "
|
|
f"Got num_processes={num_processes} of type {type(num_processes)}.")
|
|
if not (0 <= process_id < num_processes):
|
|
raise ValueError("process_id and num_processes must be nonnegative, with process_id < num_processes. "
|
|
f"Got process_id={process_id}, num_processes={num_processes}.")
|
|
|
|
self.coordinator_address = coordinator_address
|
|
|
|
# The default value of [::]:port tells the coordinator to bind to all
|
|
# available addresses on the same port as coordinator_address.
|
|
default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1]
|
|
coordinator_bind_address = (coordinator_bind_address or
|
|
os.environ.get('JAX_COORDINATOR_BIND_ADDRESS',
|
|
default_coordinator_bind_address))
|
|
if coordinator_bind_address is None:
|
|
raise ValueError('coordinator_bind_address should be defined.')
|
|
|
|
if local_device_ids:
|
|
visible_devices = ','.join(str(x) for x in local_device_ids)
|
|
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
|
|
config.update("jax_cuda_visible_devices", visible_devices)
|
|
config.update("jax_rocm_visible_devices", visible_devices)
|
|
|
|
self.process_id = process_id
|
|
|
|
proxy_vars = []
|
|
if _CHECK_PROXY_ENVS.value:
|
|
proxy_vars = [key for key in os.environ.keys()
|
|
if '_proxy' in key.lower()]
|
|
|
|
if len(proxy_vars) > 0:
|
|
vars = " ".join(proxy_vars) + ". "
|
|
warning = (
|
|
f'JAX detected proxy variable(s) in the environment as distributed setup: {vars}'
|
|
'On some systems, this may cause a hang of distributed.initialize and '
|
|
'you may need to unset these ENV variable(s)'
|
|
)
|
|
logger.warning(warning)
|
|
|
|
if process_id == 0:
|
|
if self.service is not None:
|
|
raise RuntimeError('distributed.initialize should only be called once.')
|
|
logger.info(
|
|
'Starting JAX distributed service on %s', coordinator_bind_address
|
|
)
|
|
self.service = xla_extension.get_distributed_runtime_service(
|
|
coordinator_bind_address, num_processes,
|
|
heartbeat_interval=service_heartbeat_interval_seconds,
|
|
max_missing_heartbeats=service_max_missing_heartbeats)
|
|
|
|
self.num_processes = num_processes
|
|
|
|
if self.client is not None:
|
|
raise RuntimeError('distributed.initialize should only be called once.')
|
|
|
|
self.client = xla_extension.get_distributed_runtime_client(
|
|
coordinator_address, process_id, init_timeout=initialization_timeout,
|
|
heartbeat_interval=client_heartbeat_interval_seconds,
|
|
max_missing_heartbeats=client_max_missing_heartbeats, use_compression=True)
|
|
logger.info('Connecting to JAX distributed service on %s', coordinator_address)
|
|
self.client.connect()
|
|
|
|
self.initialize_preemption_sync_manager()
|
|
|
|
if slice_index is None and 'JAX_SLICE_INDEX' in os.environ:
|
|
slice_index = int(os.environ.get('JAX_SLICE_INDEX')) # type: ignore
|
|
self.slice_index = slice_index
|
|
|
|
def shutdown(self):
|
|
if self.client:
|
|
self.client.shutdown()
|
|
self.client = None
|
|
if self.service:
|
|
self.service.shutdown()
|
|
self.service = None
|
|
if self.preemption_sync_manager:
|
|
self.preemption_sync_manager = None
|
|
|
|
def initialize_preemption_sync_manager(self):
|
|
if self.preemption_sync_manager is not None:
|
|
raise RuntimeError(
|
|
'Preemption sync manager should only be initialized once.')
|
|
self.preemption_sync_manager = (
|
|
xla_extension.create_preemption_sync_manager())
|
|
self.preemption_sync_manager.initialize(self.client)
|
|
|
|
global_state = State()
|
|
|
|
def initialize(coordinator_address: str | None = None,
|
|
num_processes: int | None = None,
|
|
process_id: int | None = None,
|
|
local_device_ids: int | Sequence[int] | None = None,
|
|
cluster_detection_method: str | None = None,
|
|
initialization_timeout: int = 300,
|
|
coordinator_bind_address: str | None = None,
|
|
slice_index: int | None = None):
|
|
"""Initializes the JAX distributed system.
|
|
|
|
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
|
|
multi-host GPU and Cloud TPU. :func:`~jax.distributed.initialize` must be
|
|
called before performing any JAX computations.
|
|
|
|
The JAX distributed system serves a number of roles:
|
|
|
|
* It allows JAX processes to discover each other and share topology information,
|
|
* It performs health checking, ensuring that all processes shut down if any process dies, and
|
|
* It is used for distributed checkpointing.
|
|
|
|
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they
|
|
will be chosen automatically.
|
|
|
|
The ``cluster_detection_method`` may be used to choose a specific method for detecting those
|
|
distributed arguments. You may pass any of the automatic ``spec_detect_methods`` to this
|
|
argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI
|
|
installations, if you have a functional ``mpi4py`` installed, you may pass
|
|
``cluster_detection_method="mpi4py"`` to bootstrap the required arguments.
|
|
|
|
Otherwise, you must provide the ``coordinator_address``,
|
|
``num_processes``, ``process_id``, and ``local_device_ids`` arguments
|
|
to :func:`~jax.distributed.initialize`. When all four arguments are provided, cluster
|
|
environment auto detection will be skipped.
|
|
|
|
Please note: on some systems, particularly HPC clusters that only access external networks
|
|
through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
|
|
:func:`~jax.distributed.initialize` may timeout. You may need to unset these variables
|
|
prior to application launch.
|
|
|
|
Args:
|
|
coordinator_address: the IP address of process `0` and a port on which that
|
|
process should launch a coordinator service. The choice of
|
|
port does not matter, so long as the port is available on the coordinator
|
|
and all processes agree on the port.
|
|
May be ``None`` only on supported environments, in which case it will be chosen automatically.
|
|
Note that special addresses like ``localhost`` or ``127.0.0.1`` usually mean that the program
|
|
will bind to a local interface and are not suitable when running in a multi-host environment.
|
|
num_processes: Number of processes. May be ``None`` only on supported environments, in
|
|
which case it will be chosen automatically.
|
|
process_id: The ID number of the current process. The ``process_id`` values across
|
|
the cluster must be a dense range ``0``, ``1``, ..., ``num_processes - 1``.
|
|
May be ``None`` only on supported environments; if ``None`` it will be chosen automatically.
|
|
local_device_ids: Restricts the visible devices of the current process to ``local_device_ids``.
|
|
If ``None``, defaults to all local devices being visible to the process except when processes
|
|
are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
|
|
cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
|
|
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
|
|
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
|
|
Legacy auto-detect options "ompi" (OMPI) and "slurm" (Slurm) remain enabled. "deactivate" bypasses
|
|
automatic cluster detection.
|
|
initialization_timeout: Time period (in seconds) for which connection will
|
|
be retried. If the initialization takes more than the timeout specified,
|
|
the initialization will error. Defaults to 300 secs i.e. 5 mins.
|
|
coordinator_bind_address: the address and port to which the coordinator service
|
|
on process `0` should bind. If this is not specified, the default is to bind to
|
|
all available addresses on the same port as ``coordinator_address``. On systems
|
|
that have multiple network interfaces per node it may be insufficient to only
|
|
have the coordinator service listen on one address/interface.
|
|
slice_index: The slice index assigned to this process' local devices. If any process sets ``slice_index``,
|
|
then all processes must do so. If ``None`` the slice indices will be chosen automatically.
|
|
|
|
Raises:
|
|
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once
|
|
or if called after the backend is already initialized.
|
|
|
|
Examples:
|
|
|
|
Suppose there are two GPU processes, and process 0 is the designated coordinator
|
|
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
|
|
following commands before anything else.
|
|
|
|
On process 0:
|
|
|
|
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0) # doctest: +SKIP
|
|
|
|
On process 1:
|
|
|
|
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1) # doctest: +SKIP
|
|
"""
|
|
if xla_bridge.backends_are_initialized():
|
|
raise RuntimeError("jax.distributed.initialize() must be called before "
|
|
"any JAX calls that might initialise the XLA backend. "
|
|
"This includes any computation, but also calls to jax.devices, jax.device_put, and others.")
|
|
global_state.initialize(coordinator_address, num_processes, process_id,
|
|
local_device_ids, cluster_detection_method,
|
|
initialization_timeout, coordinator_bind_address,
|
|
slice_index=slice_index)
|
|
|
|
|
|
def is_initialized() -> bool:
|
|
"""Check if the JAX distributed system is initialized."""
|
|
return global_state.client is not None
|
|
|
|
def shutdown():
|
|
"""Shuts down the distributed system.
|
|
|
|
Does nothing if the distributed system is not running.
|
|
"""
|
|
global_state.shutdown()
|