Save an on-demand checkpoint when any worker receives a preemption signal.

PiperOrigin-RevId: 458525108
This commit is contained in:
Haoyu Zhang 2022-07-01 12:44:59 -07:00 committed by jax authors
parent 7c49864fdf
commit 3fc24ceb35
2 changed files with 58 additions and 0 deletions

View File

@ -28,6 +28,7 @@ from jax._src.lib import xla_extension
class State:
service: Optional[Any] = None
client: Optional[Any] = None
preemption_sync_manager: Optional[Any] = None
def initialize(self,
coordinator_address: Optional[str] = None,
@ -81,6 +82,9 @@ class State:
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()
if xla_client._version >= 77 and config.jax_coordination_service:
self.initialize_preemption_sync_manager()
def shutdown(self):
if self.client:
self.client.shutdown()
@ -89,6 +93,14 @@ class State:
self.service.shutdown()
self.service = None
def initialize_preemption_sync_manager(self):
if self.preemption_sync_manager is not None:
raise RuntimeError(
'Preemption sync manager should only be initialized once.')
self.preemption_sync_manager = (
xla_extension.create_preemption_sync_manager())
self.preemption_sync_manager.initialize(self.client)
global_state = State()

View File

@ -23,6 +23,7 @@ from jax.experimental import maps
from jax.experimental.pjit import pjit, FROM_GDA
from jax.interpreters.pxla import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
from jax._src import distributed
import numpy as np
@ -130,3 +131,48 @@ def assert_equal(in_tree, fail_message: str = ''):
jax.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)):
raise AssertionError(
f'{fail_message} Expected: {expected}; got: {in_tree}.')
def reached_preemption_sync_point(step_id: int) -> bool:
"""Determine whether all hosts have reached a preemption sync step.
When any host receive a preemption notice, the notice will be propagated to
all hosts and trigger a synchronization protocol in background. The
synchronization protocol calculates the maximum step ids from all hosts, and
uses the next step id (i.e., max + 1) as the safe step to save a checkpoint.
All hosts should continue training more steps until this method returns True,
indicating that the `step_id` is equal to the safe step and the hosts should
start saving a checkpoint. This feature requires enabling
`jax.config.jax_coordination_service`.
To use this API, all hosts must start training from the same step and call at
every training step. Example usage:
```
def should_save(step_id: int) -> bool:
# Should save an on-demand checkpoint for preemption
if multihost_utils.reached_preemption_sync_point(step_id):
return True
# Should save a regular checkpoint
return step_id - last_saved_checkpoint_step >= save_interval_steps
```
Preemption notice is provided by the cluster scheduler to notify the
application in advance before it gets evicted. By default, we use SIGTERM as
the signal for preemption notice.
TODO(b/230630494): Add instructions for customized preemption notice.
Returns:
A boolean indicating whether all hosts have reached a synchronization step
after some hosts are preempted.
Raises:
RuntimeError: if preemption sync manager has not been inititialized.
"""
sync_manager = distributed.global_state.preemption_sync_manager
if sync_manager is None:
raise RuntimeError("Preemption sync manager has not been initialized.")
return sync_manager.reached_sync_point(step_id)