[JAX] Remove code that sets or tests --jax_coordination_service.

--jax_coordination_service defaults to True and has for some time, and support for the non-coordination service case will be removed shortly.

PiperOrigin-RevId: 551932242
This commit is contained in:
Peter Hawkins 2023-07-28 13:12:40 -07:00 committed by jax authors
parent 9a21ff0780
commit ddfdb7a00e

View File

@ -169,8 +169,7 @@ def reached_preemption_sync_point(step_id: int) -> bool:
uses the next step id (i.e., max + 1) as the safe step to save a checkpoint.
All hosts should continue training more steps until this method returns True,
indicating that the `step_id` is equal to the safe step and the hosts should
start saving a checkpoint. This feature requires enabling
`jax.config.jax_coordination_service`.
start saving a checkpoint.
To use this API, all hosts must start training from the same step and call at
every training step. Example usage: