mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Take shardings
as a parameter to deserialize
and run_deserialization
instead of mesh
and pspecs
.
PiperOrigin-RevId: 479346552
This commit is contained in:
parent
219c574d8f
commit
d174b3dce3
@ -19,7 +19,7 @@ import itertools
|
||||
from functools import partial
|
||||
import re
|
||||
import threading
|
||||
from typing import Callable, Sequence, Optional
|
||||
from typing import Callable, Sequence, Optional, Dict, Any
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
@ -28,6 +28,7 @@ from jax._src.config import config
|
||||
from jax.experimental import global_device_array as gda
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax._src import typing
|
||||
from jax.experimental.maps import Mesh
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -152,7 +153,7 @@ class _LimitInFlightBytes:
|
||||
async def async_serialize(arr_inp, tensorstore_spec, commit_future=None):
|
||||
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
|
||||
arr_inp.is_fully_addressable()):
|
||||
raise ValueError('Passing fully addressable Arrays to a multi-host '
|
||||
raise ValueError('Passing fully addressable Arrays to a multiprocess '
|
||||
'serialization is not allowed.')
|
||||
# 'metadata' may not be present at the top level (for example, if we are using
|
||||
# a 'cast' driver).
|
||||
@ -232,12 +233,12 @@ def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
|
||||
return num_bytes
|
||||
|
||||
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
|
||||
async def async_deserialize(in_sharding, tensorstore_spec,
|
||||
global_shape=None, dtype=None,
|
||||
byte_limiter: Optional[_LimitInFlightBytes] = None):
|
||||
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT)
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)
|
||||
new_shard_shape = in_sharding.shard_shape(tuple(shape))
|
||||
|
||||
async def cb(index):
|
||||
# This maybe needed because the shape the array was saved with is smaller
|
||||
@ -266,14 +267,21 @@ async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
|
||||
return out
|
||||
|
||||
if config.jax_array:
|
||||
inp_sharding = sharding.MeshPspecSharding(mesh, mesh_axes)
|
||||
return await create_async_array_from_callback(tuple(shape), inp_sharding, cb)
|
||||
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
|
||||
else:
|
||||
return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)
|
||||
if not isinstance(in_sharding, sharding.MeshPspecSharding):
|
||||
raise ValueError('Deserializing a GlobalDeviceArray is only possible with '
|
||||
'a `MeshPspecSharding` which consists of a `mesh` and '
|
||||
f'`pspec`, but got {in_sharding}')
|
||||
return await create_async_gda_from_callback(
|
||||
tuple(shape), in_sharding.mesh, in_sharding.spec, cb)
|
||||
|
||||
|
||||
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None, dtypes=None, concurrent_gb=32):
|
||||
def run_deserialization(shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
|
||||
concurrent_gb: int = 32):
|
||||
concurrent_bytes = concurrent_gb * 10**9
|
||||
|
||||
async def _run_deserializer():
|
||||
@ -282,7 +290,7 @@ def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
|
||||
future_arrays = jax.tree_util.tree_map(
|
||||
partial(async_deserialize, byte_limiter=byte_limiter),
|
||||
global_meshes, mesh_axes, tensorstore_specs,
|
||||
shardings, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
||||
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
||||
return await asyncio.gather(*future_arrays)
|
||||
@ -351,8 +359,10 @@ class GlobalAsyncCheckpointManagerBase(metaclass=abc.ABCMeta):
|
||||
"""Serializes GDAs to TensorStore."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None, dtypes=None):
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
|
||||
"""Deserializes GDAs from TensorStore."""
|
||||
|
||||
|
||||
@ -479,8 +489,10 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
# has finished writing.
|
||||
self._start_async_commit(on_commit_callback)
|
||||
|
||||
def deserialize(self, global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None, dtypes=None):
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[Dict[str, Any]],
|
||||
global_shapes: Optional[Sequence[array.Shape]] = None,
|
||||
dtypes: Optional[Sequence[typing.DTypeLike]] = None):
|
||||
self.wait_until_finished()
|
||||
return run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
return run_deserialization(shardings, tensorstore_specs,
|
||||
global_shapes, dtypes)
|
||||
|
@ -22,7 +22,7 @@ from jax._src import util
|
||||
from jax._src import config as jax_config
|
||||
from jax.config import config
|
||||
from jax._src import array
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
from jax._src.sharding import MeshPspecSharding, OpShardingSharding
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
@ -71,8 +71,9 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
serialization.run_serialization([gda1, gda2, gda3], tspecs)
|
||||
|
||||
m1, m2, m3 = serialization.run_deserialization(
|
||||
[global_mesh, global_mesh, global_mesh1d],
|
||||
[mesh_axes, P('x'), P(None)],
|
||||
[MeshPspecSharding(global_mesh, mesh_axes),
|
||||
MeshPspecSharding(global_mesh, P('x')),
|
||||
MeshPspecSharding(global_mesh1d, P(None))],
|
||||
tspecs)
|
||||
|
||||
self.assertArraysEqual(np.asarray(m1.local_shards[0].data),
|
||||
@ -130,8 +131,9 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
serialization.run_serialization([a1, a2, a3], tspecs)
|
||||
|
||||
m1, m2, m3 = serialization.run_deserialization(
|
||||
[global_mesh, global_mesh, global_mesh1d],
|
||||
[pspec, P('x'), P(None)],
|
||||
[MeshPspecSharding(global_mesh, pspec),
|
||||
MeshPspecSharding(global_mesh, P('x')),
|
||||
MeshPspecSharding(global_mesh1d, P(None))],
|
||||
tspecs)
|
||||
|
||||
self.assertIsInstance(m1, array.ArrayImpl)
|
||||
@ -177,9 +179,10 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
ds = MeshPspecSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[jtu.create_global_mesh((4, 2), ('x', 'y'))],
|
||||
[P('x', 'y')],
|
||||
[ds],
|
||||
tspecs,
|
||||
[(12, 2)],
|
||||
[np.float32]
|
||||
@ -199,6 +202,14 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Deserializing a GlobalDeviceArray is only possible with '
|
||||
'a `MeshPspecSharding`'):
|
||||
new_ds = OpShardingSharding.get_replicated(list(global_mesh.devices.flat))
|
||||
serialization.run_deserialization([new_ds], tspecs, [(12, 2)], [np.float32])
|
||||
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_checkpointing_with_bigger_shape_jax_array(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
@ -217,13 +228,10 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
serialization.run_serialization([arr], tspecs)
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[jtu.create_global_mesh((4, 2), ('x', 'y'))],
|
||||
[P('x', 'y')],
|
||||
tspecs,
|
||||
[(12, 2)],
|
||||
[np.float32]
|
||||
)
|
||||
ds = MeshPspecSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
|
||||
|
||||
m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
|
||||
[np.float32])
|
||||
|
||||
expected_data = {
|
||||
0: np.array([[0], [2], [4]], dtype=np.float32),
|
||||
@ -239,6 +247,11 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
||||
|
||||
new_ds = OpShardingSharding.get_replicated(list(global_mesh.devices.flat))
|
||||
m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32])
|
||||
for l in m2.addressable_shards:
|
||||
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))
|
||||
|
||||
def test_checkpointing_scalar_gda(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
@ -255,8 +268,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[jtu.create_global_mesh((2,), ('x'))],
|
||||
[P(None)],
|
||||
[MeshPspecSharding(jtu.create_global_mesh((2,), ('x')), P(None))],
|
||||
tspecs,
|
||||
[()],
|
||||
[np.float32]
|
||||
@ -279,10 +291,10 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
ds = MeshPspecSharding(jtu.create_global_mesh((2,), ('x')), P(None))
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[jtu.create_global_mesh((2,), ('x'))],
|
||||
[P(None)],
|
||||
[ds],
|
||||
tspecs,
|
||||
[()],
|
||||
[np.float32]
|
||||
@ -298,8 +310,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
data = np.arange(1024)
|
||||
tspec = ts.array(data).spec()
|
||||
m1, = serialization.run_deserialization(
|
||||
[global_mesh],
|
||||
[P(None)],
|
||||
[MeshPspecSharding(global_mesh, P(None))],
|
||||
[tspec]
|
||||
)
|
||||
for l in m1.local_shards:
|
||||
@ -311,8 +322,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
data = np.arange(1024)
|
||||
tspec = ts.array(data).spec()
|
||||
m1, = serialization.run_deserialization(
|
||||
[global_mesh],
|
||||
[P(None)],
|
||||
[MeshPspecSharding(global_mesh, P(None))],
|
||||
[tspec]
|
||||
)
|
||||
for l in m1.addressable_shards:
|
||||
|
Loading…
x
Reference in New Issue
Block a user