Remove path from the serde API as tspec encompasses those things.

PiperOrigin-RevId: 425727733
This commit is contained in:
Yash Katariya 2022-02-01 15:16:24 -08:00 committed by jax authors
parent 4e47de66fc
commit dcca99b052
2 changed files with 13 additions and 15 deletions

View File

@ -80,10 +80,9 @@ def get_tensorstore_spec(ckpt_path: str):
return spec
async def async_serialize(ckpt_path: str, gda: gda.GlobalDeviceArray,
tensorstore_spec):
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
if not tensorstore_spec.get('metadata'):
tensorstore_spec['metadata'] = _get_metadata(gda)
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
t = await ts.open(
ts.Spec(tensorstore_spec),
@ -97,19 +96,19 @@ async def async_serialize(ckpt_path: str, gda: gda.GlobalDeviceArray,
if shard.replica_id == 0:
await t[shard.index].write(shard.data)
future_write_state = jax.tree_util.tree_map(_write_array, tuple(gda.local_shards))
future_write_state = jax.tree_util.tree_map(_write_array,
tuple(gda_inp.local_shards))
return await asyncio.gather(*future_write_state)
def run_serialization(ckpt_paths, gdas, tensorstore_specs):
def run_serialization(gdas, tensorstore_specs):
async def _run_serializer():
future_writer = jax.tree_map(async_serialize, ckpt_paths, gdas,
tensorstore_specs)
future_writer = jax.tree_map(async_serialize, gdas, tensorstore_specs)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec):
async def async_deserialize(mesh, mesh_axes, tensorstore_spec):
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
async def cb(index):
@ -118,9 +117,9 @@ async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec):
return await create_async_gda_from_callback(t.shape, mesh, mesh_axes, cb)
def run_deserialization(ckpt_paths, global_meshes, mesh_axes, tensorstore_specs):
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs):
async def _run_deserializer():
future_gdas = jax.tree_map(async_deserialize, ckpt_paths, global_meshes,
mesh_axes, tensorstore_specs)
future_gdas = jax.tree_map(async_deserialize, global_meshes, mesh_axes,
tensorstore_specs)
return await asyncio.gather(*future_gdas)
return asyncio.run(_run_deserializer())

View File

@ -64,11 +64,10 @@ class CheckpointTest(jtu.JaxTestCase):
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization(ckpt_paths, [gda1, gda2], tspecs)
serialization.run_serialization([gda1, gda2], tspecs)
m1, m2 = serialization.run_deserialization(ckpt_paths,
[global_mesh, global_mesh],
[mesh_axes, ['x']], tspecs)
m1, m2 = serialization.run_deserialization(
[global_mesh, global_mesh], [mesh_axes, ['x']], tspecs)
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
np.array([[0], [2]]))