mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add an option to deactivate automatic cluster detection in jax.distributed.initialize().
This commit is contained in:
parent
c9a5902216
commit
6a8bbcbadf
@ -62,7 +62,8 @@ class State:
|
||||
if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')):
|
||||
local_device_ids = list(map(int, env_ids.split(",")))
|
||||
|
||||
if None in (coordinator_address, num_processes, process_id, local_device_ids):
|
||||
if (cluster_detection_method != 'deactivate' and
|
||||
None in (coordinator_address, num_processes, process_id, local_device_ids)):
|
||||
(coordinator_address, num_processes, process_id, local_device_ids) = (
|
||||
clusters.ClusterEnv.auto_detect_unset_distributed_params(
|
||||
coordinator_address,
|
||||
@ -217,7 +218,8 @@ def initialize(coordinator_address: str | None = None,
|
||||
cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
|
||||
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
|
||||
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
|
||||
Legacy auto-detect options (OMPI, Slurm) remain enabled.
|
||||
Legacy auto-detect options "ompi" (OMPI) and "slurm" (Slurm) remain enabled. "deactivate" bypasses
|
||||
automatic cluster detection.
|
||||
initialization_timeout: Time period (in seconds) for which connection will
|
||||
be retried. If the initialization takes more than the timeout specified,
|
||||
the initialization will error. Defaults to 300 secs i.e. 5 mins.
|
||||
|
Loading…
x
Reference in New Issue
Block a user