mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Save an on-demand checkpoint when any worker receives a preemption signal.
PiperOrigin-RevId: 458525108
This commit is contained in:
parent
7c49864fdf
commit
3fc24ceb35
@ -28,6 +28,7 @@ from jax._src.lib import xla_extension
|
||||
class State:
|
||||
service: Optional[Any] = None
|
||||
client: Optional[Any] = None
|
||||
preemption_sync_manager: Optional[Any] = None
|
||||
|
||||
def initialize(self,
|
||||
coordinator_address: Optional[str] = None,
|
||||
@ -81,6 +82,9 @@ class State:
|
||||
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
|
||||
self.client.connect()
|
||||
|
||||
if xla_client._version >= 77 and config.jax_coordination_service:
|
||||
self.initialize_preemption_sync_manager()
|
||||
|
||||
def shutdown(self):
|
||||
if self.client:
|
||||
self.client.shutdown()
|
||||
@ -89,6 +93,14 @@ class State:
|
||||
self.service.shutdown()
|
||||
self.service = None
|
||||
|
||||
def initialize_preemption_sync_manager(self):
|
||||
if self.preemption_sync_manager is not None:
|
||||
raise RuntimeError(
|
||||
'Preemption sync manager should only be initialized once.')
|
||||
self.preemption_sync_manager = (
|
||||
xla_extension.create_preemption_sync_manager())
|
||||
self.preemption_sync_manager.initialize(self.client)
|
||||
|
||||
global_state = State()
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@ from jax.experimental import maps
|
||||
from jax.experimental.pjit import pjit, FROM_GDA
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax._src import distributed
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -130,3 +131,48 @@ def assert_equal(in_tree, fail_message: str = ''):
|
||||
jax.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)):
|
||||
raise AssertionError(
|
||||
f'{fail_message} Expected: {expected}; got: {in_tree}.')
|
||||
|
||||
|
||||
def reached_preemption_sync_point(step_id: int) -> bool:
|
||||
"""Determine whether all hosts have reached a preemption sync step.
|
||||
|
||||
When any host receive a preemption notice, the notice will be propagated to
|
||||
all hosts and trigger a synchronization protocol in background. The
|
||||
synchronization protocol calculates the maximum step ids from all hosts, and
|
||||
uses the next step id (i.e., max + 1) as the safe step to save a checkpoint.
|
||||
All hosts should continue training more steps until this method returns True,
|
||||
indicating that the `step_id` is equal to the safe step and the hosts should
|
||||
start saving a checkpoint. This feature requires enabling
|
||||
`jax.config.jax_coordination_service`.
|
||||
|
||||
To use this API, all hosts must start training from the same step and call at
|
||||
every training step. Example usage:
|
||||
|
||||
```
|
||||
def should_save(step_id: int) -> bool:
|
||||
|
||||
# Should save an on-demand checkpoint for preemption
|
||||
if multihost_utils.reached_preemption_sync_point(step_id):
|
||||
return True
|
||||
|
||||
# Should save a regular checkpoint
|
||||
return step_id - last_saved_checkpoint_step >= save_interval_steps
|
||||
```
|
||||
|
||||
Preemption notice is provided by the cluster scheduler to notify the
|
||||
application in advance before it gets evicted. By default, we use SIGTERM as
|
||||
the signal for preemption notice.
|
||||
|
||||
TODO(b/230630494): Add instructions for customized preemption notice.
|
||||
|
||||
Returns:
|
||||
A boolean indicating whether all hosts have reached a synchronization step
|
||||
after some hosts are preempted.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if preemption sync manager has not been inititialized.
|
||||
"""
|
||||
sync_manager = distributed.global_state.preemption_sync_manager
|
||||
if sync_manager is None:
|
||||
raise RuntimeError("Preemption sync manager has not been initialized.")
|
||||
return sync_manager.reached_sync_point(step_id)
|
||||
|
Loading…
x
Reference in New Issue
Block a user