From c879f65aa681db556ec9a90fed3e1af6e62adce2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Aug 2023 15:16:50 -0700 Subject: [PATCH] [JAX] Remove the non-coordination service distributed service implementation from JAX. The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code. PiperOrigin-RevId: 554608165 --- CHANGELOG.md | 2 ++ jax/_src/config.py | 10 ---------- jax/_src/distributed.py | 22 +++++++++++++++------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70dde5574..326a707cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list * jax2tf now uses native serialization by default. See the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) for details and for mechanisms to override the default. + * The option `--jax_coordination_service` has been removed. It is now always + `True`. ## jaxlib 0.4.15 diff --git a/jax/_src/config.py b/jax/_src/config.py index 50f235d18..2f591c389 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1068,16 +1068,6 @@ config.define_bool_state( default=(lib.version >= (0, 3, 6)), help=('Enables using optimization-barrier op for lowering remat.')) -# TODO(b/205307544): Remove flag once coordination service has rolled out. -config.define_bool_state( - name='jax_coordination_service', - default=True, - help=( - 'Use coordination service (experimental) instead of the default PjRT ' - 'distributed runtime.' - ) -) - # TODO(sharadmv,mattjj): set default to True, then remove config.define_bool_state( name='jax_eager_pmap', diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index a89e3030d..d8796e376 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -21,6 +21,7 @@ from typing import Any, Optional, Union from jax._src import clusters from jax._src.config import config from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version logger = logging.getLogger(__name__) @@ -71,22 +72,29 @@ class State: if self.service is not None: raise RuntimeError('distributed.initialize should only be called once.') logger.info('Starting JAX distributed service on %s', coordinator_address) - self.service = xla_extension.get_distributed_runtime_service( - coordinator_address, num_processes, config.jax_coordination_service) + if xla_extension_version >= 179: + self.service = xla_extension.get_distributed_runtime_service( + coordinator_address, num_processes) + else: + self.service = xla_extension.get_distributed_runtime_service( + coordinator_address, num_processes, config.jax_coordination_service) self.num_processes = num_processes if self.client is not None: raise RuntimeError('distributed.initialize should only be called once.') - self.client = xla_extension.get_distributed_runtime_client( - coordinator_address, process_id, config.jax_coordination_service, - init_timeout=initialization_timeout) + if xla_extension_version >= 179: + self.client = xla_extension.get_distributed_runtime_client( + coordinator_address, process_id, init_timeout=initialization_timeout) + else: + self.client = xla_extension.get_distributed_runtime_client( + coordinator_address, process_id, config.jax_coordination_service, + init_timeout=initialization_timeout) logger.info('Connecting to JAX distributed service on %s', coordinator_address) self.client.connect() - if config.jax_coordination_service: - self.initialize_preemption_sync_manager() + self.initialize_preemption_sync_manager() def shutdown(self): if self.client: