mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[multihost_utils] fix docstring wording
This commit is contained in:
parent
f4bb1c0c62
commit
1797ab6c5b
@ -161,16 +161,16 @@ def assert_equal(in_tree, fail_message: str = ''):
|
||||
def reached_preemption_sync_point(step_id: int) -> bool:
|
||||
"""Determine whether all hosts have reached a preemption sync step.
|
||||
|
||||
When any host receive a preemption notice, the notice will be propagated to
|
||||
all hosts and trigger a synchronization protocol in background. The
|
||||
When any host receives a preemption notice, the notice is propagated to all
|
||||
hosts and triggers a synchronization protocol in the background. The
|
||||
synchronization protocol calculates the maximum step ids from all hosts, and
|
||||
uses the next step id (i.e., max + 1) as the safe step to save a checkpoint.
|
||||
All hosts should continue training more steps until this method returns True,
|
||||
indicating that the `step_id` is equal to the safe step and the hosts should
|
||||
start saving a checkpoint.
|
||||
|
||||
To use this API, all hosts must start training from the same step and call at
|
||||
every training step. Example usage:
|
||||
To use this API, all hosts must start training from the same step and call it
|
||||
at every training step. Example usage:
|
||||
|
||||
```
|
||||
def should_save(step_id: int) -> bool:
|
||||
|
Loading…
x
Reference in New Issue
Block a user