Add a limiter for in-flight bytes. Read a shard from TensorStore if there are enough bytes are available. This only works for deserialization right now.

PiperOrigin-RevId: 458586521
This commit is contained in:
Yash Katariya 2022-07-01 19:26:32 -07:00 committed by jax authors
parent f12af93258
commit 8a23605462
2 changed files with 102 additions and 20 deletions

View File

@ -15,9 +15,10 @@
import abc
import asyncio
from functools import partial
import re
import threading
from typing import Callable, Sequence
from typing import Callable, Sequence, Optional
from absl import logging
import jax
@ -98,6 +99,28 @@ def get_tensorstore_spec(ckpt_path: str):
return spec
# Lifted from T5X.
class _LimitInFlightBytes:
"""Limits in-flight bytes when reading/writing checkpoints per process."""
def __init__(self, num_bytes):
self._max_bytes = num_bytes
self._available_bytes = num_bytes
self._cv = asyncio.Condition(lock=asyncio.Lock())
async def wait_for_bytes(self, requested_bytes):
async with self._cv:
await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
self._available_bytes -= requested_bytes
assert self._available_bytes >= 0
async def release_bytes(self, requested_bytes):
async with self._cv:
self._available_bytes += requested_bytes
assert self._available_bytes <= self._max_bytes
self._cv.notify_all()
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
commit_future=None):
# 'metadata' may not be present at the top level (for example, if we are using
@ -145,41 +168,76 @@ def run_serialization(gdas, tensorstore_specs):
asyncio.run(_run_serializer())
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
global_shape=None, dtype=None):
t = ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT).result()
shape = t.shape if global_shape is None else global_shape
requires_padding = prod(shape) > prod(t.shape)
def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
rank = t.rank
num_bytes = t.dtype.numpy_dtype.itemsize
if rank == 0:
return num_bytes
chunk_template = t.chunk_layout.read_chunk_template
origin = t.domain.origin
shape = t.domain.shape
chunk_origin = chunk_template.origin
chunk_shape = chunk_template.shape
if requires_padding:
new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)
for i in range(rank):
origin_value = origin[i]
chunk_origin_value = chunk_origin[i]
chunk_size = chunk_shape[i]
lower = origin_value - chunk_origin_value
upper = origin_value + shape[i] - chunk_origin_value
lower_aligned = lower // chunk_size * chunk_size
upper_aligned = -(-upper // chunk_size) * chunk_size
num_bytes *= (upper_aligned - lower_aligned)
return num_bytes
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
global_shape=None, dtype=None,
byte_limiter: Optional[_LimitInFlightBytes] = None):
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT)
shape = t.shape if global_shape is None else global_shape
new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)
async def cb(index):
if requires_padding:
# This is needed because the shape the array was saved with is smaller
# than the requested shape of the array in which it will be reloaded. So
# the extra values will be filled with 0s.
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(t[restricted_domain])
else:
out = await t[index].read()
# This maybe needed because the shape the array was saved with is smaller
# than the requested shape of the array in which it will be reloaded. So
# the extra values will be filled with 0s.
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
requested_bytes = estimate_read_memory_footprint(t[restricted_domain])
# Limit the bytes read for every shard.
if byte_limiter is not None:
await byte_limiter.wait_for_bytes(requested_bytes)
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(
t[restricted_domain])
if dtype is not None:
# Cast while reloading on process to avoid 2 copies on device if the
# casting is done on device.
return out.astype(dtype)
if byte_limiter is not None:
await byte_limiter.release_bytes(requested_bytes)
return out
return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None, dtypes=None):
global_shapes=None, dtypes=None, concurrent_gb=32):
concurrent_bytes = concurrent_gb * 10**9
async def _run_deserializer():
# Object should be created once per process.
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
future_gdas = jax.tree_map(
async_deserialize, global_meshes, mesh_axes, tensorstore_specs,
partial(async_deserialize, byte_limiter=byte_limiter),
global_meshes, mesh_axes, tensorstore_specs,
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
return await asyncio.gather(*future_gdas)

View File

@ -129,6 +129,30 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
def test_checkpointing_scalar(self):
global_mesh = jtu.create_global_mesh((2,), ('x'))
global_input_shape = ()
data = np.array(4)
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
P(None), lambda idx: data[idx])
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir1)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([gda1], tspecs)
m1, = serialization.run_deserialization(
[jtu.create_global_mesh((2,), ('x'))],
[P(None)],
tspecs,
[()],
[np.float32]
)
for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), data.astype(np.float32))
def test_spec_has_metadata(self):
spec = {
'a': {