Add fully asynchronous checkpointing. This will allow the training to proceed forward when the checkpoint is being committed.

PiperOrigin-RevId: 446083057
This commit is contained in:
Yash Katariya 2022-05-02 18:43:27 -07:00 committed by jax authors
parent 939233e769
commit b7293d5683
2 changed files with 261 additions and 25 deletions

View File

@ -14,8 +14,12 @@
"""GlobalDeviceArray serialization and deserialization."""
import asyncio
import os
import re
import threading
import time
from typing import Callable
from absl import logging
import jax
from jax._src.util import prod
@ -24,6 +28,11 @@ from jax.experimental.maps import Mesh
import jax.numpy as jnp
import numpy as np
import tensorstore as ts
import tensorflow.compat.v2 as tf
# internal import
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
async def create_async_gda_from_callback(
@ -88,26 +97,28 @@ def get_tensorstore_spec(ckpt_path: str):
return spec
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
commit_future=None):
# 'metadata' may not be present at the top level (for example, if we are using
# a 'cast' driver).
if not _spec_has_metadata(tensorstore_spec):
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
t = await ts.open(
ts.Spec(tensorstore_spec),
create=True,
open=True,
context=ts.Context({'file_io_concurrency': {
'limit': 128
}}))
ts.Spec(tensorstore_spec), create=True, open=True, context=TS_CONTEXT)
async def _write_array(shard):
if shard.replica_id == 0:
await t[shard.index].write(shard.data)
write_future = t[shard.index].write(shard.data)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(write_future.commit)
await write_future.copy
else:
await write_future.commit
future_write_state = jax.tree_util.tree_map(_write_array,
tuple(gda_inp.local_shards))
gda_inp.local_shards)
return await asyncio.gather(*future_write_state)
@ -120,7 +131,7 @@ def run_serialization(gdas, tensorstore_specs):
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
global_shape=None, dtype=None):
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
t = ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT).result()
shape = t.shape if global_shape is None else global_shape
requires_padding = prod(shape) > prod(t.shape)
@ -157,3 +168,198 @@ def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
return await asyncio.gather(*future_gdas)
return asyncio.run(_run_deserializer())
no_lockfiles_exists = lambda paths: all(not tf.io.gfile.exists(f) for f in paths)
class _RetryWithTimeout:
def __init__(self, secs):
self.secs = secs
def __enter__(self):
self.timeout_after = time.time() + self.secs
return self
def __exit__(self, type, value, traceback):
pass
@property
def timed_out(self):
return time.time() > self.timeout_after
class GlobalAsyncCheckpointManager:
"""Responsible for serializing GDAs via TensorStore.
This class manages the state of an ongoing asynchronous checkpoint.
For example, say a checkpoint happens on every step. If you checkpoint on
step 1 and after some computation the model is on checkpoint 2. But step 1's
checkpoint hasn't finished committing to the storage layer yet. So until that
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
class allows to maintain that state.
Example:
Below is a simplified training loop:
```
manager = GlobalAsyncCheckpointManager()
# Restore checkpoint if available or initialize the train_state from
# init_fn().
train_state = manager.deserialize(...)
while ...:
if step % num_steps_between_checkpoints == 0:
manager.serialize(train_state)
train_state = train_step(train_state, input)
# This is a non-blocking call.
manager.check_for_errors()
manager.serialize(train_state)
# Wait before the end of the program for the checkpoint to finish. This is a
# blocking call.
manager.wait_until_finished()
```
"""
def __init__(self, temp_checkpoint_dir, final_checkpoint_dir, timeout_secs=300):
self.temp_checkpoint_dir = temp_checkpoint_dir
self.final_checkpoint_dir = final_checkpoint_dir
self._timeout_secs = timeout_secs
self._commit_futures = None
self._thread = None
self._exception = None
def __del__(self):
if self._thread is not None and self._thread.is_alive():
logging.warning('Please add `.wait_until_finished()` in the main thread '
'before your program finishes because there is a '
'possibility of losing errors raised if the '
'GlobalAsyncCheckpointManager is deleted before '
'serialization is completed.')
def _thread_func(self):
try:
for future in self._commit_futures:
for f in future:
f.result()
logging.info('Commit to storage layer has completed.')
current_process = jax.process_index()
lockfiles_dir = os.path.join(self.temp_checkpoint_dir, 'lockfiles')
all_lockfile_paths = [
os.path.join(lockfiles_dir, f'lockfile_{p}')
for p in range(jax.process_count())
]
current_process_lockfile = os.path.join(lockfiles_dir,
f'lockfile_{current_process}')
with _RetryWithTimeout(self._timeout_secs) as t:
while not tf.io.gfile.exists(current_process_lockfile):
if t.timed_out:
raise RuntimeError('Terminating after waiting for '
f'{self._timeout_secs} secs for lockfile to appear')
logging.info('Waiting for current process %s lockfile to appear.',
current_process)
time.sleep(60)
tf.io.gfile.remove(current_process_lockfile)
logging.info('Lockfile removed for process %s', current_process)
# This while loop will not trigger until all commits have finished.
if current_process == 0:
with _RetryWithTimeout(self._timeout_secs) as t:
while True:
if t.timed_out:
raise RuntimeError('Terminating after waiting for '
f'{self._timeout_secs} secs for '
'finishing the serialization.')
# Mark as done when no lockfiles exist.
if no_lockfiles_exists(all_lockfile_paths):
tf.io.gfile.remove(lockfiles_dir)
logging.info('Lockfiles directory removed.')
logging.info('Renaming %s to %s', self.temp_checkpoint_dir, self.final_checkpoint_dir)
tf.io.gfile.rename(self.temp_checkpoint_dir, self.final_checkpoint_dir)
logging.info('Finished saving GDA checkpoint to `%s`.', self.final_checkpoint_dir)
break
else:
logging.info('Thread sleeping for 60 seconds.')
time.sleep(60)
except Exception as e:
self._exception = e
def _start_commit_thread(self):
self._thread = threading.Thread(target=self._thread_func)
self._thread.start()
def _write_lockfiles(self):
lockfiles_dir = os.path.join(self.temp_checkpoint_dir, 'lockfiles')
tf.io.gfile.mkdir(lockfiles_dir)
for p in range(jax.process_count()):
with tf.io.gfile.GFile(
os.path.join(lockfiles_dir, f'lockfile_{p}'), mode='w') as f:
f.write('File to track if all chunks have been written.')
if len(tf.io.gfile.listdir(lockfiles_dir)) != jax.process_count():
raise ValueError("Process 0 couldn't write all the lockfiles.")
logging.info('Lock files for all processes have been written by process 0.')
def check_for_errors(self):
if self._exception is not None:
# Clears self._exception so it is only raised once.
exception = self._exception
self._exception = None
raise exception # pylint: disable=raising-bad-type
def wait_until_finished(self):
if self._thread is not None:
self._thread.join()
self._thread = None
self.check_for_errors()
def serialize(self, gdas, tensorstore_specs):
"""Serializes GlobalDeviceArrays via TensorStore asynchronously.
TensorStore writes to a storage layer in 2 steps:
* Reading/copying from the source after which the source can be modified.
* Returns a copy future.
* Writing/committing to the storage layer.
* Returns a commit future.
In asynchronous mode, the serialization waits for the commit future to
finish in a separate thread allowing other computation to proceed.
Args:
gdas: GlobalDeviceArrays that should be serialized.
tensorstore_specs: TensorStore specs that are used to serialize GDAs.
"""
logging.info('Waiting for thread to finish serialization.')
self.wait_until_finished()
# Process 0 writes lock files for all processes.
if jax.process_index() == 0:
self._write_lockfiles()
self._commit_futures = [[] for _ in range(len(tensorstore_specs))]
async def _run_serializer():
future_writer = jax.tree_map(async_serialize, gdas,
tensorstore_specs, self._commit_futures)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
self._start_commit_thread()
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None, dtypes=None):
return run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
global_shapes, dtypes)

