Use JAX's distributed system for fully asynchronous checkpointing.

PiperOrigin-RevId: 449380175
This commit is contained in:
Yash Katariya 2022-05-17 20:17:19 -07:00 committed by jax authors
parent 720d09c7df
commit 0574eb2141
3 changed files with 72 additions and 79 deletions

View File

@ -32,8 +32,14 @@ def initialize(coordinator_address: Optional[str] = None,
process_id: Optional[int] = None):
"""Initialize distributed system for topology discovery.
Currently, calling ``initialize`` sets up the multi-host GPU backend, and
is not required for CPU or TPU backends.
Currently, calling ``initialize`` sets up the multi-host GPU backend and Cloud
TPU backend.
If you are on GPU platform, you will have to provide the coordinator_address
and other args to the `initialize` API.
If you are on TPU platform, the coordinator_address and other args will be
auto detected but you have the option to provide it too.
Args:
coordinator_address: IP address and port of the coordinator. The choice of

View File

@ -14,7 +14,6 @@
"""GlobalDeviceArray serialization and deserialization."""
import asyncio
import os
import re
import threading
import time
@ -22,6 +21,7 @@ from typing import Callable
from absl import logging
import jax
from jax._src import distributed
from jax._src.util import prod
from jax.experimental import global_device_array as gda
from jax.experimental.maps import Mesh
@ -29,10 +29,12 @@ import jax.numpy as jnp
import numpy as np
import tensorstore as ts
import tensorflow.compat.v2 as tf
# internal import
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
_REMOVED_VALUE = 'Value removed'
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
_CHECKPOINT_FAILURE = 'checkpoint_write_failed'
async def create_async_gda_from_callback(
@ -170,9 +172,6 @@ def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
return asyncio.run(_run_deserializer())
no_lockfiles_exists = lambda paths: all(not tf.io.gfile.exists(f) for f in paths)
class _RetryWithTimeout:
def __init__(self, secs):
self.secs = secs
@ -189,6 +188,10 @@ class _RetryWithTimeout:
return time.time() > self.timeout_after
def _get_key(key: str):
return f'checkpoint_{key}'
class GlobalAsyncCheckpointManager:
"""Responsible for serializing GDAs via TensorStore.
@ -205,6 +208,9 @@ class GlobalAsyncCheckpointManager:
Below is a simplified training loop:
```
# Call this at the start of your program.
jax.distributed.initialize()
manager = GlobalAsyncCheckpointManager()
# Restore checkpoint if available or initialize the train_state from
@ -227,13 +233,21 @@ class GlobalAsyncCheckpointManager:
```
"""
def __init__(self, timeout_secs=600):
def __init__(self, timeout_secs=300):
self._timeout_secs = timeout_secs
self._timeout_in_ms = self._timeout_secs * 1000
self._commit_futures = None
self._thread = None
self._exception = None
if distributed.distributed_client is None:
raise ValueError('Please initialize the distributed system via '
'`jax.distributed.initialize()` at the start of your '
'program.')
self._client = distributed.distributed_client
self._final_ckpt_dir = None
def __del__(self):
if self._thread is not None and self._thread.is_alive():
logging.warning('Please add `.wait_until_finished()` in the main thread '
@ -250,48 +264,47 @@ class GlobalAsyncCheckpointManager:
logging.info('Commit to storage layer has completed.')
current_process = jax.process_index()
lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
all_lockfile_paths = [
os.path.join(lockfiles_dir, f'lockfile_{p}')
for p in range(jax.process_count())
]
current_process_lockfile = os.path.join(lockfiles_dir,
f'lockfile_{current_process}')
# TODO(yashkatariya): Add a method to the distributed system to wait for
# the key's value to change.
# Value already exists -- wait until value is NOT _REMOVED_VALUE.
with _RetryWithTimeout(self._timeout_secs) as t:
while not tf.io.gfile.exists(current_process_lockfile):
while self._client.blocking_key_value_get(
_get_key(str(current_process)), self._timeout_in_ms) == _REMOVED_VALUE:
if t.timed_out:
raise RuntimeError('Terminating after waiting for '
f'{self._timeout_secs} secs for lockfile to appear')
logging.info('Waiting for current process %s lockfile to appear.',
raise TimeoutError('Terminating after waiting for '
f'{self._timeout_secs} secs for lock value to appear.')
logging.info('Waiting for current process %s lock value to appear.',
current_process)
time.sleep(60)
tf.io.gfile.remove(current_process_lockfile)
logging.info('Lockfile removed for process %s', current_process)
self._client.key_value_set(_get_key(str(current_process)), _REMOVED_VALUE)
logging.info('Lock value removed for process %s', current_process)
# This while loop will not trigger until all commits have finished.
if current_process == 0:
with _RetryWithTimeout(self._timeout_secs) as t:
while True:
if t.timed_out:
raise RuntimeError('Terminating after waiting for '
raise TimeoutError('Terminating after waiting for '
f'{self._timeout_secs} secs for '
'finishing the serialization.')
# Mark as done when no lockfiles exist.
if no_lockfiles_exists(all_lockfile_paths):
tf.io.gfile.rmtree(lockfiles_dir)
logging.info('Lockfiles directory removed.')
# Mark as done when no lock values exist.
if all(
self._client.blocking_key_value_get(
_get_key(str(p)), self._timeout_in_ms) == _REMOVED_VALUE
for p in range(jax.process_count())):
logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir)
tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir)
logging.info('Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir)
self._client.key_value_set(_get_key(self._final_ckpt_dir), _CHECKPOINT_SUCCESS)
break
else:
logging.info('Thread sleeping for 60 seconds.')
time.sleep(60)
except Exception as e:
self._client.key_value_set(_get_key(self._final_ckpt_dir), _CHECKPOINT_FAILURE)
self._exception = e
def _start_commit_thread(self, temp_checkpoint_dir, final_checkpoint_dir):
@ -300,21 +313,16 @@ class GlobalAsyncCheckpointManager:
args=(temp_checkpoint_dir, final_checkpoint_dir))
self._thread.start()
def _write_lockfiles(self, temp_checkpoint_dir):
lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
tf.io.gfile.mkdir(lockfiles_dir)
def _write_lock_values(self):
write_count = 0
for p in range(jax.process_count()):
with tf.io.gfile.GFile(
os.path.join(lockfiles_dir, f'lockfile_{p}'), mode='w') as f:
f.write('File to track if all chunks have been written.')
# TODO(yashkatariya): Make the key value store writes safe if checkpoint
# managers are created concurrently.
self._client.key_value_set(_get_key(str(p)), f'Lock value for process {str(p)}')
write_count += 1
if write_count != jax.process_count():
raise ValueError("Process 0 couldn't write all the lockfiles.")
logging.info('Lock files for all processes have been written by process 0.')
raise ValueError("Process 0 couldn't write all the lock values.")
logging.info('Lock values for all processes have been written by process 0.')
def check_for_errors(self):
if self._exception is not None:
@ -330,6 +338,22 @@ class GlobalAsyncCheckpointManager:
self.check_for_errors()
if self._final_ckpt_dir is not None:
with _RetryWithTimeout(self._timeout_secs) as t:
if t.timed_out:
raise TimeoutError("Process 0 didn't finish the rename from "
"temporary to final checkpoint directory")
val = self._client.blocking_key_value_get(
_get_key(self._final_ckpt_dir), self._timeout_in_ms)
while val != _CHECKPOINT_SUCCESS:
if val == _CHECKPOINT_FAILURE:
raise ValueError(
'Checkpoint write failed. Please check the error message on '
'all hosts to see the real failure.')
logging.info('Waiting for final checkpoint directory to exist on '
'process %s', jax.process_index())
time.sleep(15)
def serialize(self, gdas, tensorstore_specs, *, temp_checkpoint_dir,
final_checkpoint_dir):
"""Serializes GlobalDeviceArrays via TensorStore asynchronously.
@ -356,7 +380,7 @@ class GlobalAsyncCheckpointManager:
# Process 0 writes lock files for all processes.
if jax.process_index() == 0:
self._write_lockfiles(temp_checkpoint_dir)
self._write_lock_values()
self._commit_futures = [[] for _ in range(len(tensorstore_specs))]
@ -366,6 +390,9 @@ class GlobalAsyncCheckpointManager:
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
# Used in wait_until_finished to check on process != 0, if the checkpoint
# has finished writing.
self._final_ckpt_dir = final_checkpoint_dir
self._start_commit_thread(temp_checkpoint_dir, final_checkpoint_dir)
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,

View File

@ -129,46 +129,6 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
def test_async_checkpointing(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
num = util.prod(global_input_shape)
# First GDA
global_input_data1 = np.arange(num).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb1)
temp_ckpt_dir1 = pathlib.Path(self.create_tempdir('temp_first').full_path)
ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first')
s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)])
manager = serialization.GlobalAsyncCheckpointManager()
manager.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
final_checkpoint_dir=ckpt_dir1)
manager.wait_until_finished()
d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)])
m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs)
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
np.array([[0], [2]]))
self.assertArraysEqual(m1.local_shards[1].data.to_py(),
np.array([[1], [3]]))
self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)
# Will throw `file already exists` error when `tf.io.gfile.rename`.
# `wait_until_finished` will raise the error.
with self.assertRaises(Exception):
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
manager1 = serialization.GlobalAsyncCheckpointManager()
manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
final_checkpoint_dir=ckpt_dir1)
manager1.wait_until_finished()
def test_spec_has_metadata(self):
spec = {
'a': {