mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jax.distributed.initialize: specify bind address.
By default, the coordinator process listens on all interfaces.
This commit is contained in:
parent
dcd45c8d20
commit
2dd1b3d6c8
@ -62,6 +62,9 @@ The API {func}`jax.distributed.initialize` takes several arguments, namely:
|
||||
with a port available on that process. Process 0 will start a JAX service
|
||||
exposed via that IP address and port, to which the other processes in the
|
||||
cluster will connect.
|
||||
* `coordinator_bind_address`: the IP address and port to which the JAX service
|
||||
on process 0 in your cluster will bind. By default, it will bind to all
|
||||
available interfaces using the same port as `coordinator_address`.
|
||||
* `num_processes`: the number of processes in the cluster
|
||||
* `process_id`: the ID number of this process, in the range `[0 ..
|
||||
num_processes)`.
|
||||
|
@ -41,7 +41,8 @@ class State:
|
||||
num_processes: int | None = None,
|
||||
process_id: int | None = None,
|
||||
local_device_ids: int | Sequence[int] | None = None,
|
||||
initialization_timeout: int = 300):
|
||||
initialization_timeout: int = 300,
|
||||
coordinator_bind_address: str | None = None):
|
||||
coordinator_address = (coordinator_address or
|
||||
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
|
||||
if isinstance(local_device_ids, int):
|
||||
@ -66,6 +67,15 @@ class State:
|
||||
|
||||
self.coordinator_address = coordinator_address
|
||||
|
||||
# The default value of [::]:port tells the coordinator to bind to all
|
||||
# available addresses on the same port as coordinator_address.
|
||||
default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1]
|
||||
coordinator_bind_address = (coordinator_bind_address or
|
||||
os.environ.get('JAX_COORDINATOR_BIND_ADDRESS',
|
||||
default_coordinator_bind_address))
|
||||
if coordinator_bind_address is None:
|
||||
raise ValueError('coordinator_bind_address should be defined.')
|
||||
|
||||
if local_device_ids:
|
||||
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
|
||||
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
|
||||
@ -79,7 +89,7 @@ class State:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
logger.info('Starting JAX distributed service on %s', coordinator_address)
|
||||
self.service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes)
|
||||
coordinator_bind_address, num_processes)
|
||||
|
||||
self.num_processes = num_processes
|
||||
|
||||
@ -118,7 +128,8 @@ def initialize(coordinator_address: str | None = None,
|
||||
num_processes: int | None = None,
|
||||
process_id: int | None = None,
|
||||
local_device_ids: int | Sequence[int] | None = None,
|
||||
initialization_timeout: int = 300):
|
||||
initialization_timeout: int = 300,
|
||||
coordinator_bind_address: str | None = None):
|
||||
"""Initializes the JAX distributed system.
|
||||
|
||||
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
|
||||
@ -156,6 +167,11 @@ def initialize(coordinator_address: str | None = None,
|
||||
initialization_timeout: Time period (in seconds) for which connection will
|
||||
be retried. If the initialization takes more than the timeout specified,
|
||||
the initialization will error. Defaults to 300 secs i.e. 5 mins.
|
||||
coordinator_bind_address: the address and port to which the coordinator service
|
||||
on process `0` should bind. If this is not specified, the default is to bind to
|
||||
all available addresses on the same port as ``coordinator_address``. On systems
|
||||
that have multiple network interfaces per node it may be insufficient to only
|
||||
have the coordinator service listen on one address/interface.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
|
||||
@ -178,7 +194,7 @@ def initialize(coordinator_address: str | None = None,
|
||||
raise RuntimeError("jax.distributed.initialize() must be called before "
|
||||
"any JAX computations are executed.")
|
||||
global_state.initialize(coordinator_address, num_processes, process_id,
|
||||
local_device_ids, initialization_timeout)
|
||||
local_device_ids, initialization_timeout, coordinator_bind_address)
|
||||
atexit.register(shutdown)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user