mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #11946 from nvcastet:fix_distributed_timeout
PiperOrigin-RevId: 469721289
This commit is contained in:
commit
0df81ffc16
@ -72,8 +72,10 @@ class State:
|
||||
if self.client is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
|
||||
# Set init_timeout to 5 min to leave time for all the processes to connect
|
||||
self.client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id, config.jax_coordination_service)
|
||||
coordinator_address, process_id, config.jax_coordination_service,
|
||||
init_timeout=300)
|
||||
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
|
||||
self.client.connect()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user