jax.distributed.initialize: specify bind address.

By default, the coordinator process listens on all interfaces.
This commit is contained in:
Olli Lupton 2024-04-03 09:28:54 +00:00
parent dcd45c8d20
commit 2dd1b3d6c8
2 changed files with 23 additions and 4 deletions

View File

@ -62,6 +62,9 @@ The API {func}`jax.distributed.initialize` takes several arguments, namely:
with a port available on that process. Process 0 will start a JAX service
exposed via that IP address and port, to which the other processes in the
cluster will connect.
* `coordinator_bind_address`: the IP address and port to which the JAX service
on process 0 in your cluster will bind. By default, it will bind to all
available interfaces using the same port as `coordinator_address`.
* `num_processes`: the number of processes in the cluster
* `process_id`: the ID number of this process, in the range `[0 ..
num_processes)`.

View File

@ -41,7 +41,8 @@ class State:
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
initialization_timeout: int = 300):
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):
@ -66,6 +67,15 @@ class State:
self.coordinator_address = coordinator_address
# The default value of [::]:port tells the coordinator to bind to all
# available addresses on the same port as coordinator_address.
default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1]
coordinator_bind_address = (coordinator_bind_address or
os.environ.get('JAX_COORDINATOR_BIND_ADDRESS',
default_coordinator_bind_address))
if coordinator_bind_address is None:
raise ValueError('coordinator_bind_address should be defined.')
if local_device_ids:
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
@ -79,7 +89,7 @@ class State:
raise RuntimeError('distributed.initialize should only be called once.')
logger.info('Starting JAX distributed service on %s', coordinator_address)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes)
coordinator_bind_address, num_processes)
self.num_processes = num_processes
@ -118,7 +128,8 @@ def initialize(coordinator_address: str | None = None,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
initialization_timeout: int = 300):
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
"""Initializes the JAX distributed system.
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
@ -156,6 +167,11 @@ def initialize(coordinator_address: str | None = None,
initialization_timeout: Time period (in seconds) for which connection will
be retried. If the initialization takes more than the timeout specified,
the initialization will error. Defaults to 300 secs i.e. 5 mins.
coordinator_bind_address: the address and port to which the coordinator service
on process `0` should bind. If this is not specified, the default is to bind to
all available addresses on the same port as ``coordinator_address``. On systems
that have multiple network interfaces per node it may be insufficient to only
have the coordinator service listen on one address/interface.
Raises:
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
@ -178,7 +194,7 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
local_device_ids, initialization_timeout)
local_device_ids, initialization_timeout, coordinator_bind_address)
atexit.register(shutdown)