View File

@ -14,7 +14,6 @@
"""Tests for serialization and deserialization of GDA."""
import pathlib
import unittest
from absl.testing import absltest
import jax
@ -24,25 +23,15 @@ from jax.config import config
from jax.experimental import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.gda_serialization import serialization
from jax.experimental.maps import Mesh
import numpy as np
config.parse_flags_with_absl()
def create_global_mesh(mesh_shape, axis_names):
size = util.prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f'Test requires {size} local devices')
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
global_mesh = Mesh(mesh_devices, axis_names)
return global_mesh
class CheckpointTest(jtu.JaxTestCase):
def test_checkpointing(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
num = util.prod(global_input_shape)
@ -66,7 +55,7 @@ class CheckpointTest(jtu.JaxTestCase):
# Third GDA
def cb3(index):
return np.array([])
global_mesh1d = create_global_mesh((8,), ('x',))
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
@ -101,7 +90,7 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertEqual(m3.dtype, np.float32)
def test_checkpointing_with_bigger_shape(self):
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
num = util.prod(global_input_shape)
@ -119,7 +108,7 @@ class CheckpointTest(jtu.JaxTestCase):
serialization.run_serialization([gda1], tspecs)
m1, = serialization.run_deserialization(
[create_global_mesh((4, 2), ('x', 'y'))],
[jtu.create_global_mesh((4, 2), ('x', 'y'))],
[P('x', 'y')],
tspecs,
[(12, 2)],
@ -140,6 +129,47 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
def test_async_checkpointing(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
num = util.prod(global_input_shape)
# First GDA
global_input_data1 = np.arange(num).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb1)
temp_ckpt_dir1 = pathlib.Path(self.create_tempdir('temp_first').full_path)
ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first')
s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)])
manager = serialization.GlobalAsyncCheckpointManager(
temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1)
manager.serialize([gda1], s_tspecs)
manager.wait_until_finished()
d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)])
m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs)
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
np.array([[0], [2]]))
self.assertArraysEqual(m1.local_shards[1].data.to_py(),
np.array([[1], [3]]))
self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)
# Will throw `file already exists` error when `tf.io.gfile.rename`.
# `wait_until_finished` will raise the error.
with self.assertRaises(Exception):
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
manager1 = serialization.GlobalAsyncCheckpointManager(
temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1)
manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
final_checkpoint_dir=ckpt_dir1)
manager1.wait_until_finished()
def test_spec_has_metadata(self):
spec = {
'a': {