Merge pull request #11946 from nvcastet:fix_distributed_timeout

PiperOrigin-RevId: 469721289
This commit is contained in:
jax authors 2022-08-24 07:28:44 -07:00
commit 0df81ffc16

View File

@ -72,8 +72,10 @@ class State:
if self.client is not None:
raise RuntimeError('distributed.initialize should only be called once.')
# Set init_timeout to 5 min to leave time for all the processes to connect
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service)
coordinator_address, process_id, config.jax_coordination_service,
init_timeout=300)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()