Take temp_checkpoint_dir and final_checkpoint_dir as the arguments to serialize instead of the __init__. THis is because this manager will be defined at the top where the directories may not yet be known.

PiperOrigin-RevId: 446104174
This commit is contained in:
Yash Katariya 2022-05-02 21:22:11 -07:00 committed by jax authors
parent 888e5c6958
commit de7a872e1b
2 changed files with 24 additions and 20 deletions

View File

@ -225,9 +225,7 @@ class GlobalAsyncCheckpointManager:
```
"""
def __init__(self, temp_checkpoint_dir, final_checkpoint_dir, timeout_secs=300):
self.temp_checkpoint_dir = temp_checkpoint_dir
self.final_checkpoint_dir = final_checkpoint_dir
def __init__(self, timeout_secs=300):
self._timeout_secs = timeout_secs
self._commit_futures = None
@ -242,7 +240,7 @@ class GlobalAsyncCheckpointManager:
'GlobalAsyncCheckpointManager is deleted before '
'serialization is completed.')
def _thread_func(self):
def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir):
try:
for future in self._commit_futures:
for f in future:
@ -250,7 +248,7 @@ class GlobalAsyncCheckpointManager:
logging.info('Commit to storage layer has completed.')
current_process = jax.process_index()
lockfiles_dir = os.path.join(self.temp_checkpoint_dir, 'lockfiles')
lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
all_lockfile_paths = [
os.path.join(lockfiles_dir, f'lockfile_{p}')
for p in range(jax.process_count())
@ -283,9 +281,9 @@ class GlobalAsyncCheckpointManager:
tf.io.gfile.remove(lockfiles_dir)
logging.info('Lockfiles directory removed.')
logging.info('Renaming %s to %s', self.temp_checkpoint_dir, self.final_checkpoint_dir)
tf.io.gfile.rename(self.temp_checkpoint_dir, self.final_checkpoint_dir)
logging.info('Finished saving GDA checkpoint to `%s`.', self.final_checkpoint_dir)
logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir)
tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir)
logging.info('Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir)
break
else:
logging.info('Thread sleeping for 60 seconds.')
@ -294,12 +292,14 @@ class GlobalAsyncCheckpointManager:
except Exception as e:
self._exception = e
def _start_commit_thread(self):
self._thread = threading.Thread(target=self._thread_func)
def _start_commit_thread(self, temp_checkpoint_dir, final_checkpoint_dir):
self._thread = threading.Thread(
target=self._thread_func,
args=(temp_checkpoint_dir, final_checkpoint_dir))
self._thread.start()
def _write_lockfiles(self):
lockfiles_dir = os.path.join(self.temp_checkpoint_dir, 'lockfiles')
def _write_lockfiles(self, temp_checkpoint_dir):
lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
tf.io.gfile.mkdir(lockfiles_dir)
for p in range(jax.process_count()):
@ -326,7 +326,8 @@ class GlobalAsyncCheckpointManager:
self.check_for_errors()
def serialize(self, gdas, tensorstore_specs):
def serialize(self, gdas, tensorstore_specs, *, temp_checkpoint_dir,
final_checkpoint_dir):
"""Serializes GlobalDeviceArrays via TensorStore asynchronously.
TensorStore writes to a storage layer in 2 steps:
@ -341,13 +342,17 @@ class GlobalAsyncCheckpointManager:
Args:
gdas: GlobalDeviceArrays that should be serialized.
tensorstore_specs: TensorStore specs that are used to serialize GDAs.
temp_checkpoint_dir: Temporary checkpoint directory where the checkpoints
will be written.
final_checkpoint_dir: Final checkpoint directory where the checkpoints
will be moved from `temp_checkpoint_dir`.
"""
logging.info('Waiting for thread to finish serialization.')
self.wait_until_finished()
# Process 0 writes lock files for all processes.
if jax.process_index() == 0:
self._write_lockfiles()
self._write_lockfiles(temp_checkpoint_dir)
self._commit_futures = [[] for _ in range(len(tensorstore_specs))]
@ -357,7 +362,7 @@ class GlobalAsyncCheckpointManager:
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
self._start_commit_thread()
self._start_commit_thread(temp_checkpoint_dir, final_checkpoint_dir)
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None, dtypes=None):

View File

@ -146,9 +146,9 @@ class CheckpointTest(jtu.JaxTestCase):
s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)])
manager = serialization.GlobalAsyncCheckpointManager(
temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1)
manager.serialize([gda1], s_tspecs)
manager = serialization.GlobalAsyncCheckpointManager()
manager.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
final_checkpoint_dir=ckpt_dir1)
manager.wait_until_finished()
d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)])
@ -164,8 +164,7 @@ class CheckpointTest(jtu.JaxTestCase):
# `wait_until_finished` will raise the error.
with self.assertRaises(Exception):
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
manager1 = serialization.GlobalAsyncCheckpointManager(
temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1)
manager1 = serialization.GlobalAsyncCheckpointManager()
manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
final_checkpoint_dir=ckpt_dir1)
manager1.wait_until_finished()