mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] Remove the non-coordination service distributed service implementation from JAX.
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code. PiperOrigin-RevId: 554608165
This commit is contained in:
parent
22285e69fb
commit
c879f65aa6
@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* jax2tf now uses native serialization by default. See
|
||||
the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
|
||||
for details and for mechanisms to override the default.
|
||||
* The option `--jax_coordination_service` has been removed. It is now always
|
||||
`True`.
|
||||
|
||||
## jaxlib 0.4.15
|
||||
|
||||
|
@ -1068,16 +1068,6 @@ config.define_bool_state(
|
||||
default=(lib.version >= (0, 3, 6)),
|
||||
help=('Enables using optimization-barrier op for lowering remat.'))
|
||||
|
||||
# TODO(b/205307544): Remove flag once coordination service has rolled out.
|
||||
config.define_bool_state(
|
||||
name='jax_coordination_service',
|
||||
default=True,
|
||||
help=(
|
||||
'Use coordination service (experimental) instead of the default PjRT '
|
||||
'distributed runtime.'
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(sharadmv,mattjj): set default to True, then remove
|
||||
config.define_bool_state(
|
||||
name='jax_eager_pmap',
|
||||
|
@ -21,6 +21,7 @@ from typing import Any, Optional, Union
|
||||
from jax._src import clusters
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -71,22 +72,29 @@ class State:
|
||||
if self.service is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
logger.info('Starting JAX distributed service on %s', coordinator_address)
|
||||
self.service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes, config.jax_coordination_service)
|
||||
if xla_extension_version >= 179:
|
||||
self.service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes)
|
||||
else:
|
||||
self.service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes, config.jax_coordination_service)
|
||||
|
||||
self.num_processes = num_processes
|
||||
|
||||
if self.client is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
|
||||
self.client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id, config.jax_coordination_service,
|
||||
init_timeout=initialization_timeout)
|
||||
if xla_extension_version >= 179:
|
||||
self.client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id, init_timeout=initialization_timeout)
|
||||
else:
|
||||
self.client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id, config.jax_coordination_service,
|
||||
init_timeout=initialization_timeout)
|
||||
logger.info('Connecting to JAX distributed service on %s', coordinator_address)
|
||||
self.client.connect()
|
||||
|
||||
if config.jax_coordination_service:
|
||||
self.initialize_preemption_sync_manager()
|
||||
self.initialize_preemption_sync_manager()
|
||||
|
||||
def shutdown(self):
|
||||
if self.client:
|
||||
|
Loading…
x
Reference in New Issue
Block a user