mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove path from the serde API as tspec encompasses those things.
PiperOrigin-RevId: 425727733
This commit is contained in:
parent
4e47de66fc
commit
dcca99b052
@ -80,10 +80,9 @@ def get_tensorstore_spec(ckpt_path: str):
|
||||
return spec
|
||||
|
||||
|
||||
async def async_serialize(ckpt_path: str, gda: gda.GlobalDeviceArray,
|
||||
tensorstore_spec):
|
||||
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
|
||||
if not tensorstore_spec.get('metadata'):
|
||||
tensorstore_spec['metadata'] = _get_metadata(gda)
|
||||
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
|
||||
|
||||
t = await ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
@ -97,19 +96,19 @@ async def async_serialize(ckpt_path: str, gda: gda.GlobalDeviceArray,
|
||||
if shard.replica_id == 0:
|
||||
await t[shard.index].write(shard.data)
|
||||
|
||||
future_write_state = jax.tree_util.tree_map(_write_array, tuple(gda.local_shards))
|
||||
future_write_state = jax.tree_util.tree_map(_write_array,
|
||||
tuple(gda_inp.local_shards))
|
||||
return await asyncio.gather(*future_write_state)
|
||||
|
||||
|
||||
def run_serialization(ckpt_paths, gdas, tensorstore_specs):
|
||||
def run_serialization(gdas, tensorstore_specs):
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_map(async_serialize, ckpt_paths, gdas,
|
||||
tensorstore_specs)
|
||||
future_writer = jax.tree_map(async_serialize, gdas, tensorstore_specs)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
|
||||
async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec):
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec):
|
||||
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
|
||||
|
||||
async def cb(index):
|
||||
@ -118,9 +117,9 @@ async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec):
|
||||
return await create_async_gda_from_callback(t.shape, mesh, mesh_axes, cb)
|
||||
|
||||
|
||||
def run_deserialization(ckpt_paths, global_meshes, mesh_axes, tensorstore_specs):
|
||||
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs):
|
||||
async def _run_deserializer():
|
||||
future_gdas = jax.tree_map(async_deserialize, ckpt_paths, global_meshes,
|
||||
mesh_axes, tensorstore_specs)
|
||||
future_gdas = jax.tree_map(async_deserialize, global_meshes, mesh_axes,
|
||||
tensorstore_specs)
|
||||
return await asyncio.gather(*future_gdas)
|
||||
return asyncio.run(_run_deserializer())
|
||||
|
@ -64,11 +64,10 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2)]
|
||||
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization(ckpt_paths, [gda1, gda2], tspecs)
|
||||
serialization.run_serialization([gda1, gda2], tspecs)
|
||||
|
||||
m1, m2 = serialization.run_deserialization(ckpt_paths,
|
||||
[global_mesh, global_mesh],
|
||||
[mesh_axes, ['x']], tspecs)
|
||||
m1, m2 = serialization.run_deserialization(
|
||||
[global_mesh, global_mesh], [mesh_axes, ['x']], tspecs)
|
||||
|
||||
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
|
||||
np.array([[0], [2]]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user