mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support deserializing from non-chunked storage such as tensorstore.array.
PiperOrigin-RevId: 473080245
This commit is contained in:
parent
0400db959b
commit
fc84b27289
@ -213,6 +213,14 @@ def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
|
||||
chunk_origin = chunk_template.origin
|
||||
chunk_shape = chunk_template.shape
|
||||
|
||||
# Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
|
||||
# For those, instead of returning a near-infinite memory footprint, estimate
|
||||
# the footprint as the entire shape.
|
||||
for i in range(rank):
|
||||
if not chunk_template[i].finite:
|
||||
return t.domain.size * num_bytes
|
||||
|
||||
# Otherwise, if we have a chunked driver, estimate based on chunk size.
|
||||
for i in range(rank):
|
||||
origin_value = origin[i]
|
||||
chunk_origin_value = chunk_origin[i]
|
||||
@ -222,6 +230,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
|
||||
lower_aligned = lower // chunk_size * chunk_size
|
||||
upper_aligned = -(-upper // chunk_size) * chunk_size
|
||||
num_bytes *= (upper_aligned - lower_aligned)
|
||||
|
||||
return num_bytes
|
||||
|
||||
|
||||
|
@ -27,6 +27,7 @@ from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -218,6 +219,18 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
|
||||
|
||||
def test_deserialize_tensorstore_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
||||
data = np.arange(1024)
|
||||
tspec = ts.array(data).spec()
|
||||
m1, = serialization.run_deserialization(
|
||||
[global_mesh],
|
||||
[P(None)],
|
||||
[tspec]
|
||||
)
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), data)
|
||||
|
||||
def test_spec_has_metadata(self):
|
||||
spec = {
|
||||
'a': {
|
||||
|
Loading…
x
Reference in New Issue
Block a user