mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a limiter for in-flight bytes. Read a shard from TensorStore if there are enough bytes are available. This only works for deserialization right now.
PiperOrigin-RevId: 458586521
This commit is contained in:
parent
f12af93258
commit
8a23605462
@ -15,9 +15,10 @@
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from functools import partial
|
||||
import re
|
||||
import threading
|
||||
from typing import Callable, Sequence
|
||||
from typing import Callable, Sequence, Optional
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
@ -98,6 +99,28 @@ def get_tensorstore_spec(ckpt_path: str):
|
||||
return spec
|
||||
|
||||
|
||||
# Lifted from T5X.
|
||||
class _LimitInFlightBytes:
|
||||
"""Limits in-flight bytes when reading/writing checkpoints per process."""
|
||||
|
||||
def __init__(self, num_bytes):
|
||||
self._max_bytes = num_bytes
|
||||
self._available_bytes = num_bytes
|
||||
self._cv = asyncio.Condition(lock=asyncio.Lock())
|
||||
|
||||
async def wait_for_bytes(self, requested_bytes):
|
||||
async with self._cv:
|
||||
await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
|
||||
self._available_bytes -= requested_bytes
|
||||
assert self._available_bytes >= 0
|
||||
|
||||
async def release_bytes(self, requested_bytes):
|
||||
async with self._cv:
|
||||
self._available_bytes += requested_bytes
|
||||
assert self._available_bytes <= self._max_bytes
|
||||
self._cv.notify_all()
|
||||
|
||||
|
||||
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
|
||||
commit_future=None):
|
||||
# 'metadata' may not be present at the top level (for example, if we are using
|
||||
@ -145,41 +168,76 @@ def run_serialization(gdas, tensorstore_specs):
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
|
||||
global_shape=None, dtype=None):
|
||||
t = ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT).result()
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
requires_padding = prod(shape) > prod(t.shape)
|
||||
def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
|
||||
rank = t.rank
|
||||
num_bytes = t.dtype.numpy_dtype.itemsize
|
||||
if rank == 0:
|
||||
return num_bytes
|
||||
chunk_template = t.chunk_layout.read_chunk_template
|
||||
origin = t.domain.origin
|
||||
shape = t.domain.shape
|
||||
chunk_origin = chunk_template.origin
|
||||
chunk_shape = chunk_template.shape
|
||||
|
||||
if requires_padding:
|
||||
new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)
|
||||
for i in range(rank):
|
||||
origin_value = origin[i]
|
||||
chunk_origin_value = chunk_origin[i]
|
||||
chunk_size = chunk_shape[i]
|
||||
lower = origin_value - chunk_origin_value
|
||||
upper = origin_value + shape[i] - chunk_origin_value
|
||||
lower_aligned = lower // chunk_size * chunk_size
|
||||
upper_aligned = -(-upper // chunk_size) * chunk_size
|
||||
num_bytes *= (upper_aligned - lower_aligned)
|
||||
return num_bytes
|
||||
|
||||
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
|
||||
global_shape=None, dtype=None,
|
||||
byte_limiter: Optional[_LimitInFlightBytes] = None):
|
||||
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT)
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)
|
||||
|
||||
async def cb(index):
|
||||
if requires_padding:
|
||||
# This is needed because the shape the array was saved with is smaller
|
||||
# than the requested shape of the array in which it will be reloaded. So
|
||||
# the extra values will be filled with 0s.
|
||||
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
|
||||
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
|
||||
restricted_domain = t.domain.intersect(requested_domain)
|
||||
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(t[restricted_domain])
|
||||
else:
|
||||
out = await t[index].read()
|
||||
# This maybe needed because the shape the array was saved with is smaller
|
||||
# than the requested shape of the array in which it will be reloaded. So
|
||||
# the extra values will be filled with 0s.
|
||||
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
|
||||
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
|
||||
restricted_domain = t.domain.intersect(requested_domain)
|
||||
|
||||
requested_bytes = estimate_read_memory_footprint(t[restricted_domain])
|
||||
|
||||
# Limit the bytes read for every shard.
|
||||
if byte_limiter is not None:
|
||||
await byte_limiter.wait_for_bytes(requested_bytes)
|
||||
|
||||
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(
|
||||
t[restricted_domain])
|
||||
|
||||
if dtype is not None:
|
||||
# Cast while reloading on process to avoid 2 copies on device if the
|
||||
# casting is done on device.
|
||||
return out.astype(dtype)
|
||||
|
||||
if byte_limiter is not None:
|
||||
await byte_limiter.release_bytes(requested_bytes)
|
||||
return out
|
||||
|
||||
return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)
|
||||
|
||||
|
||||
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None, dtypes=None):
|
||||
global_shapes=None, dtypes=None, concurrent_gb=32):
|
||||
concurrent_bytes = concurrent_gb * 10**9
|
||||
|
||||
async def _run_deserializer():
|
||||
# Object should be created once per process.
|
||||
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
|
||||
|
||||
future_gdas = jax.tree_map(
|
||||
async_deserialize, global_meshes, mesh_axes, tensorstore_specs,
|
||||
partial(async_deserialize, byte_limiter=byte_limiter),
|
||||
global_meshes, mesh_axes, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
||||
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
||||
return await asyncio.gather(*future_gdas)
|
||||
|
@ -129,6 +129,30 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
|
||||
|
||||
def test_checkpointing_scalar(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
global_input_shape = ()
|
||||
data = np.array(4)
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
P(None), lambda idx: data[idx])
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[jtu.create_global_mesh((2,), ('x'))],
|
||||
[P(None)],
|
||||
tspecs,
|
||||
[()],
|
||||
[np.float32]
|
||||
)
|
||||
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(l.data.to_py(), data.astype(np.float32))
|
||||
|
||||
def test_spec_has_metadata(self):
|
||||
spec = {
|
||||
'a': {
|
||||
|
Loading…
x
Reference in New Issue
Block a user