Support deserializing from non-chunked storage such as tensorstore.array.

PiperOrigin-RevId: 473080245
This commit is contained in:
reinerp 2022-09-08 14:07:55 -07:00 committed by jax authors
parent 0400db959b
commit fc84b27289
2 changed files with 22 additions and 0 deletions

View File

@ -213,6 +213,14 @@ def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
chunk_origin = chunk_template.origin
chunk_shape = chunk_template.shape
# Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
# For those, instead of returning a near-infinite memory footprint, estimate
# the footprint as the entire shape.
for i in range(rank):
if not chunk_template[i].finite:
return t.domain.size * num_bytes
# Otherwise, if we have a chunked driver, estimate based on chunk size.
for i in range(rank):
origin_value = origin[i]
chunk_origin_value = chunk_origin[i]
@ -222,6 +230,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
lower_aligned = lower // chunk_size * chunk_size
upper_aligned = -(-upper // chunk_size) * chunk_size
num_bytes *= (upper_aligned - lower_aligned)
return num_bytes

View File

@ -27,6 +27,7 @@ from jax.experimental import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.gda_serialization import serialization
import numpy as np
import tensorstore as ts
config.parse_flags_with_absl()
@ -218,6 +219,18 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
def test_deserialize_tensorstore_array(self):
global_mesh = jtu.create_global_mesh((2,), ('x'))
data = np.arange(1024)
tspec = ts.array(data).spec()
m1, = serialization.run_deserialization(
[global_mesh],
[P(None)],
[tspec]
)
for l in m1.local_shards:
self.assertArraysEqual(np.asarray(l.data), data)
def test_spec_has_metadata(self):
spec = {
'a': {