mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use JAX's distributed system for fully asynchronous checkpointing.
PiperOrigin-RevId: 449380175
This commit is contained in:
parent
720d09c7df
commit
0574eb2141
@ -32,8 +32,14 @@ def initialize(coordinator_address: Optional[str] = None,
|
||||
process_id: Optional[int] = None):
|
||||
"""Initialize distributed system for topology discovery.
|
||||
|
||||
Currently, calling ``initialize`` sets up the multi-host GPU backend, and
|
||||
is not required for CPU or TPU backends.
|
||||
Currently, calling ``initialize`` sets up the multi-host GPU backend and Cloud
|
||||
TPU backend.
|
||||
|
||||
If you are on GPU platform, you will have to provide the coordinator_address
|
||||
and other args to the `initialize` API.
|
||||
|
||||
If you are on TPU platform, the coordinator_address and other args will be
|
||||
auto detected but you have the option to provide it too.
|
||||
|
||||
Args:
|
||||
coordinator_address: IP address and port of the coordinator. The choice of
|
||||
|
@ -14,7 +14,6 @@
|
||||
"""GlobalDeviceArray serialization and deserialization."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
@ -22,6 +21,7 @@ from typing import Callable
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
from jax._src import distributed
|
||||
from jax._src.util import prod
|
||||
from jax.experimental import global_device_array as gda
|
||||
from jax.experimental.maps import Mesh
|
||||
@ -29,10 +29,12 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
import tensorflow.compat.v2 as tf
|
||||
# internal import
|
||||
|
||||
|
||||
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
|
||||
_REMOVED_VALUE = 'Value removed'
|
||||
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
||||
_CHECKPOINT_FAILURE = 'checkpoint_write_failed'
|
||||
|
||||
|
||||
async def create_async_gda_from_callback(
|
||||
@ -170,9 +172,6 @@ def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
return asyncio.run(_run_deserializer())
|
||||
|
||||
|
||||
no_lockfiles_exists = lambda paths: all(not tf.io.gfile.exists(f) for f in paths)
|
||||
|
||||
|
||||
class _RetryWithTimeout:
|
||||
def __init__(self, secs):
|
||||
self.secs = secs
|
||||
@ -189,6 +188,10 @@ class _RetryWithTimeout:
|
||||
return time.time() > self.timeout_after
|
||||
|
||||
|
||||
def _get_key(key: str):
|
||||
return f'checkpoint_{key}'
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManager:
|
||||
"""Responsible for serializing GDAs via TensorStore.
|
||||
|
||||
@ -205,6 +208,9 @@ class GlobalAsyncCheckpointManager:
|
||||
Below is a simplified training loop:
|
||||
|
||||
```
|
||||
# Call this at the start of your program.
|
||||
jax.distributed.initialize()
|
||||
|
||||
manager = GlobalAsyncCheckpointManager()
|
||||
|
||||
# Restore checkpoint if available or initialize the train_state from
|
||||
@ -227,13 +233,21 @@ class GlobalAsyncCheckpointManager:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, timeout_secs=600):
|
||||
def __init__(self, timeout_secs=300):
|
||||
self._timeout_secs = timeout_secs
|
||||
self._timeout_in_ms = self._timeout_secs * 1000
|
||||
|
||||
self._commit_futures = None
|
||||
self._thread = None
|
||||
self._exception = None
|
||||
|
||||
if distributed.distributed_client is None:
|
||||
raise ValueError('Please initialize the distributed system via '
|
||||
'`jax.distributed.initialize()` at the start of your '
|
||||
'program.')
|
||||
self._client = distributed.distributed_client
|
||||
self._final_ckpt_dir = None
|
||||
|
||||
def __del__(self):
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logging.warning('Please add `.wait_until_finished()` in the main thread '
|
||||
@ -250,48 +264,47 @@ class GlobalAsyncCheckpointManager:
|
||||
logging.info('Commit to storage layer has completed.')
|
||||
|
||||
current_process = jax.process_index()
|
||||
lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
|
||||
all_lockfile_paths = [
|
||||
os.path.join(lockfiles_dir, f'lockfile_{p}')
|
||||
for p in range(jax.process_count())
|
||||
]
|
||||
current_process_lockfile = os.path.join(lockfiles_dir,
|
||||
f'lockfile_{current_process}')
|
||||
|
||||
# TODO(yashkatariya): Add a method to the distributed system to wait for
|
||||
# the key's value to change.
|
||||
# Value already exists -- wait until value is NOT _REMOVED_VALUE.
|
||||
with _RetryWithTimeout(self._timeout_secs) as t:
|
||||
while not tf.io.gfile.exists(current_process_lockfile):
|
||||
while self._client.blocking_key_value_get(
|
||||
_get_key(str(current_process)), self._timeout_in_ms) == _REMOVED_VALUE:
|
||||
if t.timed_out:
|
||||
raise RuntimeError('Terminating after waiting for '
|
||||
f'{self._timeout_secs} secs for lockfile to appear')
|
||||
logging.info('Waiting for current process %s lockfile to appear.',
|
||||
raise TimeoutError('Terminating after waiting for '
|
||||
f'{self._timeout_secs} secs for lock value to appear.')
|
||||
logging.info('Waiting for current process %s lock value to appear.',
|
||||
current_process)
|
||||
time.sleep(60)
|
||||
|
||||
tf.io.gfile.remove(current_process_lockfile)
|
||||
logging.info('Lockfile removed for process %s', current_process)
|
||||
self._client.key_value_set(_get_key(str(current_process)), _REMOVED_VALUE)
|
||||
logging.info('Lock value removed for process %s', current_process)
|
||||
|
||||
# This while loop will not trigger until all commits have finished.
|
||||
if current_process == 0:
|
||||
with _RetryWithTimeout(self._timeout_secs) as t:
|
||||
while True:
|
||||
if t.timed_out:
|
||||
raise RuntimeError('Terminating after waiting for '
|
||||
raise TimeoutError('Terminating after waiting for '
|
||||
f'{self._timeout_secs} secs for '
|
||||
'finishing the serialization.')
|
||||
# Mark as done when no lockfiles exist.
|
||||
if no_lockfiles_exists(all_lockfile_paths):
|
||||
tf.io.gfile.rmtree(lockfiles_dir)
|
||||
logging.info('Lockfiles directory removed.')
|
||||
|
||||
# Mark as done when no lock values exist.
|
||||
if all(
|
||||
self._client.blocking_key_value_get(
|
||||
_get_key(str(p)), self._timeout_in_ms) == _REMOVED_VALUE
|
||||
for p in range(jax.process_count())):
|
||||
logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir)
|
||||
tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir)
|
||||
logging.info('Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir)
|
||||
self._client.key_value_set(_get_key(self._final_ckpt_dir), _CHECKPOINT_SUCCESS)
|
||||
break
|
||||
else:
|
||||
logging.info('Thread sleeping for 60 seconds.')
|
||||
time.sleep(60)
|
||||
|
||||
except Exception as e:
|
||||
self._client.key_value_set(_get_key(self._final_ckpt_dir), _CHECKPOINT_FAILURE)
|
||||
self._exception = e
|
||||
|
||||
def _start_commit_thread(self, temp_checkpoint_dir, final_checkpoint_dir):
|
||||
@ -300,21 +313,16 @@ class GlobalAsyncCheckpointManager:
|
||||
args=(temp_checkpoint_dir, final_checkpoint_dir))
|
||||
self._thread.start()
|
||||
|
||||
def _write_lockfiles(self, temp_checkpoint_dir):
|
||||
lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
|
||||
tf.io.gfile.mkdir(lockfiles_dir)
|
||||
|
||||
def _write_lock_values(self):
|
||||
write_count = 0
|
||||
for p in range(jax.process_count()):
|
||||
with tf.io.gfile.GFile(
|
||||
os.path.join(lockfiles_dir, f'lockfile_{p}'), mode='w') as f:
|
||||
f.write('File to track if all chunks have been written.')
|
||||
# TODO(yashkatariya): Make the key value store writes safe if checkpoint
|
||||
# managers are created concurrently.
|
||||
self._client.key_value_set(_get_key(str(p)), f'Lock value for process {str(p)}')
|
||||
write_count += 1
|
||||
|
||||
if write_count != jax.process_count():
|
||||
raise ValueError("Process 0 couldn't write all the lockfiles.")
|
||||
|
||||
logging.info('Lock files for all processes have been written by process 0.')
|
||||
raise ValueError("Process 0 couldn't write all the lock values.")
|
||||
logging.info('Lock values for all processes have been written by process 0.')
|
||||
|
||||
def check_for_errors(self):
|
||||
if self._exception is not None:
|
||||
@ -330,6 +338,22 @@ class GlobalAsyncCheckpointManager:
|
||||
|
||||
self.check_for_errors()
|
||||
|
||||
if self._final_ckpt_dir is not None:
|
||||
with _RetryWithTimeout(self._timeout_secs) as t:
|
||||
if t.timed_out:
|
||||
raise TimeoutError("Process 0 didn't finish the rename from "
|
||||
"temporary to final checkpoint directory")
|
||||
val = self._client.blocking_key_value_get(
|
||||
_get_key(self._final_ckpt_dir), self._timeout_in_ms)
|
||||
while val != _CHECKPOINT_SUCCESS:
|
||||
if val == _CHECKPOINT_FAILURE:
|
||||
raise ValueError(
|
||||
'Checkpoint write failed. Please check the error message on '
|
||||
'all hosts to see the real failure.')
|
||||
logging.info('Waiting for final checkpoint directory to exist on '
|
||||
'process %s', jax.process_index())
|
||||
time.sleep(15)
|
||||
|
||||
def serialize(self, gdas, tensorstore_specs, *, temp_checkpoint_dir,
|
||||
final_checkpoint_dir):
|
||||
"""Serializes GlobalDeviceArrays via TensorStore asynchronously.
|
||||
@ -356,7 +380,7 @@ class GlobalAsyncCheckpointManager:
|
||||
|
||||
# Process 0 writes lock files for all processes.
|
||||
if jax.process_index() == 0:
|
||||
self._write_lockfiles(temp_checkpoint_dir)
|
||||
self._write_lock_values()
|
||||
|
||||
self._commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
||||
|
||||
@ -366,6 +390,9 @@ class GlobalAsyncCheckpointManager:
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
# Used in wait_until_finished to check on process != 0, if the checkpoint
|
||||
# has finished writing.
|
||||
self._final_ckpt_dir = final_checkpoint_dir
|
||||
self._start_commit_thread(temp_checkpoint_dir, final_checkpoint_dir)
|
||||
|
||||
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
|
||||
|
@ -129,46 +129,6 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
|
||||
|
||||
def test_async_checkpointing(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
# First GDA
|
||||
global_input_data1 = np.arange(num).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb1)
|
||||
temp_ckpt_dir1 = pathlib.Path(self.create_tempdir('temp_first').full_path)
|
||||
ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first')
|
||||
|
||||
s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)])
|
||||
|
||||
manager = serialization.GlobalAsyncCheckpointManager()
|
||||
manager.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
|
||||
final_checkpoint_dir=ckpt_dir1)
|
||||
manager.wait_until_finished()
|
||||
|
||||
d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)])
|
||||
m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs)
|
||||
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
|
||||
np.array([[0], [2]]))
|
||||
self.assertArraysEqual(m1.local_shards[1].data.to_py(),
|
||||
np.array([[1], [3]]))
|
||||
self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
# Will throw `file already exists` error when `tf.io.gfile.rename`.
|
||||
# `wait_until_finished` will raise the error.
|
||||
with self.assertRaises(Exception):
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
manager1 = serialization.GlobalAsyncCheckpointManager()
|
||||
manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1,
|
||||
final_checkpoint_dir=ckpt_dir1)
|
||||
manager1.wait_until_finished()
|
||||
|
||||
def test_spec_has_metadata(self):
|
||||
spec = {
|
||||
'a': {
|
||||
|
Loading…
x
Reference in New Issue
Block a user