mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Take temp_checkpoint_dir and final_checkpoint_dir as the arguments to serialize instead of the __init__. THis is because this manager will be defined at the top where the directories may not yet be known.
PiperOrigin-RevId: 446104174
This commit is contained in:
parent
888e5c6958
commit
de7a872e1b
@ -225,9 +225,7 @@ class GlobalAsyncCheckpointManager:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, temp_checkpoint_dir, final_checkpoint_dir, timeout_secs=300):
|
||||
self.temp_checkpoint_dir = temp_checkpoint_dir
|
||||
self.final_checkpoint_dir = final_checkpoint_dir
|
||||
def __init__(self, timeout_secs=300):
|
||||
self._timeout_secs = timeout_secs
|
||||
|
||||
self._commit_futures = None
|
||||
@ -242,7 +240,7 @@ class GlobalAsyncCheckpointManager:
|
||||
'GlobalAsyncCheckpointManager is deleted before '
|
||||
'serialization is completed.')
|
||||
|
||||
def _thread_func(self):
|
||||
def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir):
|
||||
try:
|
||||
for future in self._commit_futures:
|
||||
for f in future:
|
||||
@ -250,7 +248,7 @@ class GlobalAsyncCheckpointManager:
|
||||
logging.info('Commit to storage layer has completed.')
|
||||
|
||||
current_process = jax.process_index()
|
||||
lockfiles_dir = os.path.join(self.temp_checkpoint_dir, 'lockfiles')
|
||||
lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
|
||||
all_lockfile_paths = [
|
||||
os.path.join(lockfiles_dir, f'lockfile_{p}')
|
||||
for p in range(jax.process_count())
|
||||
@ -283,9 +281,9 @@ class GlobalAsyncCheckpointManager:
|
||||
tf.io.gfile.remove(lockfiles_dir)
|
||||
logging.info('Lockfiles directory removed.')
|
||||
|
||||
logging.info('Renaming %s to %s', self.temp_checkpoint_dir, self.final_checkpoint_dir)
|
||||
tf.io.gfile.rename(self.temp_checkpoint_dir, self.final_checkpoint_dir)
|
||||
logging.info('Finished saving GDA checkpoint to `%s`.', self.final_checkpoint_dir)
|
||||
logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir)
|
||||
tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir)
|
||||
logging.info('Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir)
|
||||
break
|
||||
else:
|
||||
logging.info('Thread sleeping for 60 seconds.')
|
||||
@ -294,12 +292,14 @@ class GlobalAsyncCheckpointManager:
|
||||
except Exception as e:
|
||||
self._exception = e
|
||||
|
||||
def _start_commit_thread(self):
|
||||
self._thread = threading.Thread(target=self._thread_func)
|
||||
def _start_commit_thread(self, temp_checkpoint_dir, final_checkpoint_dir):
|
||||
self._thread = threading.Thread(
|
||||
target=self._thread_func,
|
||||
args=(temp_checkpoint_dir, final_checkpoint_dir))
|
||||
self._thread.start()
|
||||
|
||||
def _write_lockfiles(self):
|
||||
lockfiles_dir = os.path.join(self.temp_checkpoint_dir, 'lockfiles')
|
||||
def _write_lockfiles(self, temp_checkpoint_dir):
|
||||
lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
|
||||
tf.io.gfile.mkdir(lockfiles_dir)
|
||||
|
||||
for p in range(jax.process_count()):
|
||||
@ -326,7 +326,8 @@ class GlobalAsyncCheckpointManager:
|
||||
|
||||
self.check_for_errors()
|
||||
|
||||
def serialize(self, gdas, tensorstore_specs):
|
||||
def serialize(self, gdas, tensorstore_specs, *, temp_checkpoint_dir,
|
||||
final_checkpoint_dir):
|
||||
"""Serializes GlobalDeviceArrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
@ -341,13 +342,17 @@ class GlobalAsyncCheckpointManager:
|
||||
Args:
|
||||
gdas: GlobalDeviceArrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs.
|
||||
temp_checkpoint_dir: Temporary checkpoint directory where the checkpoints
|
||||
will be written.
|
||||
final_checkpoint_dir: Final checkpoint directory where the checkpoints
|
||||
will be moved from `temp_checkpoint_dir`.
|
||||
"""
|
||||
logging.info('Waiting for thread to finish serialization.')
|
||||
self.wait_until_finished()
|
||||
|
||||
# Process 0 writes lock files for all processes.
|
||||
if jax.process_index() == 0:
|
||||
self._write_lockfiles()
|
||||
self._write_lockfiles(temp_checkpoint_dir)
|
||||
|
||||
self._commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
||||
|
||||
@ -357,7 +362,7 @@ class GlobalAsyncCheckpointManager:
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
self._start_commit_thread()
|
||||
self._start_commit_thread(temp_checkpoint_dir, final_checkpoint_dir)
|
||||
|
||||
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None, dtypes=None):
|
||||
|
@ -146,9 +146,9 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)])
|
||||
|
||||
manager = serialization.GlobalAsyncCheckpointManager(
|
||||
temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1)
|
||||
manager.serialize([gda1], s_tspecs)
|
||||
manager = serialization.GlobalAsyncCheckpointManager()
|
||||
manager.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
|
||||
final_checkpoint_dir=ckpt_dir1)
|
||||
manager.wait_until_finished()
|
||||
|
||||
d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)])
|
||||
@ -164,8 +164,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
# `wait_until_finished` will raise the error.
|
||||
with self.assertRaises(Exception):
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
manager1 = serialization.GlobalAsyncCheckpointManager(
|
||||
temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1)
|
||||
manager1 = serialization.GlobalAsyncCheckpointManager()
|
||||
manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
|
||||
final_checkpoint_dir=ckpt_dir1)
|
||||
manager1.wait_until_finished()
|
||||
|
Loading…
x
Reference in New Issue
Block a user