mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add fully asynchronous checkpointing. This will allow the training to proceed forward when the checkpoint is being committed.
PiperOrigin-RevId: 446083057
This commit is contained in:
parent
939233e769
commit
b7293d5683
@ -14,8 +14,12 @@
|
||||
"""GlobalDeviceArray serialization and deserialization."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
from jax._src.util import prod
|
||||
@ -24,6 +28,11 @@ from jax.experimental.maps import Mesh
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
import tensorflow.compat.v2 as tf
|
||||
# internal import
|
||||
|
||||
|
||||
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
|
||||
|
||||
|
||||
async def create_async_gda_from_callback(
|
||||
@ -88,26 +97,28 @@ def get_tensorstore_spec(ckpt_path: str):
|
||||
return spec
|
||||
|
||||
|
||||
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
|
||||
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
|
||||
commit_future=None):
|
||||
# 'metadata' may not be present at the top level (for example, if we are using
|
||||
# a 'cast' driver).
|
||||
if not _spec_has_metadata(tensorstore_spec):
|
||||
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
|
||||
|
||||
t = await ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
create=True,
|
||||
open=True,
|
||||
context=ts.Context({'file_io_concurrency': {
|
||||
'limit': 128
|
||||
}}))
|
||||
ts.Spec(tensorstore_spec), create=True, open=True, context=TS_CONTEXT)
|
||||
|
||||
async def _write_array(shard):
|
||||
if shard.replica_id == 0:
|
||||
await t[shard.index].write(shard.data)
|
||||
write_future = t[shard.index].write(shard.data)
|
||||
if commit_future is not None:
|
||||
assert isinstance(commit_future, list)
|
||||
commit_future.append(write_future.commit)
|
||||
await write_future.copy
|
||||
else:
|
||||
await write_future.commit
|
||||
|
||||
future_write_state = jax.tree_util.tree_map(_write_array,
|
||||
tuple(gda_inp.local_shards))
|
||||
gda_inp.local_shards)
|
||||
return await asyncio.gather(*future_write_state)
|
||||
|
||||
|
||||
@ -120,7 +131,7 @@ def run_serialization(gdas, tensorstore_specs):
|
||||
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
|
||||
global_shape=None, dtype=None):
|
||||
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
|
||||
t = ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT).result()
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
requires_padding = prod(shape) > prod(t.shape)
|
||||
|
||||
@ -157,3 +168,198 @@ def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
||||
return await asyncio.gather(*future_gdas)
|
||||
return asyncio.run(_run_deserializer())
|
||||
|
||||
|
||||
no_lockfiles_exists = lambda paths: all(not tf.io.gfile.exists(f) for f in paths)
|
||||
|
||||
|
||||
class _RetryWithTimeout:
|
||||
def __init__(self, secs):
|
||||
self.secs = secs
|
||||
|
||||
def __enter__(self):
|
||||
self.timeout_after = time.time() + self.secs
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
pass
|
||||
|
||||
@property
|
||||
def timed_out(self):
|
||||
return time.time() > self.timeout_after
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManager:
|
||||
"""Responsible for serializing GDAs via TensorStore.
|
||||
|
||||
This class manages the state of an ongoing asynchronous checkpoint.
|
||||
|
||||
For example, say a checkpoint happens on every step. If you checkpoint on
|
||||
step 1 and after some computation the model is on checkpoint 2. But step 1's
|
||||
checkpoint hasn't finished committing to the storage layer yet. So until that
|
||||
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
|
||||
class allows to maintain that state.
|
||||
|
||||
Example:
|
||||
|
||||
Below is a simplified training loop:
|
||||
|
||||
```
|
||||
manager = GlobalAsyncCheckpointManager()
|
||||
|
||||
# Restore checkpoint if available or initialize the train_state from
|
||||
# init_fn().
|
||||
train_state = manager.deserialize(...)
|
||||
|
||||
while ...:
|
||||
if step % num_steps_between_checkpoints == 0:
|
||||
manager.serialize(train_state)
|
||||
train_state = train_step(train_state, input)
|
||||
# This is a non-blocking call.
|
||||
manager.check_for_errors()
|
||||
|
||||
manager.serialize(train_state)
|
||||
# Wait before the end of the program for the checkpoint to finish. This is a
|
||||
# blocking call.
|
||||
manager.wait_until_finished()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, temp_checkpoint_dir, final_checkpoint_dir, timeout_secs=300):
|
||||
self.temp_checkpoint_dir = temp_checkpoint_dir
|
||||
self.final_checkpoint_dir = final_checkpoint_dir
|
||||
self._timeout_secs = timeout_secs
|
||||
|
||||
self._commit_futures = None
|
||||
self._thread = None
|
||||
self._exception = None
|
||||
|
||||
def __del__(self):
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logging.warning('Please add `.wait_until_finished()` in the main thread '
|
||||
'before your program finishes because there is a '
|
||||
'possibility of losing errors raised if the '
|
||||
'GlobalAsyncCheckpointManager is deleted before '
|
||||
'serialization is completed.')
|
||||
|
||||
def _thread_func(self):
|
||||
try:
|
||||
for future in self._commit_futures:
|
||||
for f in future:
|
||||
f.result()
|
||||
logging.info('Commit to storage layer has completed.')
|
||||
|
||||
current_process = jax.process_index()
|
||||
lockfiles_dir = os.path.join(self.temp_checkpoint_dir, 'lockfiles')
|
||||
all_lockfile_paths = [
|
||||
os.path.join(lockfiles_dir, f'lockfile_{p}')
|
||||
for p in range(jax.process_count())
|
||||
]
|
||||
current_process_lockfile = os.path.join(lockfiles_dir,
|
||||
f'lockfile_{current_process}')
|
||||
|
||||
with _RetryWithTimeout(self._timeout_secs) as t:
|
||||
while not tf.io.gfile.exists(current_process_lockfile):
|
||||
if t.timed_out:
|
||||
raise RuntimeError('Terminating after waiting for '
|
||||
f'{self._timeout_secs} secs for lockfile to appear')
|
||||
logging.info('Waiting for current process %s lockfile to appear.',
|
||||
current_process)
|
||||
time.sleep(60)
|
||||
|
||||
tf.io.gfile.remove(current_process_lockfile)
|
||||
logging.info('Lockfile removed for process %s', current_process)
|
||||
|
||||
# This while loop will not trigger until all commits have finished.
|
||||
if current_process == 0:
|
||||
with _RetryWithTimeout(self._timeout_secs) as t:
|
||||
while True:
|
||||
if t.timed_out:
|
||||
raise RuntimeError('Terminating after waiting for '
|
||||
f'{self._timeout_secs} secs for '
|
||||
'finishing the serialization.')
|
||||
# Mark as done when no lockfiles exist.
|
||||
if no_lockfiles_exists(all_lockfile_paths):
|
||||
tf.io.gfile.remove(lockfiles_dir)
|
||||
logging.info('Lockfiles directory removed.')
|
||||
|
||||
logging.info('Renaming %s to %s', self.temp_checkpoint_dir, self.final_checkpoint_dir)
|
||||
tf.io.gfile.rename(self.temp_checkpoint_dir, self.final_checkpoint_dir)
|
||||
logging.info('Finished saving GDA checkpoint to `%s`.', self.final_checkpoint_dir)
|
||||
break
|
||||
else:
|
||||
logging.info('Thread sleeping for 60 seconds.')
|
||||
time.sleep(60)
|
||||
|
||||
except Exception as e:
|
||||
self._exception = e
|
||||
|
||||
def _start_commit_thread(self):
|
||||
self._thread = threading.Thread(target=self._thread_func)
|
||||
self._thread.start()
|
||||
|
||||
def _write_lockfiles(self):
|
||||
lockfiles_dir = os.path.join(self.temp_checkpoint_dir, 'lockfiles')
|
||||
tf.io.gfile.mkdir(lockfiles_dir)
|
||||
|
||||
for p in range(jax.process_count()):
|
||||
with tf.io.gfile.GFile(
|
||||
os.path.join(lockfiles_dir, f'lockfile_{p}'), mode='w') as f:
|
||||
f.write('File to track if all chunks have been written.')
|
||||
|
||||
if len(tf.io.gfile.listdir(lockfiles_dir)) != jax.process_count():
|
||||
raise ValueError("Process 0 couldn't write all the lockfiles.")
|
||||
|
||||
logging.info('Lock files for all processes have been written by process 0.')
|
||||
|
||||
def check_for_errors(self):
|
||||
if self._exception is not None:
|
||||
# Clears self._exception so it is only raised once.
|
||||
exception = self._exception
|
||||
self._exception = None
|
||||
raise exception # pylint: disable=raising-bad-type
|
||||
|
||||
def wait_until_finished(self):
|
||||
if self._thread is not None:
|
||||
self._thread.join()
|
||||
self._thread = None
|
||||
|
||||
self.check_for_errors()
|
||||
|
||||
def serialize(self, gdas, tensorstore_specs):
|
||||
"""Serializes GlobalDeviceArrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
* Reading/copying from the source after which the source can be modified.
|
||||
* Returns a copy future.
|
||||
* Writing/committing to the storage layer.
|
||||
* Returns a commit future.
|
||||
|
||||
In asynchronous mode, the serialization waits for the commit future to
|
||||
finish in a separate thread allowing other computation to proceed.
|
||||
|
||||
Args:
|
||||
gdas: GlobalDeviceArrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs.
|
||||
"""
|
||||
logging.info('Waiting for thread to finish serialization.')
|
||||
self.wait_until_finished()
|
||||
|
||||
# Process 0 writes lock files for all processes.
|
||||
if jax.process_index() == 0:
|
||||
self._write_lockfiles()
|
||||
|
||||
self._commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
||||
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_map(async_serialize, gdas,
|
||||
tensorstore_specs, self._commit_futures)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
self._start_commit_thread()
|
||||
|
||||
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None, dtypes=None):
|
||||
return run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes, dtypes)
|
||||
|
@ -14,7 +14,6 @@
|
||||
"""Tests for serialization and deserialization of GDA."""
|
||||
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
@ -24,25 +23,15 @@ from jax.config import config
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
from jax.experimental.maps import Mesh
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def create_global_mesh(mesh_shape, axis_names):
|
||||
size = util.prod(mesh_shape)
|
||||
if len(jax.devices()) < size:
|
||||
raise unittest.SkipTest(f'Test requires {size} local devices')
|
||||
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
|
||||
global_mesh = Mesh(mesh_devices, axis_names)
|
||||
return global_mesh
|
||||
|
||||
|
||||
class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
def test_checkpointing(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
num = util.prod(global_input_shape)
|
||||
@ -66,7 +55,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
# Third GDA
|
||||
def cb3(index):
|
||||
return np.array([])
|
||||
global_mesh1d = create_global_mesh((8,), ('x',))
|
||||
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
|
||||
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
|
||||
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
|
||||
|
||||
@ -101,7 +90,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_with_bigger_shape(self):
|
||||
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
@ -119,7 +108,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[create_global_mesh((4, 2), ('x', 'y'))],
|
||||
[jtu.create_global_mesh((4, 2), ('x', 'y'))],
|
||||
[P('x', 'y')],
|
||||
tspecs,
|
||||
[(12, 2)],
|
||||
@ -140,6 +129,47 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
|
||||
|
||||
def test_async_checkpointing(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
# First GDA
|
||||
global_input_data1 = np.arange(num).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb1)
|
||||
temp_ckpt_dir1 = pathlib.Path(self.create_tempdir('temp_first').full_path)
|
||||
ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first')
|
||||
|
||||
s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)])
|
||||
|
||||
manager = serialization.GlobalAsyncCheckpointManager(
|
||||
temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1)
|
||||
manager.serialize([gda1], s_tspecs)
|
||||
manager.wait_until_finished()
|
||||
|
||||
d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)])
|
||||
m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs)
|
||||
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
|
||||
np.array([[0], [2]]))
|
||||
self.assertArraysEqual(m1.local_shards[1].data.to_py(),
|
||||
np.array([[1], [3]]))
|
||||
self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
# Will throw `file already exists` error when `tf.io.gfile.rename`.
|
||||
# `wait_until_finished` will raise the error.
|
||||
with self.assertRaises(Exception):
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
manager1 = serialization.GlobalAsyncCheckpointManager(
|
||||
temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1)
|
||||
manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
|
||||
final_checkpoint_dir=ckpt_dir1)
|
||||
manager1.wait_until_finished()
|
||||
|
||||
def test_spec_has_metadata(self):
|
||||
spec = {
|
||||
'a': {
|
||||
|
Loading…
x
Reference in New Issue
Block a user