Move serialization/de-serialization of GDA into jax.

PiperOrigin-RevId: 414607092
This commit is contained in:
Yash Katariya 2021-12-06 20:04:19 -08:00 committed by jax authors
parent 5de2cec415
commit 0bb7d204ab
5 changed files with 232 additions and 1 deletions

View File

@ -132,4 +132,4 @@ jobs:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
run: |
pytest -n 1 --tb=short docs
pytest -n 1 --tb=short --doctest-modules --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py jax
pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization

View File

@ -0,0 +1,10 @@
# Serialization and De-serialization of GlobalDeviceArray via tensorstore
Warning: This directory is going to move in the near future. Please use at your
own risk.
To use this library, please install tensorstore and JAX.
```bash
pip install -U tensorstore
```

View File

@ -0,0 +1,13 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,100 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GlobalDeviceArray serialization and deserialization."""
import asyncio
from typing import Callable
import jax
from jax.experimental import global_device_array as gda
from jax.experimental.maps import Mesh
import jax.numpy as jnp
import numpy as np
import tensorstore as ts
async def create_async_gsda_from_callback(
global_shape: gda.Shape,
global_mesh: Mesh,
mesh_axes: gda.MeshAxes,
data_callback: Callable[[gda.Index], asyncio.Future],
):
indices = gda.get_shard_indices(global_shape, global_mesh, mesh_axes)
future_arrays = [data_callback(indices[d]) for d in global_mesh.local_devices]
# Pause here and come back to `from_async_callback()` when future_arrays are
# ready. device_put cannot happen with future_arrays.
local_arrays = await asyncio.gather(*future_arrays)
dbs = [jax.device_put(array, device)
for array, device in zip(local_arrays, global_mesh.local_devices)]
return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs)
def _get_metadata(gda):
if gda.dtype == jnp.bfloat16:
# Tensorstore uses 'bfloat16', not '<V2'.
dtype = 'bfloat16'
else:
dtype = np.dtype(gda.dtype).str
return {
'compressor': {
'id': 'gzip'
},
'shape': gda.shape,
'chunks': np.array(gda.local_data(0).shape),
'dtype': dtype,
}
def get_tensorstore_spec(ckpt_path):
spec = {'driver': 'zarr', 'kvstore': {}}
# TODO(yashkatariya): Add GCS kvstore too.
spec['kvstore'] = {
'driver': 'file',
'path': ckpt_path,
}
return spec
async def async_serialize(ckpt_path: str, gda: gda.GlobalDeviceArray,
tensorstore_spec):
if not tensorstore_spec.get('metadata'):
tensorstore_spec['metadata'] = _get_metadata(gda)
async def _write_array(shard):
if shard.replica_id == 0:
t = await ts.open(
ts.Spec(tensorstore_spec),
create=True,
open=True,
context=ts.Context({'file_io_concurrency': {
'limit': 128
}}))
await t[shard.index].write(shard.data)
async def writer():
future_write_state = jax.tree_util.tree_map(_write_array,
tuple(gda.local_shards))
await asyncio.gather(*future_write_state)
return await writer()
async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec):
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
async def cb(index):
return await t[index].read()
return await create_async_gsda_from_callback(t.shape, mesh, mesh_axes, cb)

View File

@ -0,0 +1,108 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for serialization and deserialization of GDA."""
import asyncio
import pathlib
import unittest
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax._src import util
from jax.config import config
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.gda_serialization import serialization
from jax.experimental.maps import Mesh
import numpy as np
config.parse_flags_with_absl()
def create_global_mesh(mesh_shape, axis_names):
size = util.prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f'Test requires {size} local devices')
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
global_mesh = Mesh(mesh_devices, axis_names)
return global_mesh
class CheckpointTest(jtu.JaxTestCase):
def test_checkpointing(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
num = util.prod(global_input_shape)
global_input_data1 = np.arange(num).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
gsda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb1)
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
def cb2(index):
return global_input_data2[index]
gsda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb2)
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
# Async Serialization below.
async def run_serializer():
future_writer = jax.tree_map(
serialization.async_serialize,
ckpt_paths,
[gsda1, gsda2],
tspecs
)
return await asyncio.gather(*future_writer)
asyncio.run(run_serializer())
# Async Deserialization below.
async def run():
future_gsdas = jax.tree_map(
serialization.async_deserialize,
ckpt_paths,
[global_mesh, global_mesh],
[mesh_axes, ['x']],
tspecs
)
return await asyncio.gather(*future_gsdas)
m1, m2 = asyncio.run(run())
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
np.array([[0], [2]]))
self.assertArraysEqual(m1.local_shards[1].data.to_py(),
np.array([[1], [3]]))
self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)
self.assertArraysEqual(m2.local_shards[0].data.to_py(),
np.array([[16, 17], [18, 19]]))
self.assertArraysEqual(m2.local_shards[1].data.to_py(),
np.array([[16, 17], [18, 19]]))
self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())