mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
PiperOrigin-RevId: 521055760
This commit is contained in:
parent
0b31e8b822
commit
d27a80dbfa
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -160,4 +160,4 @@ jobs:
|
||||
JAX_ARRAY: 1
|
||||
run: |
|
||||
pytest -n auto --tb=short docs
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization --ignore=jax/collect_profile.py
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/experimental/gda_serialization --ignore=jax/collect_profile.py
|
||||
|
@ -9,6 +9,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
## jax 0.4.9
|
||||
|
||||
* Deprecations
|
||||
* `jax.experimental.gda_serialization` is deprecated and has been renamed to
|
||||
`jax.experimental.array_serialization`.
|
||||
Please change your imports to use `jax.experimental.array_serialization`.
|
||||
* The `in_axis_resources` and `out_axis_resources` arguments of pjit have been
|
||||
deprecated. Please use `in_shardings` and `out_shardings` respectively.
|
||||
* The function `jax.numpy.msort` has been removed. It has been deprecated since
|
||||
|
13
jax/experimental/array_serialization/__init__.py
Normal file
13
jax/experimental/array_serialization/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
502
jax/experimental/array_serialization/serialization.py
Normal file
502
jax/experimental/array_serialization/serialization.py
Normal file
@ -0,0 +1,502 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""GlobalDeviceArray serialization and deserialization."""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
from functools import partial
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Callable, Sequence, Optional, Dict, Any
|
||||
|
||||
import jax
|
||||
from jax._src import distributed
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import typing
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
|
||||
|
||||
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
|
||||
_REMOVED_VALUE = 'Value removed'
|
||||
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
||||
_module_unique_count = itertools.count()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_async_array_from_callback(
|
||||
global_shape: array.Shape,
|
||||
inp_sharding: sharding_impls.XLACompatibleSharding,
|
||||
data_callback: Callable[[array.Index], asyncio.Future],
|
||||
):
|
||||
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
|
||||
addressable_da = inp_sharding._addressable_device_assignment
|
||||
future_arrays = [data_callback(device_to_index_map[d]) # type: ignore
|
||||
for d in addressable_da]
|
||||
# Pause here and come back to `from_async_callback()` when future_arrays are
|
||||
# ready. device_put cannot happen with future_arrays.
|
||||
local_arrays = await asyncio.gather(*future_arrays)
|
||||
|
||||
dbs = [jax.device_put(array, device)
|
||||
for array, device in zip(local_arrays, addressable_da)]
|
||||
return array.make_array_from_single_device_arrays(
|
||||
global_shape, inp_sharding, dbs)
|
||||
|
||||
|
||||
def _get_metadata(arr):
|
||||
if arr.dtype == jnp.bfloat16:
|
||||
# Tensorstore uses 'bfloat16', not '<V2'.
|
||||
dtype = 'bfloat16'
|
||||
else:
|
||||
dtype = np.dtype(arr.dtype).str
|
||||
local_shape = arr.addressable_data(0).shape
|
||||
return {
|
||||
'compressor': {
|
||||
'id': 'gzip'
|
||||
},
|
||||
'shape': arr.shape,
|
||||
'chunks': np.array(np.maximum(1, local_shape)),
|
||||
'dtype': dtype,
|
||||
}
|
||||
|
||||
|
||||
def _spec_has_metadata(tree):
|
||||
if not isinstance(tree, dict):
|
||||
return False
|
||||
return 'metadata' in tree or any(
|
||||
_spec_has_metadata(subtree) for _, subtree in tree.items())
|
||||
|
||||
def _get_kvstore_for_gcs(ckpt_path: str):
|
||||
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
|
||||
if m is None:
|
||||
raise ValueError('The ckpt_path should contain the bucket name and the '
|
||||
f'file path inside the bucket. Got: {ckpt_path}')
|
||||
gcs_bucket = m.group(1)
|
||||
path_without_bucket = m.group(2)
|
||||
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}
|
||||
|
||||
def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
|
||||
is_gcs_path = ckpt_path.startswith('gs://')
|
||||
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
|
||||
# fix the path prefix to add back the stripped '/'.
|
||||
ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://')
|
||||
spec = {'driver': 'zarr', 'kvstore': {}}
|
||||
if ocdbt:
|
||||
prefix = 'gs' if is_gcs_path else 'file'
|
||||
spec['kvstore'] = {
|
||||
'driver': 'ocdbt',
|
||||
'base': f'{prefix}://{os.path.dirname(ckpt_path)}',
|
||||
'path': os.path.basename(ckpt_path),
|
||||
}
|
||||
else:
|
||||
if is_gcs_path:
|
||||
spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path)
|
||||
else:
|
||||
spec['kvstore'] = {'driver': 'file', 'path': ckpt_path}
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
# Lifted from T5X.
|
||||
class _LimitInFlightBytes:
|
||||
"""Limits in-flight bytes when reading/writing checkpoints per process."""
|
||||
|
||||
def __init__(self, num_bytes):
|
||||
self._max_bytes = num_bytes
|
||||
self._available_bytes = num_bytes
|
||||
self._cv = asyncio.Condition(lock=asyncio.Lock())
|
||||
|
||||
async def wait_for_bytes(self, requested_bytes):
|
||||
if requested_bytes >= self._max_bytes:
|
||||
raise ValueError('Requested more bytes than we reserved space for: '
|
||||
f'{requested_bytes} > {self._max_bytes}')
|
||||
async with self._cv:
|
||||
await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
|
||||
self._available_bytes -= requested_bytes
|
||||
assert self._available_bytes >= 0
|
||||
|
||||
async def release_bytes(self, requested_bytes):
|
||||
async with self._cv:
|
||||
self._available_bytes += requested_bytes
|
||||
assert self._available_bytes <= self._max_bytes
|
||||
self._cv.notify_all()
|
||||
|
||||
|
||||
async def async_serialize(
|
||||
arr_inp, tensorstore_spec, commit_future=None, context=TS_CONTEXT
|
||||
):
|
||||
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
|
||||
arr_inp.is_fully_addressable):
|
||||
raise ValueError('Passing fully addressable Arrays to a multiprocess '
|
||||
'serialization is not allowed.')
|
||||
# 'metadata' may not be present at the top level (for example, if we are using
|
||||
# a 'cast' driver).
|
||||
if not _spec_has_metadata(tensorstore_spec):
|
||||
tensorstore_spec['metadata'] = _get_metadata(arr_inp)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
open_future = ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
create=True,
|
||||
open=True,
|
||||
context=context,
|
||||
)
|
||||
# Asynchronous case.
|
||||
if commit_future is not None:
|
||||
assert isinstance(commit_future, list)
|
||||
commit_future.append(open_future)
|
||||
else:
|
||||
await open_future
|
||||
|
||||
# `ts.open` runs twice for process 0 because for the first time, we just get
|
||||
# the future to be awaited upon in the background thread. The second one runs
|
||||
# with `assume_metadata=True` which does no I/O operation and returns the
|
||||
# tensorstore object.
|
||||
# For every process other than `0`, we open with `assume_metadata=True`.
|
||||
t = await ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
open=True,
|
||||
assume_metadata=True,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _write_array(shard):
|
||||
if shard.replica_id == 0:
|
||||
write_future = t[shard.index].write(shard.data)
|
||||
if commit_future is not None:
|
||||
assert isinstance(commit_future, list)
|
||||
commit_future.append(write_future.commit)
|
||||
await write_future.copy
|
||||
else:
|
||||
await write_future.commit
|
||||
|
||||
if isinstance(arr_inp, array.ArrayImpl):
|
||||
local_shards = arr_inp.addressable_shards
|
||||
else:
|
||||
local_shards = arr_inp.addressable_shards
|
||||
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
|
||||
return await asyncio.gather(*future_write_state)
|
||||
|
||||
|
||||
def run_serialization(arrays, tensorstore_specs):
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
|
||||
def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
|
||||
rank = t.rank
|
||||
num_bytes = t.dtype.numpy_dtype.itemsize
|
||||
chunk_template = t.chunk_layout.read_chunk_template
|
||||
origin = t.domain.origin
|
||||
shape = t.domain.shape
|
||||
chunk_origin = chunk_template.origin
|
||||
chunk_shape = chunk_template.shape
|
||||
|
||||
# Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
|
||||
# For those, instead of returning a near-infinite memory footprint, estimate
|
||||
# the footprint as the entire shape.
|
||||
for i in range(rank):
|
||||
if not chunk_template[i].finite:
|
||||
return t.domain.size * num_bytes
|
||||
|
||||
# Otherwise, if we have a chunked driver, estimate based on chunk size.
|
||||
for i in range(rank):
|
||||
origin_value = origin[i]
|
||||
chunk_origin_value = chunk_origin[i]
|
||||
chunk_size = chunk_shape[i]
|
||||
lower = origin_value - chunk_origin_value
|
||||
upper = origin_value + shape[i] - chunk_origin_value
|
||||
lower_aligned = lower // chunk_size * chunk_size
|
||||
upper_aligned = -(-upper // chunk_size) * chunk_size
|
||||
num_bytes *= (upper_aligned - lower_aligned)
|
||||
|
||||
return num_bytes
|
||||
|
||||
|
||||
async def async_deserialize(
|
||||
in_sharding,
|
||||
tensorstore_spec,
|
||||
global_shape=None,
|
||||
dtype=None,
|
||||
byte_limiter: Optional[_LimitInFlightBytes] = None,
|
||||
context=TS_CONTEXT,
|
||||
):
|
||||
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=context)
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
new_shard_shape = in_sharding.shard_shape(tuple(shape))
|
||||
|
||||
async def cb(index):
|
||||
# This maybe needed because the shape the array was saved with is smaller
|
||||
# than the requested shape of the array in which it will be reloaded. So
|
||||
# the extra values will be filled with 0s.
|
||||
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
|
||||
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
|
||||
restricted_domain = t.domain.intersect(requested_domain)
|
||||
|
||||
requested_bytes = estimate_read_memory_footprint(t[restricted_domain])
|
||||
|
||||
# Limit the bytes read for every shard.
|
||||
if byte_limiter is not None:
|
||||
await byte_limiter.wait_for_bytes(requested_bytes)
|
||||
|
||||
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(
|
||||
t[restricted_domain])
|
||||
|
||||
if dtype is not None:
|
||||
# Cast while reloading on process to avoid 2 copies on device if the
|
||||
# casting is done on device.
|
||||
out = out.astype(dtype)
|
||||
|
||||
if byte_limiter is not None:
|
||||
await byte_limiter.release_bytes(requested_bytes)
|
||||
return out
|
||||
|
||||
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
|
||||
|
||||
|
||||
def run_deserialization(shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
|
||||
concurrent_gb: int = 32):
|
||||
concurrent_bytes = concurrent_gb * 10**9
|
||||
|
||||
async def _run_deserializer():
|
||||
# Object should be created once per process.
|
||||
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
|
||||
|
||||
future_arrays = jax.tree_util.tree_map(
|
||||
partial(async_deserialize, byte_limiter=byte_limiter),
|
||||
shardings, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
||||
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
||||
return await asyncio.gather(*future_arrays)
|
||||
return asyncio.run(_run_deserializer())
|
||||
|
||||
|
||||
def _get_key(key: str):
|
||||
return f'tensorstore_checkpoint_{key}'
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManagerBase(metaclass=abc.ABCMeta):
|
||||
"""Interface for checkpointing GDAs asynchronously.
|
||||
|
||||
This class manages the state of an ongoing asynchronous checkpoint.
|
||||
|
||||
For example, say a checkpoint happens on every step. If you checkpoint on
|
||||
step 1 and after some computation the model is on checkpoint 2. But step 1's
|
||||
checkpoint hasn't finished committing to the storage layer yet. So until that
|
||||
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
|
||||
class allows to maintain that state.
|
||||
|
||||
Example:
|
||||
|
||||
Below is a simplified training loop:
|
||||
|
||||
```
|
||||
# Call this at the start of your program.
|
||||
jax.distributed.initialize()
|
||||
|
||||
manager = GlobalAsyncCheckpointManager()
|
||||
|
||||
# Restore checkpoint if available or initialize the train_state from
|
||||
# init_fn().
|
||||
train_state = manager.deserialize(...)
|
||||
|
||||
while ...:
|
||||
if step % num_steps_between_checkpoints == 0:
|
||||
manager.serialize(train_state, temp_checkpoint_dir=...,
|
||||
final_checkpoint_dir=...)
|
||||
train_state = train_step(train_state, input)
|
||||
# This is a non-blocking call.
|
||||
manager.check_for_errors()
|
||||
|
||||
manager.serialize(train_state, temp_checkpoint_dir=...,
|
||||
final_checkpoint_dir=...)
|
||||
# Wait before the end of the program for the checkpoint to finish. This is a
|
||||
# blocking call.
|
||||
manager.wait_until_finished()
|
||||
```
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_for_errors(self):
|
||||
"""Checks if any errors have been raised in the child thread.
|
||||
|
||||
This is a non-blocking call that can be called in the main thread.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def wait_until_finished(self):
|
||||
"""Blocks until serialization has finished."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def serialize(self, arrays, tensorstore_specs, *,
|
||||
on_commit_callback: Callable[[], None]):
|
||||
"""Serializes GDAs to TensorStore."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
|
||||
"""Deserializes GDAs from TensorStore."""
|
||||
|
||||
|
||||
class AsyncManager:
|
||||
|
||||
def __init__(self, timeout_secs=300):
|
||||
self._timeout_secs = timeout_secs
|
||||
self._timeout_in_ms = self._timeout_secs * 1000
|
||||
|
||||
self._commit_futures = None
|
||||
self._thread = None
|
||||
self._exception = None
|
||||
|
||||
if distributed.global_state.client is None:
|
||||
raise ValueError('Please initialize the distributed system via '
|
||||
'`jax.distributed.initialize()` at the start of your '
|
||||
'program.')
|
||||
self._client = distributed.global_state.client
|
||||
self._count = None
|
||||
|
||||
def __del__(self):
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logger.warning('Please add `.wait_until_finished()` in the main thread '
|
||||
'before your program finishes because there is a '
|
||||
'possibility of losing errors raised if the '
|
||||
'this class is deleted before writing is completed.')
|
||||
|
||||
def _thread_func(self):
|
||||
try:
|
||||
current_process = jax.process_index()
|
||||
logger.info('Starting commit to storage layer by process: %s',
|
||||
current_process)
|
||||
for future in self._commit_futures:
|
||||
future.result()
|
||||
logger.info('Finished committing to storage layer by process: %s',
|
||||
current_process)
|
||||
|
||||
# All processes will wait at the barrier. When all processes are at the
|
||||
# barrier, the barrier will be satisfied. If not, then it will timeout.
|
||||
key_for_barrier = _get_key(self._count)
|
||||
logger.info('Key used for barrier is %s for process %s',
|
||||
key_for_barrier, current_process)
|
||||
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
|
||||
logger.info('Finished waiting at barrier for process %s',
|
||||
current_process)
|
||||
|
||||
if current_process == 0:
|
||||
self._on_commit_callback()
|
||||
logger.info('on_commit_callback successfully ran!')
|
||||
self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS)
|
||||
logger.info('Process 0 successfully set key %s in the kv store',
|
||||
key_for_barrier)
|
||||
|
||||
except Exception as e:
|
||||
self._exception = e
|
||||
|
||||
def _start_async_commit(self, on_commit_callback):
|
||||
self._count = next(_module_unique_count)
|
||||
|
||||
self._on_commit_callback = on_commit_callback
|
||||
self._thread = threading.Thread(target=self._thread_func)
|
||||
self._thread.start()
|
||||
|
||||
def check_for_errors(self):
|
||||
if self._exception is not None:
|
||||
# Clears self._exception so it is only raised once.
|
||||
exception = self._exception
|
||||
self._exception = None
|
||||
raise exception # pylint: disable=raising-bad-type
|
||||
|
||||
def wait_until_finished(self):
|
||||
if self._thread is not None:
|
||||
self._thread.join()
|
||||
self._thread = None
|
||||
logger.info('Thread joined successfully')
|
||||
|
||||
self.check_for_errors()
|
||||
logger.info('Error check finished successfully')
|
||||
|
||||
if self._count is not None:
|
||||
# Block until process 0 writes success value to the key value store.
|
||||
# If it fails to write it, then `blocking_key_value_get` will time out.
|
||||
get_key = _get_key(self._count)
|
||||
self._client.blocking_key_value_get(get_key, self._timeout_in_ms)
|
||||
logger.info('blocking_key_value_get on key %s was successfully '
|
||||
'completed.', get_key)
|
||||
|
||||
def _add_futures(self, futures: Sequence[asyncio.Future]):
|
||||
self._commit_futures = futures
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
|
||||
"""Responsible for serializing GDAs via TensorStore."""
|
||||
|
||||
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
|
||||
"""Serializes GlobalDeviceArrays or Arrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
* Reading/copying from the source after which the source can be modified.
|
||||
* Returns a copy future.
|
||||
* Writing/committing to the storage layer.
|
||||
* Returns a commit future.
|
||||
|
||||
In asynchronous mode, the serialization waits for the commit future to
|
||||
finish in a separate thread allowing other computation to proceed.
|
||||
|
||||
Args:
|
||||
arrays: GlobalDeviceArrays or Arrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
|
||||
Arrays.
|
||||
temp_checkpoint_dir: Temporary checkpoint directory where the checkpoints
|
||||
will be written.
|
||||
final_checkpoint_dir: Final checkpoint directory where the checkpoints
|
||||
will be moved from `temp_checkpoint_dir`.
|
||||
"""
|
||||
logger.info('Waiting for previous serialization to finish.')
|
||||
self.wait_until_finished()
|
||||
|
||||
commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
||||
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(
|
||||
async_serialize, arrays, tensorstore_specs, commit_futures)
|
||||
return await asyncio.gather(*future_writer)
|
||||
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
|
||||
|
||||
# Used in wait_until_finished to check on process != 0, if the checkpoint
|
||||
# has finished writing.
|
||||
self._start_async_commit(on_commit_callback)
|
||||
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
|
||||
self.wait_until_finished()
|
||||
return run_deserialization(shardings, tensorstore_specs,
|
||||
global_shapes, dtypes)
|
@ -23,7 +23,7 @@ from jax.config import config
|
||||
from jax._src import array
|
||||
from jax.sharding import NamedSharding, GSPMDSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
from jax.experimental.array_serialization import serialization
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
|
@ -1,10 +0,0 @@
|
||||
# Serialization and De-serialization of GlobalDeviceArray via tensorstore
|
||||
|
||||
Warning: This directory is going to move in the near future. Please use at your
|
||||
own risk.
|
||||
|
||||
To use this library, please install tensorstore and JAX.
|
||||
|
||||
```bash
|
||||
pip install -U tensorstore
|
||||
```
|
@ -13,492 +13,9 @@
|
||||
# limitations under the License.
|
||||
"""GlobalDeviceArray serialization and deserialization."""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
from functools import partial
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Callable, Sequence, Optional, Dict, Any
|
||||
from jax.experimental.array_serialization.serialization import * # noqa: F403
|
||||
|
||||
import jax
|
||||
from jax._src import distributed
|
||||
from jax._src.config import config
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import typing
|
||||
from jax.sharding import Mesh
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
|
||||
|
||||
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
|
||||
_REMOVED_VALUE = 'Value removed'
|
||||
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
||||
_module_unique_count = itertools.count()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_async_array_from_callback(
|
||||
global_shape: array.Shape,
|
||||
inp_sharding: sharding_impls.XLACompatibleSharding,
|
||||
data_callback: Callable[[array.Index], asyncio.Future],
|
||||
):
|
||||
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
|
||||
addressable_da = inp_sharding._addressable_device_assignment
|
||||
future_arrays = [data_callback(device_to_index_map[d]) # type: ignore
|
||||
for d in addressable_da]
|
||||
# Pause here and come back to `from_async_callback()` when future_arrays are
|
||||
# ready. device_put cannot happen with future_arrays.
|
||||
local_arrays = await asyncio.gather(*future_arrays)
|
||||
|
||||
dbs = [jax.device_put(array, device)
|
||||
for array, device in zip(local_arrays, addressable_da)]
|
||||
return array.make_array_from_single_device_arrays(
|
||||
global_shape, inp_sharding, dbs)
|
||||
|
||||
|
||||
def _get_metadata(arr):
|
||||
if arr.dtype == jnp.bfloat16:
|
||||
# Tensorstore uses 'bfloat16', not '<V2'.
|
||||
dtype = 'bfloat16'
|
||||
else:
|
||||
dtype = np.dtype(arr.dtype).str
|
||||
local_shape = arr.addressable_data(0).shape
|
||||
return {
|
||||
'compressor': {
|
||||
'id': 'gzip'
|
||||
},
|
||||
'shape': arr.shape,
|
||||
'chunks': np.array(np.maximum(1, local_shape)),
|
||||
'dtype': dtype,
|
||||
}
|
||||
|
||||
|
||||
def _spec_has_metadata(tree):
|
||||
if not isinstance(tree, dict):
|
||||
return False
|
||||
return 'metadata' in tree or any(
|
||||
_spec_has_metadata(subtree) for _, subtree in tree.items())
|
||||
|
||||
def _get_kvstore_for_gcs(ckpt_path: str):
|
||||
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
|
||||
if m is None:
|
||||
raise ValueError('The ckpt_path should contain the bucket name and the '
|
||||
f'file path inside the bucket. Got: {ckpt_path}')
|
||||
gcs_bucket = m.group(1)
|
||||
path_without_bucket = m.group(2)
|
||||
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}
|
||||
|
||||
def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
|
||||
is_gcs_path = ckpt_path.startswith('gs://')
|
||||
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
|
||||
# fix the path prefix to add back the stripped '/'.
|
||||
ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://')
|
||||
spec = {'driver': 'zarr', 'kvstore': {}}
|
||||
if ocdbt:
|
||||
prefix = 'gs' if is_gcs_path else 'file'
|
||||
spec['kvstore'] = {
|
||||
'driver': 'ocdbt',
|
||||
'base': f'{prefix}://{os.path.dirname(ckpt_path)}',
|
||||
'path': os.path.basename(ckpt_path),
|
||||
}
|
||||
else:
|
||||
if is_gcs_path:
|
||||
spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path)
|
||||
else:
|
||||
spec['kvstore'] = {'driver': 'file', 'path': ckpt_path}
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
# Lifted from T5X.
|
||||
class _LimitInFlightBytes:
|
||||
"""Limits in-flight bytes when reading/writing checkpoints per process."""
|
||||
|
||||
def __init__(self, num_bytes):
|
||||
self._max_bytes = num_bytes
|
||||
self._available_bytes = num_bytes
|
||||
self._cv = asyncio.Condition(lock=asyncio.Lock())
|
||||
|
||||
async def wait_for_bytes(self, requested_bytes):
|
||||
if requested_bytes >= self._max_bytes:
|
||||
raise ValueError('Requested more bytes than we reserved space for: '
|
||||
f'{requested_bytes} > {self._max_bytes}')
|
||||
async with self._cv:
|
||||
await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
|
||||
self._available_bytes -= requested_bytes
|
||||
assert self._available_bytes >= 0
|
||||
|
||||
async def release_bytes(self, requested_bytes):
|
||||
async with self._cv:
|
||||
self._available_bytes += requested_bytes
|
||||
assert self._available_bytes <= self._max_bytes
|
||||
self._cv.notify_all()
|
||||
|
||||
|
||||
async def async_serialize(
|
||||
arr_inp, tensorstore_spec, commit_future=None, context=TS_CONTEXT
|
||||
):
|
||||
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
|
||||
arr_inp.is_fully_addressable):
|
||||
raise ValueError('Passing fully addressable Arrays to a multiprocess '
|
||||
'serialization is not allowed.')
|
||||
# 'metadata' may not be present at the top level (for example, if we are using
|
||||
# a 'cast' driver).
|
||||
if not _spec_has_metadata(tensorstore_spec):
|
||||
tensorstore_spec['metadata'] = _get_metadata(arr_inp)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
open_future = ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
create=True,
|
||||
open=True,
|
||||
context=context,
|
||||
from jax.experimental.array_serialization.serialization import (
|
||||
_LimitInFlightBytes as _LimitInFlightBytes,
|
||||
_get_metadata as _get_metadata,
|
||||
)
|
||||
# Asynchronous case.
|
||||
if commit_future is not None:
|
||||
assert isinstance(commit_future, list)
|
||||
commit_future.append(open_future)
|
||||
else:
|
||||
await open_future
|
||||
|
||||
# `ts.open` runs twice for process 0 because for the first time, we just get
|
||||
# the future to be awaited upon in the background thread. The second one runs
|
||||
# with `assume_metadata=True` which does no I/O operation and returns the
|
||||
# tensorstore object.
|
||||
# For every process other than `0`, we open with `assume_metadata=True`.
|
||||
t = await ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
open=True,
|
||||
assume_metadata=True,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _write_array(shard):
|
||||
if shard.replica_id == 0:
|
||||
write_future = t[shard.index].write(shard.data)
|
||||
if commit_future is not None:
|
||||
assert isinstance(commit_future, list)
|
||||
commit_future.append(write_future.commit)
|
||||
await write_future.copy
|
||||
else:
|
||||
await write_future.commit
|
||||
|
||||
if isinstance(arr_inp, array.ArrayImpl):
|
||||
local_shards = arr_inp.addressable_shards
|
||||
else:
|
||||
local_shards = arr_inp.addressable_shards
|
||||
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
|
||||
return await asyncio.gather(*future_write_state)
|
||||
|
||||
|
||||
def run_serialization(arrays, tensorstore_specs):
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
|
||||
def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
|
||||
rank = t.rank
|
||||
num_bytes = t.dtype.numpy_dtype.itemsize
|
||||
chunk_template = t.chunk_layout.read_chunk_template
|
||||
origin = t.domain.origin
|
||||
shape = t.domain.shape
|
||||
chunk_origin = chunk_template.origin
|
||||
chunk_shape = chunk_template.shape
|
||||
|
||||
# Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
|
||||
# For those, instead of returning a near-infinite memory footprint, estimate
|
||||
# the footprint as the entire shape.
|
||||
for i in range(rank):
|
||||
if not chunk_template[i].finite:
|
||||
return t.domain.size * num_bytes
|
||||
|
||||
# Otherwise, if we have a chunked driver, estimate based on chunk size.
|
||||
for i in range(rank):
|
||||
origin_value = origin[i]
|
||||
chunk_origin_value = chunk_origin[i]
|
||||
chunk_size = chunk_shape[i]
|
||||
lower = origin_value - chunk_origin_value
|
||||
upper = origin_value + shape[i] - chunk_origin_value
|
||||
lower_aligned = lower // chunk_size * chunk_size
|
||||
upper_aligned = -(-upper // chunk_size) * chunk_size
|
||||
num_bytes *= (upper_aligned - lower_aligned)
|
||||
|
||||
return num_bytes
|
||||
|
||||
|
||||
async def async_deserialize(
|
||||
in_sharding,
|
||||
tensorstore_spec,
|
||||
global_shape=None,
|
||||
dtype=None,
|
||||
byte_limiter: Optional[_LimitInFlightBytes] = None,
|
||||
context=TS_CONTEXT,
|
||||
):
|
||||
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=context)
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
new_shard_shape = in_sharding.shard_shape(tuple(shape))
|
||||
|
||||
async def cb(index):
|
||||
# This maybe needed because the shape the array was saved with is smaller
|
||||
# than the requested shape of the array in which it will be reloaded. So
|
||||
# the extra values will be filled with 0s.
|
||||
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
|
||||
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
|
||||
restricted_domain = t.domain.intersect(requested_domain)
|
||||
|
||||
requested_bytes = estimate_read_memory_footprint(t[restricted_domain])
|
||||
|
||||
# Limit the bytes read for every shard.
|
||||
if byte_limiter is not None:
|
||||
await byte_limiter.wait_for_bytes(requested_bytes)
|
||||
|
||||
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(
|
||||
t[restricted_domain])
|
||||
|
||||
if dtype is not None:
|
||||
# Cast while reloading on process to avoid 2 copies on device if the
|
||||
# casting is done on device.
|
||||
out = out.astype(dtype)
|
||||
|
||||
if byte_limiter is not None:
|
||||
await byte_limiter.release_bytes(requested_bytes)
|
||||
return out
|
||||
|
||||
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
|
||||
|
||||
|
||||
def run_deserialization(shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
|
||||
concurrent_gb: int = 32):
|
||||
concurrent_bytes = concurrent_gb * 10**9
|
||||
|
||||
async def _run_deserializer():
|
||||
# Object should be created once per process.
|
||||
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
|
||||
|
||||
future_arrays = jax.tree_util.tree_map(
|
||||
partial(async_deserialize, byte_limiter=byte_limiter),
|
||||
shardings, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
||||
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
||||
return await asyncio.gather(*future_arrays)
|
||||
return asyncio.run(_run_deserializer())
|
||||
|
||||
|
||||
def _get_key(key: str):
|
||||
return f'tensorstore_checkpoint_{key}'
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManagerBase(metaclass=abc.ABCMeta):
|
||||
"""Interface for checkpointing GDAs asynchronously.
|
||||
|
||||
This class manages the state of an ongoing asynchronous checkpoint.
|
||||
|
||||
For example, say a checkpoint happens on every step. If you checkpoint on
|
||||
step 1 and after some computation the model is on checkpoint 2. But step 1's
|
||||
checkpoint hasn't finished committing to the storage layer yet. So until that
|
||||
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
|
||||
class allows to maintain that state.
|
||||
|
||||
Example:
|
||||
|
||||
Below is a simplified training loop:
|
||||
|
||||
```
|
||||
# Call this at the start of your program.
|
||||
jax.distributed.initialize()
|
||||
|
||||
manager = GlobalAsyncCheckpointManager()
|
||||
|
||||
# Restore checkpoint if available or initialize the train_state from
|
||||
# init_fn().
|
||||
train_state = manager.deserialize(...)
|
||||
|
||||
while ...:
|
||||
if step % num_steps_between_checkpoints == 0:
|
||||
manager.serialize(train_state, temp_checkpoint_dir=...,
|
||||
final_checkpoint_dir=...)
|
||||
train_state = train_step(train_state, input)
|
||||
# This is a non-blocking call.
|
||||
manager.check_for_errors()
|
||||
|
||||
manager.serialize(train_state, temp_checkpoint_dir=...,
|
||||
final_checkpoint_dir=...)
|
||||
# Wait before the end of the program for the checkpoint to finish. This is a
|
||||
# blocking call.
|
||||
manager.wait_until_finished()
|
||||
```
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_for_errors(self):
|
||||
"""Checks if any errors have been raised in the child thread.
|
||||
|
||||
This is a non-blocking call that can be called in the main thread.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def wait_until_finished(self):
|
||||
"""Blocks until serialization has finished."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def serialize(self, arrays, tensorstore_specs, *,
|
||||
on_commit_callback: Callable[[], None]):
|
||||
"""Serializes GDAs to TensorStore."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
|
||||
"""Deserializes GDAs from TensorStore."""
|
||||
|
||||
|
||||
class AsyncManager:
|
||||
|
||||
def __init__(self, timeout_secs=300):
|
||||
self._timeout_secs = timeout_secs
|
||||
self._timeout_in_ms = self._timeout_secs * 1000
|
||||
|
||||
self._commit_futures = None
|
||||
self._thread = None
|
||||
self._exception = None
|
||||
|
||||
if distributed.global_state.client is None:
|
||||
raise ValueError('Please initialize the distributed system via '
|
||||
'`jax.distributed.initialize()` at the start of your '
|
||||
'program.')
|
||||
self._client = distributed.global_state.client
|
||||
self._count = None
|
||||
|
||||
def __del__(self):
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logger.warning('Please add `.wait_until_finished()` in the main thread '
|
||||
'before your program finishes because there is a '
|
||||
'possibility of losing errors raised if the '
|
||||
'this class is deleted before writing is completed.')
|
||||
|
||||
def _thread_func(self):
|
||||
try:
|
||||
current_process = jax.process_index()
|
||||
logger.info('Starting commit to storage layer by process: %s',
|
||||
current_process)
|
||||
for future in self._commit_futures:
|
||||
future.result()
|
||||
logger.info('Finished committing to storage layer by process: %s',
|
||||
current_process)
|
||||
|
||||
# All processes will wait at the barrier. When all processes are at the
|
||||
# barrier, the barrier will be satisfied. If not, then it will timeout.
|
||||
key_for_barrier = _get_key(self._count)
|
||||
logger.info('Key used for barrier is %s for process %s',
|
||||
key_for_barrier, current_process)
|
||||
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
|
||||
logger.info('Finished waiting at barrier for process %s',
|
||||
current_process)
|
||||
|
||||
if current_process == 0:
|
||||
self._on_commit_callback()
|
||||
logger.info('on_commit_callback successfully ran!')
|
||||
self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS)
|
||||
logger.info('Process 0 successfully set key %s in the kv store',
|
||||
key_for_barrier)
|
||||
|
||||
except Exception as e:
|
||||
self._exception = e
|
||||
|
||||
def _start_async_commit(self, on_commit_callback):
|
||||
self._count = next(_module_unique_count)
|
||||
|
||||
self._on_commit_callback = on_commit_callback
|
||||
self._thread = threading.Thread(target=self._thread_func)
|
||||
self._thread.start()
|
||||
|
||||
def check_for_errors(self):
|
||||
if self._exception is not None:
|
||||
# Clears self._exception so it is only raised once.
|
||||
exception = self._exception
|
||||
self._exception = None
|
||||
raise exception # pylint: disable=raising-bad-type
|
||||
|
||||
def wait_until_finished(self):
|
||||
if self._thread is not None:
|
||||
self._thread.join()
|
||||
self._thread = None
|
||||
logger.info('Thread joined successfully')
|
||||
|
||||
self.check_for_errors()
|
||||
logger.info('Error check finished successfully')
|
||||
|
||||
if self._count is not None:
|
||||
# Block until process 0 writes success value to the key value store.
|
||||
# If it fails to write it, then `blocking_key_value_get` will time out.
|
||||
get_key = _get_key(self._count)
|
||||
self._client.blocking_key_value_get(get_key, self._timeout_in_ms)
|
||||
logger.info('blocking_key_value_get on key %s was successfully '
|
||||
'completed.', get_key)
|
||||
|
||||
def _add_futures(self, futures: Sequence[asyncio.Future]):
|
||||
self._commit_futures = futures
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
|
||||
"""Responsible for serializing GDAs via TensorStore."""
|
||||
|
||||
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
|
||||
"""Serializes GlobalDeviceArrays or Arrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
* Reading/copying from the source after which the source can be modified.
|
||||
* Returns a copy future.
|
||||
* Writing/committing to the storage layer.
|
||||
* Returns a commit future.
|
||||
|
||||
In asynchronous mode, the serialization waits for the commit future to
|
||||
finish in a separate thread allowing other computation to proceed.
|
||||
|
||||
Args:
|
||||
arrays: GlobalDeviceArrays or Arrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
|
||||
Arrays.
|
||||
temp_checkpoint_dir: Temporary checkpoint directory where the checkpoints
|
||||
will be written.
|
||||
final_checkpoint_dir: Final checkpoint directory where the checkpoints
|
||||
will be moved from `temp_checkpoint_dir`.
|
||||
"""
|
||||
logger.info('Waiting for previous serialization to finish.')
|
||||
self.wait_until_finished()
|
||||
|
||||
commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
||||
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(
|
||||
async_serialize, arrays, tensorstore_specs, commit_futures)
|
||||
return await asyncio.gather(*future_writer)
|
||||
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
|
||||
|
||||
# Used in wait_until_finished to check on process != 0, if the checkpoint
|
||||
# has finished writing.
|
||||
self._start_async_commit(on_commit_callback)
|
||||
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
|
||||
self.wait_until_finished()
|
||||
return run_deserialization(shardings, tensorstore_specs,
|
||||
global_shapes, dtypes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user