Take shardings as a parameter to deserialize and run_deserialization instead of mesh and pspecs.

PiperOrigin-RevId: 479346552
This commit is contained in:
Yash Katariya 2022-10-06 10:20:14 -07:00 committed by jax authors
parent 219c574d8f
commit d174b3dce3
2 changed files with 59 additions and 37 deletions

View File

@ -19,7 +19,7 @@ import itertools
from functools import partial
import re
import threading
from typing import Callable, Sequence, Optional
from typing import Callable, Sequence, Optional, Dict, Any
from absl import logging
import jax
@ -28,6 +28,7 @@ from jax._src.config import config
from jax.experimental import global_device_array as gda
from jax._src import array
from jax._src import sharding
from jax._src import typing
from jax.experimental.maps import Mesh
import jax.numpy as jnp
import numpy as np
@ -152,7 +153,7 @@ class _LimitInFlightBytes:
async def async_serialize(arr_inp, tensorstore_spec, commit_future=None):
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
arr_inp.is_fully_addressable()):
raise ValueError('Passing fully addressable Arrays to a multi-host '
raise ValueError('Passing fully addressable Arrays to a multiprocess '
'serialization is not allowed.')
# 'metadata' may not be present at the top level (for example, if we are using
# a 'cast' driver).
@ -232,12 +233,12 @@ def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
return num_bytes
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
async def async_deserialize(in_sharding, tensorstore_spec,
global_shape=None, dtype=None,
byte_limiter: Optional[_LimitInFlightBytes] = None):
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT)
shape = t.shape if global_shape is None else global_shape
new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)
new_shard_shape = in_sharding.shard_shape(tuple(shape))
async def cb(index):
# This maybe needed because the shape the array was saved with is smaller
@ -266,14 +267,21 @@ async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
return out
if config.jax_array:
inp_sharding = sharding.MeshPspecSharding(mesh, mesh_axes)
return await create_async_array_from_callback(tuple(shape), inp_sharding, cb)
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
else:
return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)
if not isinstance(in_sharding, sharding.MeshPspecSharding):
raise ValueError('Deserializing a GlobalDeviceArray is only possible with '
'a `MeshPspecSharding` which consists of a `mesh` and '
f'`pspec`, but got {in_sharding}')
return await create_async_gda_from_callback(
tuple(shape), in_sharding.mesh, in_sharding.spec, cb)
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None, dtypes=None, concurrent_gb=32):
def run_deserialization(shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
concurrent_gb: int = 32):
concurrent_bytes = concurrent_gb * 10**9
async def _run_deserializer():
@ -282,7 +290,7 @@ def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
future_arrays = jax.tree_util.tree_map(
partial(async_deserialize, byte_limiter=byte_limiter),
global_meshes, mesh_axes, tensorstore_specs,
shardings, tensorstore_specs,
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
return await asyncio.gather(*future_arrays)
@ -351,8 +359,10 @@ class GlobalAsyncCheckpointManagerBase(metaclass=abc.ABCMeta):
"""Serializes GDAs to TensorStore."""
@abc.abstractmethod
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None, dtypes=None):
def deserialize(self, shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
"""Deserializes GDAs from TensorStore."""
@ -479,8 +489,10 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
# has finished writing.
self._start_async_commit(on_commit_callback)
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None, dtypes=None):
def deserialize(self, shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[Dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
self.wait_until_finished()
return run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
return run_deserialization(shardings, tensorstore_specs,
global_shapes, dtypes)

View File

@ -22,7 +22,7 @@ from jax._src import util
from jax._src import config as jax_config
from jax.config import config
from jax._src import array
from jax._src.sharding import MeshPspecSharding
from jax._src.sharding import MeshPspecSharding, OpShardingSharding
from jax.experimental import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.gda_serialization import serialization
@ -71,8 +71,9 @@ class CheckpointTest(jtu.JaxTestCase):
serialization.run_serialization([gda1, gda2, gda3], tspecs)
m1, m2, m3 = serialization.run_deserialization(
[global_mesh, global_mesh, global_mesh1d],
[mesh_axes, P('x'), P(None)],
[MeshPspecSharding(global_mesh, mesh_axes),
MeshPspecSharding(global_mesh, P('x')),
MeshPspecSharding(global_mesh1d, P(None))],
tspecs)
self.assertArraysEqual(np.asarray(m1.local_shards[0].data),
@ -130,8 +131,9 @@ class CheckpointTest(jtu.JaxTestCase):
serialization.run_serialization([a1, a2, a3], tspecs)
m1, m2, m3 = serialization.run_deserialization(
[global_mesh, global_mesh, global_mesh1d],
[pspec, P('x'), P(None)],
[MeshPspecSharding(global_mesh, pspec),
MeshPspecSharding(global_mesh, P('x')),
MeshPspecSharding(global_mesh1d, P(None))],
tspecs)
self.assertIsInstance(m1, array.ArrayImpl)
@ -177,9 +179,10 @@ class CheckpointTest(jtu.JaxTestCase):
serialization.run_serialization([gda1], tspecs)
ds = MeshPspecSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
m1, = serialization.run_deserialization(
[jtu.create_global_mesh((4, 2), ('x', 'y'))],
[P('x', 'y')],
[ds],
tspecs,
[(12, 2)],
[np.float32]
@ -199,6 +202,14 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
with self.assertRaisesRegex(
ValueError,
'Deserializing a GlobalDeviceArray is only possible with '
'a `MeshPspecSharding`'):
new_ds = OpShardingSharding.get_replicated(list(global_mesh.devices.flat))
serialization.run_deserialization([new_ds], tspecs, [(12, 2)], [np.float32])
@jax_config.jax_array(True)
def test_checkpointing_with_bigger_shape_jax_array(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@ -217,13 +228,10 @@ class CheckpointTest(jtu.JaxTestCase):
serialization.run_serialization([arr], tspecs)
m1, = serialization.run_deserialization(
[jtu.create_global_mesh((4, 2), ('x', 'y'))],
[P('x', 'y')],
tspecs,
[(12, 2)],
[np.float32]
)
ds = MeshPspecSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
[np.float32])
expected_data = {
0: np.array([[0], [2], [4]], dtype=np.float32),
@ -239,6 +247,11 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.addressable_shards:
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
new_ds = OpShardingSharding.get_replicated(list(global_mesh.devices.flat))
m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32])
for l in m2.addressable_shards:
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))
def test_checkpointing_scalar_gda(self):
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
@ -255,8 +268,7 @@ class CheckpointTest(jtu.JaxTestCase):
serialization.run_serialization([gda1], tspecs)
m1, = serialization.run_deserialization(
[jtu.create_global_mesh((2,), ('x'))],
[P(None)],
[MeshPspecSharding(jtu.create_global_mesh((2,), ('x')), P(None))],
tspecs,
[()],
[np.float32]
@ -279,10 +291,10 @@ class CheckpointTest(jtu.JaxTestCase):
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([gda1], tspecs)
ds = MeshPspecSharding(jtu.create_global_mesh((2,), ('x')), P(None))
m1, = serialization.run_deserialization(
[jtu.create_global_mesh((2,), ('x'))],
[P(None)],
[ds],
tspecs,
[()],
[np.float32]
@ -298,8 +310,7 @@ class CheckpointTest(jtu.JaxTestCase):
data = np.arange(1024)
tspec = ts.array(data).spec()
m1, = serialization.run_deserialization(
[global_mesh],
[P(None)],
[MeshPspecSharding(global_mesh, P(None))],
[tspec]
)
for l in m1.local_shards:
@ -311,8 +322,7 @@ class CheckpointTest(jtu.JaxTestCase):
data = np.arange(1024)
tspec = ts.array(data).spec()
m1, = serialization.run_deserialization(
[global_mesh],
[P(None)],
[MeshPspecSharding(global_mesh, P(None))],
[tspec]
)
for l in m1.addressable_shards: