Roll forward: Improve tensorstore I/O efficiency

Reverts 5462d2e3930c6202ffd66aea37d5876cc5f78dbb

PiperOrigin-RevId: 650332835
This commit is contained in:
jax authors 2024-07-08 12:12:17 -07:00 committed by jax authors
parent 2036325b04
commit d60e2201e7
2 changed files with 185 additions and 14 deletions

View File

@ -25,7 +25,7 @@ import os
import re
import threading
import time
from typing import Any
from typing import Any, Optional
import jax
from jax._src import array
@ -191,6 +191,7 @@ async def async_serialize(
context=TS_CONTEXT,
primary_host: int | None = 0,
replica_id: int = 0,
transaction: Optional[ts.Transaction] = None,
):
"""Serialize an array using TensorStore.
@ -204,8 +205,10 @@ async def async_serialize(
primary_host: Primary host, which indicates the host that will be treated as
the "leader". If None, all hosts are treated as the primary. DO NOT USE
unless you are sure you know what you are doing.
replica_id: Allows overriding the shard replica id that will be saved.
DO NOT USE unless you are sure you know what you are doing.
replica_id: Allows overriding the shard replica id that will be saved. DO
NOT USE unless you are sure you know what you are doing.
transaction: TensorStore transaction to use for opening and writing the
array. If not specified, a non-transactional write will be used.
"""
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
arr_inp.is_fully_addressable):
@ -232,6 +235,7 @@ async def async_serialize(
create=True,
open=True,
context=context,
transaction=transaction,
)
# Asynchronous case.
if commit_future is not None:
@ -251,11 +255,22 @@ async def async_serialize(
open=True,
assume_metadata=True,
context=context,
transaction=transaction,
)
async def _write_array(shard):
if shard.replica_id == replica_id:
write_future = t[shard.index].write(shard.data)
write_future = t[shard.index].write(
shard.data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=isinstance(
arr_inp, array.ArrayImpl
),
)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(write_future.commit)
@ -567,7 +582,14 @@ class AsyncManager:
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
"""Responsible for serializing GDAs via TensorStore."""
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
def serialize(
self,
arrays,
tensorstore_specs,
*,
on_commit_callback,
transaction: Optional[ts.Transaction] = None,
):
"""Serializes Arrays or Arrays via TensorStore asynchronously.
TensorStore writes to a storage layer in 2 steps:
@ -587,32 +609,52 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
have finished writing their checkpoints to disk. Filesystems where
atomic rename operations are supported, you can rename from the
temporary directory to the final directory. On GCS, you write to the
final directory directly and in `on_commit_callback` you write a
success file indicating that the serialization was successful because
GCS does not support atomic rename operations.
final directory directly and in `on_commit_callback` you write a success
file indicating that the serialization was successful because GCS does
not support atomic rename operations.
transaction: Optional TensorStore transaction to use.
"""
logger.info('Waiting for previous serialization to finish.')
self.wait_until_finished()
commit_futures = [[] for _ in range(len(tensorstore_specs))]
commit_futures: list[ts.Future] = []
async def _run_serializer():
future_writer = jax.tree_util.tree_map(
async_serialize, arrays, tensorstore_specs, commit_futures)
lambda arr_inp, tensorstore_spec: async_serialize(
arr_inp,
tensorstore_spec,
commit_future=commit_futures,
transaction=transaction,
),
arrays,
tensorstore_specs,
)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
self._add_futures(commit_futures)
# Used in wait_until_finished to check on process != 0, if the checkpoint
# has finished writing.
self._start_async_commit(on_commit_callback)
def serialize_with_paths(self, arrays: Sequence[jax.Array],
paths: Sequence[str], *, on_commit_callback):
def serialize_with_paths(
self,
arrays: Sequence[jax.Array],
paths: Sequence[str],
*,
on_commit_callback,
transaction: Optional[ts.Transaction] = None,
):
tspecs = jax.tree.map(get_tensorstore_spec, paths)
self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback)
self.serialize(
arrays,
tspecs,
on_commit_callback=on_commit_callback,
transaction=transaction,
)
def deserialize(self, shardings: Sequence[sharding.Sharding | Layout],
tensorstore_specs: Sequence[dict[str, Any]],

View File

@ -96,6 +96,41 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertGreater(peak, 30_000_000)
tm.stop()
def test_memory_consumption_for_save(self):
global_mesh = jtu.create_global_mesh((1, 1), ('x', 'y'))
inp_shape = (16 * 1024, 16 * 1024)
pspec = P('x', 'y')
num = math.prod(inp_shape)
sharding = NamedSharding(global_mesh, pspec)
src = jnp.arange(num, dtype=np.int32).reshape(inp_shape)
inp = array.make_array_from_callback(
inp_shape, sharding, lambda idx: src[idx]
)
ckpt_dir = pathlib.Path(self.create_tempdir('memprofsave').full_path)
tspec = serialization.get_tensorstore_spec(str(ckpt_dir))
tspec['metadata'] = {
'shape': inp.shape,
'compressor': None,
'chunks': inp.shape,
}
is_cpu = jtu.test_device_matches(['cpu'])
tm.start()
try:
manager = serialization.GlobalAsyncCheckpointManager()
manager.serialize(
[inp],
[tspec],
on_commit_callback=partial(
self._on_commit_callback, ckpt_dir, ckpt_dir
),
)
manager.wait_until_finished()
unused_current, peak = tm.get_traced_memory()
self.assertLess(peak, src.nbytes * (1 * (not is_cpu) + 0.5))
finally:
tm.stop()
def test_checkpointing_with_path_variant(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_shape = (8, 2)
@ -196,6 +231,100 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32))
self.assertEqual(m3.dtype, np.float32)
def test_checkpointing_ocdbt_transaction(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_shape = (8, 2)
pspec = P('x', 'y')
num = math.prod(inp_shape)
# First Array
global_input_data1 = np.arange(num, dtype=np.int32).reshape(inp_shape)
a1 = array.make_array_from_callback(
inp_shape,
NamedSharding(global_mesh, pspec),
lambda idx: global_input_data1[idx],
)
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
ckpt_path1 = ckpt_dir / 'first'
# Second Array
global_input_data2 = np.arange(num, num + num, dtype=np.int32).reshape(
inp_shape
)
a2 = array.make_array_from_callback(
inp_shape,
NamedSharding(global_mesh, pspec),
lambda idx: global_input_data2[idx],
)
ckpt_path2 = ckpt_dir / 'second'
# Third Array
def cb3(_):
return np.array([], dtype=np.float32)
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
a3 = array.make_array_from_callback(
(0,), NamedSharding(global_mesh1d, P(None)), cb3
)
ckpt_path3 = ckpt_dir / 'third'
ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)]
tspecs = jax.tree_util.tree_map(
lambda p: serialization.get_tensorstore_spec(p, ocdbt=True), ckpt_paths
)
manager = serialization.GlobalAsyncCheckpointManager()
with ts.Transaction(atomic=True) as transaction:
manager.serialize(
[a1, a2, a3],
tspecs,
on_commit_callback=partial(
self._on_commit_callback, ckpt_dir, ckpt_dir
),
transaction=transaction,
)
manager.wait_until_finished()
m1, m2, m3 = serialization.run_deserialization(
[
NamedSharding(global_mesh, pspec),
NamedSharding(global_mesh, P('x')),
NamedSharding(global_mesh1d, P(None)),
],
tspecs,
)
self.assertIsInstance(m1, array.ArrayImpl)
self.assertArraysEqual(
np.asarray(m1.addressable_shards[0].data),
np.array([[0], [2]], dtype=np.int32),
)
self.assertArraysEqual(
np.asarray(m1.addressable_shards[1].data),
np.array([[1], [3]], dtype=np.int32),
)
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)
self.assertIsInstance(m2, array.ArrayImpl)
self.assertArraysEqual(
np.asarray(m2.addressable_shards[0].data),
np.array([[16, 17], [18, 19]], dtype=np.int32),
)
self.assertArraysEqual(
np.asarray(m2.addressable_shards[1].data),
np.array([[16, 17], [18, 19]], dtype=np.int32),
)
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32)
self.assertIsInstance(m3, array.ArrayImpl)
for i, s in enumerate(m3.addressable_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32))
self.assertEqual(m3.dtype, np.float32)
@parameterized.product(input_dtype=[np.int32, jnp.bfloat16])
def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))