[JAX] Remove the non-coordination service distributed service implementation from JAX.

The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.

PiperOrigin-RevId: 554608165
This commit is contained in:
Peter Hawkins 2023-08-07 15:16:50 -07:00 committed by jax authors
parent 22285e69fb
commit c879f65aa6
3 changed files with 17 additions and 17 deletions

View File

@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
* jax2tf now uses native serialization by default. See
the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
## jaxlib 0.4.15

View File

@ -1068,16 +1068,6 @@ config.define_bool_state(
default=(lib.version >= (0, 3, 6)),
help=('Enables using optimization-barrier op for lowering remat.'))
# TODO(b/205307544): Remove flag once coordination service has rolled out.
config.define_bool_state(
name='jax_coordination_service',
default=True,
help=(
'Use coordination service (experimental) instead of the default PjRT '
'distributed runtime.'
)
)
# TODO(sharadmv,mattjj): set default to True, then remove
config.define_bool_state(
name='jax_eager_pmap',

View File

@ -21,6 +21,7 @@ from typing import Any, Optional, Union
from jax._src import clusters
from jax._src.config import config
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
logger = logging.getLogger(__name__)
@ -71,22 +72,29 @@ class State:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logger.info('Starting JAX distributed service on %s', coordinator_address)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
if xla_extension_version >= 179:
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes)
else:
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
self.num_processes = num_processes
if self.client is not None:
raise RuntimeError('distributed.initialize should only be called once.')
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service,
init_timeout=initialization_timeout)
if xla_extension_version >= 179:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, init_timeout=initialization_timeout)
else:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service,
init_timeout=initialization_timeout)
logger.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()
if config.jax_coordination_service:
self.initialize_preemption_sync_manager()
self.initialize_preemption_sync_manager()
def shutdown(self):
if self.client: