Add an option to deactivate automatic cluster detection in jax.distributed.initialize().

This commit is contained in:
Emily Fertig 2024-11-18 17:06:28 -08:00
parent c9a5902216
commit 6a8bbcbadf

View File

@ -62,7 +62,8 @@ class State:
if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')):
local_device_ids = list(map(int, env_ids.split(",")))
if None in (coordinator_address, num_processes, process_id, local_device_ids):
if (cluster_detection_method != 'deactivate' and
None in (coordinator_address, num_processes, process_id, local_device_ids)):
(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address,
@ -217,7 +218,8 @@ def initialize(coordinator_address: str | None = None,
cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
Legacy auto-detect options (OMPI, Slurm) remain enabled.
Legacy auto-detect options "ompi" (OMPI) and "slurm" (Slurm) remain enabled. "deactivate" bypasses
automatic cluster detection.
initialization_timeout: Time period (in seconds) for which connection will
be retried. If the initialization takes more than the timeout specified,
the initialization will error. Defaults to 300 secs i.e. 5 mins.