mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add checkpointing support for Array similar to GDA.
PiperOrigin-RevId: 469271107
This commit is contained in:
parent
384776f0c9
commit
7cdb7e1471
@ -264,6 +264,9 @@ class Array:
|
||||
self._check_if_deleted()
|
||||
return list(self.sharding.device_set)
|
||||
|
||||
def to_py(self) -> np.ndarray:
|
||||
return self._value
|
||||
|
||||
@pxla.maybe_cached_property
|
||||
def addressable_shards(self) -> Sequence[Shard]:
|
||||
self._check_if_deleted()
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import enum
|
||||
import itertools
|
||||
from functools import partial
|
||||
import re
|
||||
@ -25,6 +26,8 @@ from absl import logging
|
||||
import jax
|
||||
from jax._src import distributed
|
||||
from jax.experimental import global_device_array as gda
|
||||
from jax.experimental import array
|
||||
from jax.experimental import sharding
|
||||
from jax.experimental.maps import Mesh
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -37,6 +40,24 @@ _CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
||||
_module_unique_count = itertools.count()
|
||||
|
||||
|
||||
async def create_async_array_from_callback(
|
||||
global_shape: array.Shape,
|
||||
inp_sharding: sharding.XLACompatibleSharding,
|
||||
data_callback: Callable[[array.Index], asyncio.Future],
|
||||
):
|
||||
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
|
||||
future_arrays = [data_callback(device_to_index_map[d]) # type: ignore
|
||||
for d in inp_sharding._addressable_device_assignment]
|
||||
# Pause here and come back to `from_async_callback()` when future_arrays are
|
||||
# ready. device_put cannot happen with future_arrays.
|
||||
local_arrays = await asyncio.gather(*future_arrays)
|
||||
|
||||
dbs = [jax.device_put(array, device)
|
||||
for array, device in zip(local_arrays, inp_sharding._addressable_device_assignment)]
|
||||
aval = jax.ShapedArray(global_shape, dbs[0].dtype)
|
||||
return array.Array(aval, inp_sharding, dbs, committed=True)
|
||||
|
||||
|
||||
async def create_async_gda_from_callback(
|
||||
global_shape: gda.Shape,
|
||||
global_mesh: Mesh,
|
||||
@ -58,19 +79,22 @@ async def create_async_gda_from_callback(
|
||||
gda._GdaFastPathArgs(global_idx_rid, local_devices))
|
||||
|
||||
|
||||
def _get_metadata(gda):
|
||||
if gda.dtype == jnp.bfloat16:
|
||||
def _get_metadata(arr):
|
||||
if arr.dtype == jnp.bfloat16:
|
||||
# Tensorstore uses 'bfloat16', not '<V2'.
|
||||
dtype = 'bfloat16'
|
||||
else:
|
||||
dtype = np.dtype(gda.dtype).str
|
||||
|
||||
dtype = np.dtype(arr.dtype).str
|
||||
if isinstance(arr, array.Array):
|
||||
local_shape = arr._arrays[0].shape
|
||||
else:
|
||||
local_shape = arr.local_data(0).shape
|
||||
return {
|
||||
'compressor': {
|
||||
'id': 'gzip'
|
||||
},
|
||||
'shape': gda.shape,
|
||||
'chunks': np.array(np.maximum(1, gda.local_data(0).shape)),
|
||||
'shape': arr.shape,
|
||||
'chunks': np.array(np.maximum(1, local_shape)),
|
||||
'dtype': dtype,
|
||||
}
|
||||
|
||||
@ -121,12 +145,15 @@ class _LimitInFlightBytes:
|
||||
self._cv.notify_all()
|
||||
|
||||
|
||||
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
|
||||
commit_future=None):
|
||||
async def async_serialize(arr_inp, tensorstore_spec, commit_future=None):
|
||||
if (isinstance(arr_inp, array.Array) and jax.process_count() > 1 and
|
||||
arr_inp.is_fully_addressable()):
|
||||
raise ValueError('Passing fully addressable Arrays to a multi-host '
|
||||
'serialization is not allowed.')
|
||||
# 'metadata' may not be present at the top level (for example, if we are using
|
||||
# a 'cast' driver).
|
||||
if not _spec_has_metadata(tensorstore_spec):
|
||||
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
|
||||
tensorstore_spec['metadata'] = _get_metadata(arr_inp)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
open_future = ts.open(
|
||||
@ -156,14 +183,17 @@ async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
|
||||
else:
|
||||
await write_future.commit
|
||||
|
||||
future_write_state = jax.tree_util.tree_map(_write_array,
|
||||
gda_inp.local_shards)
|
||||
if isinstance(arr_inp, array.Array):
|
||||
local_shards = arr_inp.addressable_shards
|
||||
else:
|
||||
local_shards = arr_inp.local_shards
|
||||
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
|
||||
return await asyncio.gather(*future_write_state)
|
||||
|
||||
|
||||
def run_serialization(gdas, tensorstore_specs):
|
||||
def run_serialization(arrays, tensorstore_specs):
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(async_serialize, gdas, tensorstore_specs)
|
||||
future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
@ -189,9 +219,15 @@ def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
|
||||
return num_bytes
|
||||
|
||||
|
||||
class ArrayFlavor(enum.Enum):
|
||||
GDA = 0
|
||||
Array = 1
|
||||
|
||||
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
|
||||
global_shape=None, dtype=None,
|
||||
byte_limiter: Optional[_LimitInFlightBytes] = None):
|
||||
byte_limiter: Optional[_LimitInFlightBytes] = None,
|
||||
return_arr_flavor: ArrayFlavor = ArrayFlavor.GDA):
|
||||
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT)
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)
|
||||
@ -222,23 +258,29 @@ async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
|
||||
await byte_limiter.release_bytes(requested_bytes)
|
||||
return out
|
||||
|
||||
return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)
|
||||
if return_arr_flavor == ArrayFlavor.Array:
|
||||
inp_sharding = sharding.MeshPspecSharding(mesh, mesh_axes)
|
||||
return await create_async_array_from_callback(tuple(shape), inp_sharding, cb)
|
||||
else:
|
||||
return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)
|
||||
|
||||
|
||||
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None, dtypes=None, concurrent_gb=32):
|
||||
global_shapes=None, dtypes=None, concurrent_gb=32,
|
||||
return_arr_flavor=ArrayFlavor.GDA):
|
||||
concurrent_bytes = concurrent_gb * 10**9
|
||||
|
||||
async def _run_deserializer():
|
||||
# Object should be created once per process.
|
||||
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
|
||||
|
||||
future_gdas = jax.tree_util.tree_map(
|
||||
partial(async_deserialize, byte_limiter=byte_limiter),
|
||||
future_arrays = jax.tree_util.tree_map(
|
||||
partial(async_deserialize, byte_limiter=byte_limiter,
|
||||
return_arr_flavor=return_arr_flavor),
|
||||
global_meshes, mesh_axes, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
||||
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
||||
return await asyncio.gather(*future_gdas)
|
||||
return await asyncio.gather(*future_arrays)
|
||||
return asyncio.run(_run_deserializer())
|
||||
|
||||
|
||||
@ -299,10 +341,7 @@ class GlobalAsyncCheckpointManagerBase(metaclass=abc.ABCMeta):
|
||||
"""Blocks until serialization has finished."""
|
||||
|
||||
@abc.abstractmethod
|
||||
# TODO(b/233793426): Try removing temp_checkpoint_dir and final_checkpoint_dir
|
||||
# from the API and use a callback instead. This will affect how async
|
||||
# mechanism works.
|
||||
def serialize(self, gdas, tensorstore_specs, *,
|
||||
def serialize(self, arrays, tensorstore_specs, *,
|
||||
on_commit_callback: Callable[[], None]):
|
||||
"""Serializes GDAs to TensorStore."""
|
||||
|
||||
@ -396,8 +435,8 @@ class AsyncManager:
|
||||
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
|
||||
"""Responsible for serializing GDAs via TensorStore."""
|
||||
|
||||
def serialize(self, gdas, tensorstore_specs, *, on_commit_callback):
|
||||
"""Serializes GlobalDeviceArrays via TensorStore asynchronously.
|
||||
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
|
||||
"""Serializes GlobalDeviceArrays or Arrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
* Reading/copying from the source after which the source can be modified.
|
||||
@ -409,8 +448,9 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
finish in a separate thread allowing other computation to proceed.
|
||||
|
||||
Args:
|
||||
gdas: GlobalDeviceArrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs.
|
||||
arrays: GlobalDeviceArrays or Arrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
|
||||
Arrays.
|
||||
temp_checkpoint_dir: Temporary checkpoint directory where the checkpoints
|
||||
will be written.
|
||||
final_checkpoint_dir: Final checkpoint directory where the checkpoints
|
||||
@ -423,7 +463,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(
|
||||
async_serialize, gdas, tensorstore_specs, commit_futures)
|
||||
async_serialize, arrays, tensorstore_specs, commit_futures)
|
||||
return await asyncio.gather(*future_writer)
|
||||
|
||||
asyncio.run(_run_serializer())
|
||||
|
@ -20,6 +20,8 @@ import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax.config import config
|
||||
from jax.experimental import array
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
@ -53,7 +55,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)
|
||||
|
||||
# Third GDA
|
||||
def cb3(index):
|
||||
def cb3(_):
|
||||
return np.array([])
|
||||
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
|
||||
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
|
||||
@ -89,6 +91,64 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(s.data.to_py(), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_with_array(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
pspec = P('x', 'y')
|
||||
num = util.prod(inp_shape)
|
||||
|
||||
# First Array
|
||||
global_input_data1 = np.arange(num).reshape(inp_shape)
|
||||
a1 = array.make_array_from_callback(
|
||||
inp_shape, MeshPspecSharding(global_mesh, pspec),
|
||||
lambda idx: global_input_data1[idx])
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
# Second Array
|
||||
global_input_data2 = np.arange(num, num + num).reshape(inp_shape)
|
||||
a2 = array.make_array_from_callback(
|
||||
inp_shape, MeshPspecSharding(global_mesh, pspec),
|
||||
lambda idx: global_input_data2[idx])
|
||||
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)
|
||||
|
||||
# Third Array
|
||||
def cb3(_):
|
||||
return np.array([])
|
||||
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
|
||||
a3 = array.make_array_from_callback(
|
||||
(0,), MeshPspecSharding(global_mesh1d, P(None)), cb3)
|
||||
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([a1, a2, a3], tspecs)
|
||||
|
||||
m1, m2, m3 = serialization.run_deserialization(
|
||||
[global_mesh, global_mesh, global_mesh1d],
|
||||
[pspec, P('x'), P(None)],
|
||||
tspecs, return_arr_flavor=serialization.ArrayFlavor.Array)
|
||||
|
||||
self.assertArraysEqual(m1.addressable_shards[0].data.to_py(),
|
||||
np.array([[0], [2]]))
|
||||
self.assertArraysEqual(m1.addressable_shards[1].data.to_py(),
|
||||
np.array([[1], [3]]))
|
||||
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
self.assertArraysEqual(m2.addressable_shards[0].data.to_py(),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertArraysEqual(m2.addressable_shards[1].data.to_py(),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(m2.dtype, np.int32)
|
||||
|
||||
for i, s in enumerate(m3.addressable_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
self.assertArraysEqual(s.data.to_py(), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_with_bigger_shape(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
|
@ -167,7 +167,7 @@ class MeshPspecSharding(XLACompatibleSharding):
|
||||
return set(self.mesh.devices.flat)
|
||||
|
||||
def devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
||||
self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
# TODO(yashkatariya): Remove this when utilities are moved to pxla.py.
|
||||
from jax.experimental import global_device_array
|
||||
|
||||
@ -186,7 +186,7 @@ class MeshPspecSharding(XLACompatibleSharding):
|
||||
self,
|
||||
num_dimensions: int,
|
||||
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
|
||||
) -> Optional[xc.OpSharding]:
|
||||
) -> xc.OpSharding:
|
||||
from jax.experimental.pjit import get_array_mapping
|
||||
|
||||
array_mapping = get_array_mapping(self._parsed_pspec)
|
||||
@ -233,7 +233,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
return {self._device}
|
||||
|
||||
def devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
||||
self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
return {self._device: (slice(None),) * len(global_shape)}
|
||||
|
||||
def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
|
||||
@ -243,7 +243,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
def _device_assignment(self) -> XLADeviceAssignment:
|
||||
return [self._device]
|
||||
|
||||
def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
|
||||
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
|
||||
return _get_replicated_op_sharding()
|
||||
|
||||
|
||||
@ -324,7 +324,7 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(
|
||||
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
|
||||
self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
|
||||
len(self._devices))
|
||||
return dict(safe_zip(self._devices, indices))
|
||||
|
Loading…
x
Reference in New Issue
Block a user