mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Roll forward: Improve tensorstore I/O efficiency
Reverts 5462d2e3930c6202ffd66aea37d5876cc5f78dbb PiperOrigin-RevId: 650332835
This commit is contained in:
parent
2036325b04
commit
d60e2201e7
@ -25,7 +25,7 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import jax
|
||||
from jax._src import array
|
||||
@ -191,6 +191,7 @@ async def async_serialize(
|
||||
context=TS_CONTEXT,
|
||||
primary_host: int | None = 0,
|
||||
replica_id: int = 0,
|
||||
transaction: Optional[ts.Transaction] = None,
|
||||
):
|
||||
"""Serialize an array using TensorStore.
|
||||
|
||||
@ -204,8 +205,10 @@ async def async_serialize(
|
||||
primary_host: Primary host, which indicates the host that will be treated as
|
||||
the "leader". If None, all hosts are treated as the primary. DO NOT USE
|
||||
unless you are sure you know what you are doing.
|
||||
replica_id: Allows overriding the shard replica id that will be saved.
|
||||
DO NOT USE unless you are sure you know what you are doing.
|
||||
replica_id: Allows overriding the shard replica id that will be saved. DO
|
||||
NOT USE unless you are sure you know what you are doing.
|
||||
transaction: TensorStore transaction to use for opening and writing the
|
||||
array. If not specified, a non-transactional write will be used.
|
||||
"""
|
||||
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
|
||||
arr_inp.is_fully_addressable):
|
||||
@ -232,6 +235,7 @@ async def async_serialize(
|
||||
create=True,
|
||||
open=True,
|
||||
context=context,
|
||||
transaction=transaction,
|
||||
)
|
||||
# Asynchronous case.
|
||||
if commit_future is not None:
|
||||
@ -251,11 +255,22 @@ async def async_serialize(
|
||||
open=True,
|
||||
assume_metadata=True,
|
||||
context=context,
|
||||
transaction=transaction,
|
||||
)
|
||||
|
||||
async def _write_array(shard):
|
||||
if shard.replica_id == replica_id:
|
||||
write_future = t[shard.index].write(shard.data)
|
||||
write_future = t[shard.index].write(
|
||||
shard.data,
|
||||
# Avoid additional copy of input array into the TensorStore chunk
|
||||
# cache. If `arr_inp` is a jax.Array, the result of converting
|
||||
# it to a NumPy array, as is done internally by TensorStore, is
|
||||
# guaranteed to be immutable and therefore it is safe to retain a
|
||||
# reference indefinitely.
|
||||
can_reference_source_data_indefinitely=isinstance(
|
||||
arr_inp, array.ArrayImpl
|
||||
),
|
||||
)
|
||||
if commit_future is not None:
|
||||
assert isinstance(commit_future, list)
|
||||
commit_future.append(write_future.commit)
|
||||
@ -567,7 +582,14 @@ class AsyncManager:
|
||||
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
|
||||
"""Responsible for serializing GDAs via TensorStore."""
|
||||
|
||||
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
|
||||
def serialize(
|
||||
self,
|
||||
arrays,
|
||||
tensorstore_specs,
|
||||
*,
|
||||
on_commit_callback,
|
||||
transaction: Optional[ts.Transaction] = None,
|
||||
):
|
||||
"""Serializes Arrays or Arrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
@ -587,32 +609,52 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
have finished writing their checkpoints to disk. Filesystems where
|
||||
atomic rename operations are supported, you can rename from the
|
||||
temporary directory to the final directory. On GCS, you write to the
|
||||
final directory directly and in `on_commit_callback` you write a
|
||||
success file indicating that the serialization was successful because
|
||||
GCS does not support atomic rename operations.
|
||||
final directory directly and in `on_commit_callback` you write a success
|
||||
file indicating that the serialization was successful because GCS does
|
||||
not support atomic rename operations.
|
||||
transaction: Optional TensorStore transaction to use.
|
||||
"""
|
||||
logger.info('Waiting for previous serialization to finish.')
|
||||
self.wait_until_finished()
|
||||
|
||||
commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
||||
commit_futures: list[ts.Future] = []
|
||||
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(
|
||||
async_serialize, arrays, tensorstore_specs, commit_futures)
|
||||
lambda arr_inp, tensorstore_spec: async_serialize(
|
||||
arr_inp,
|
||||
tensorstore_spec,
|
||||
commit_future=commit_futures,
|
||||
transaction=transaction,
|
||||
),
|
||||
arrays,
|
||||
tensorstore_specs,
|
||||
)
|
||||
return await asyncio.gather(*future_writer)
|
||||
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
|
||||
self._add_futures(commit_futures)
|
||||
|
||||
# Used in wait_until_finished to check on process != 0, if the checkpoint
|
||||
# has finished writing.
|
||||
self._start_async_commit(on_commit_callback)
|
||||
|
||||
def serialize_with_paths(self, arrays: Sequence[jax.Array],
|
||||
paths: Sequence[str], *, on_commit_callback):
|
||||
def serialize_with_paths(
|
||||
self,
|
||||
arrays: Sequence[jax.Array],
|
||||
paths: Sequence[str],
|
||||
*,
|
||||
on_commit_callback,
|
||||
transaction: Optional[ts.Transaction] = None,
|
||||
):
|
||||
tspecs = jax.tree.map(get_tensorstore_spec, paths)
|
||||
self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback)
|
||||
self.serialize(
|
||||
arrays,
|
||||
tspecs,
|
||||
on_commit_callback=on_commit_callback,
|
||||
transaction=transaction,
|
||||
)
|
||||
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding | Layout],
|
||||
tensorstore_specs: Sequence[dict[str, Any]],
|
||||
|
@ -96,6 +96,41 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertGreater(peak, 30_000_000)
|
||||
tm.stop()
|
||||
|
||||
def test_memory_consumption_for_save(self):
|
||||
global_mesh = jtu.create_global_mesh((1, 1), ('x', 'y'))
|
||||
inp_shape = (16 * 1024, 16 * 1024)
|
||||
pspec = P('x', 'y')
|
||||
num = math.prod(inp_shape)
|
||||
sharding = NamedSharding(global_mesh, pspec)
|
||||
src = jnp.arange(num, dtype=np.int32).reshape(inp_shape)
|
||||
inp = array.make_array_from_callback(
|
||||
inp_shape, sharding, lambda idx: src[idx]
|
||||
)
|
||||
ckpt_dir = pathlib.Path(self.create_tempdir('memprofsave').full_path)
|
||||
tspec = serialization.get_tensorstore_spec(str(ckpt_dir))
|
||||
tspec['metadata'] = {
|
||||
'shape': inp.shape,
|
||||
'compressor': None,
|
||||
'chunks': inp.shape,
|
||||
}
|
||||
|
||||
is_cpu = jtu.test_device_matches(['cpu'])
|
||||
tm.start()
|
||||
try:
|
||||
manager = serialization.GlobalAsyncCheckpointManager()
|
||||
manager.serialize(
|
||||
[inp],
|
||||
[tspec],
|
||||
on_commit_callback=partial(
|
||||
self._on_commit_callback, ckpt_dir, ckpt_dir
|
||||
),
|
||||
)
|
||||
manager.wait_until_finished()
|
||||
unused_current, peak = tm.get_traced_memory()
|
||||
self.assertLess(peak, src.nbytes * (1 * (not is_cpu) + 0.5))
|
||||
finally:
|
||||
tm.stop()
|
||||
|
||||
def test_checkpointing_with_path_variant(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
@ -196,6 +231,100 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_ocdbt_transaction(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
pspec = P('x', 'y')
|
||||
num = math.prod(inp_shape)
|
||||
|
||||
# First Array
|
||||
global_input_data1 = np.arange(num, dtype=np.int32).reshape(inp_shape)
|
||||
a1 = array.make_array_from_callback(
|
||||
inp_shape,
|
||||
NamedSharding(global_mesh, pspec),
|
||||
lambda idx: global_input_data1[idx],
|
||||
)
|
||||
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
|
||||
ckpt_path1 = ckpt_dir / 'first'
|
||||
|
||||
# Second Array
|
||||
global_input_data2 = np.arange(num, num + num, dtype=np.int32).reshape(
|
||||
inp_shape
|
||||
)
|
||||
a2 = array.make_array_from_callback(
|
||||
inp_shape,
|
||||
NamedSharding(global_mesh, pspec),
|
||||
lambda idx: global_input_data2[idx],
|
||||
)
|
||||
ckpt_path2 = ckpt_dir / 'second'
|
||||
|
||||
# Third Array
|
||||
def cb3(_):
|
||||
return np.array([], dtype=np.float32)
|
||||
|
||||
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
|
||||
a3 = array.make_array_from_callback(
|
||||
(0,), NamedSharding(global_mesh1d, P(None)), cb3
|
||||
)
|
||||
ckpt_path3 = ckpt_dir / 'third'
|
||||
|
||||
ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)]
|
||||
tspecs = jax.tree_util.tree_map(
|
||||
lambda p: serialization.get_tensorstore_spec(p, ocdbt=True), ckpt_paths
|
||||
)
|
||||
|
||||
manager = serialization.GlobalAsyncCheckpointManager()
|
||||
with ts.Transaction(atomic=True) as transaction:
|
||||
manager.serialize(
|
||||
[a1, a2, a3],
|
||||
tspecs,
|
||||
on_commit_callback=partial(
|
||||
self._on_commit_callback, ckpt_dir, ckpt_dir
|
||||
),
|
||||
transaction=transaction,
|
||||
)
|
||||
manager.wait_until_finished()
|
||||
|
||||
m1, m2, m3 = serialization.run_deserialization(
|
||||
[
|
||||
NamedSharding(global_mesh, pspec),
|
||||
NamedSharding(global_mesh, P('x')),
|
||||
NamedSharding(global_mesh1d, P(None)),
|
||||
],
|
||||
tspecs,
|
||||
)
|
||||
|
||||
self.assertIsInstance(m1, array.ArrayImpl)
|
||||
self.assertArraysEqual(
|
||||
np.asarray(m1.addressable_shards[0].data),
|
||||
np.array([[0], [2]], dtype=np.int32),
|
||||
)
|
||||
self.assertArraysEqual(
|
||||
np.asarray(m1.addressable_shards[1].data),
|
||||
np.array([[1], [3]], dtype=np.int32),
|
||||
)
|
||||
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
self.assertIsInstance(m2, array.ArrayImpl)
|
||||
self.assertArraysEqual(
|
||||
np.asarray(m2.addressable_shards[0].data),
|
||||
np.array([[16, 17], [18, 19]], dtype=np.int32),
|
||||
)
|
||||
self.assertArraysEqual(
|
||||
np.asarray(m2.addressable_shards[1].data),
|
||||
np.array([[16, 17], [18, 19]], dtype=np.int32),
|
||||
)
|
||||
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(m2.dtype, np.int32)
|
||||
|
||||
self.assertIsInstance(m3, array.ArrayImpl)
|
||||
for i, s in enumerate(m3.addressable_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
@parameterized.product(input_dtype=[np.int32, jnp.bfloat16])
|
||||
def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user