mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Move serialization/de-serialization of GDA into jax.
PiperOrigin-RevId: 414607092
This commit is contained in:
parent
5de2cec415
commit
0bb7d204ab
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -132,4 +132,4 @@ jobs:
|
||||
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
||||
run: |
|
||||
pytest -n 1 --tb=short docs
|
||||
pytest -n 1 --tb=short --doctest-modules --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py jax
|
||||
pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization
|
||||
|
10
jax/experimental/gda_serialization/README
Normal file
10
jax/experimental/gda_serialization/README
Normal file
@ -0,0 +1,10 @@
|
||||
# Serialization and De-serialization of GlobalDeviceArray via tensorstore
|
||||
|
||||
Warning: This directory is going to move in the near future. Please use at your
|
||||
own risk.
|
||||
|
||||
To use this library, please install tensorstore and JAX.
|
||||
|
||||
```bash
|
||||
pip install -U tensorstore
|
||||
```
|
13
jax/experimental/gda_serialization/__init__.py
Normal file
13
jax/experimental/gda_serialization/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
100
jax/experimental/gda_serialization/serialization.py
Normal file
100
jax/experimental/gda_serialization/serialization.py
Normal file
@ -0,0 +1,100 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""GlobalDeviceArray serialization and deserialization."""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax.experimental import global_device_array as gda
|
||||
from jax.experimental.maps import Mesh
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
|
||||
|
||||
async def create_async_gsda_from_callback(
|
||||
global_shape: gda.Shape,
|
||||
global_mesh: Mesh,
|
||||
mesh_axes: gda.MeshAxes,
|
||||
data_callback: Callable[[gda.Index], asyncio.Future],
|
||||
):
|
||||
indices = gda.get_shard_indices(global_shape, global_mesh, mesh_axes)
|
||||
future_arrays = [data_callback(indices[d]) for d in global_mesh.local_devices]
|
||||
# Pause here and come back to `from_async_callback()` when future_arrays are
|
||||
# ready. device_put cannot happen with future_arrays.
|
||||
local_arrays = await asyncio.gather(*future_arrays)
|
||||
|
||||
dbs = [jax.device_put(array, device)
|
||||
for array, device in zip(local_arrays, global_mesh.local_devices)]
|
||||
return gda.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, dbs)
|
||||
|
||||
|
||||
def _get_metadata(gda):
|
||||
if gda.dtype == jnp.bfloat16:
|
||||
# Tensorstore uses 'bfloat16', not '<V2'.
|
||||
dtype = 'bfloat16'
|
||||
else:
|
||||
dtype = np.dtype(gda.dtype).str
|
||||
|
||||
return {
|
||||
'compressor': {
|
||||
'id': 'gzip'
|
||||
},
|
||||
'shape': gda.shape,
|
||||
'chunks': np.array(gda.local_data(0).shape),
|
||||
'dtype': dtype,
|
||||
}
|
||||
|
||||
|
||||
def get_tensorstore_spec(ckpt_path):
|
||||
spec = {'driver': 'zarr', 'kvstore': {}}
|
||||
# TODO(yashkatariya): Add GCS kvstore too.
|
||||
spec['kvstore'] = {
|
||||
'driver': 'file',
|
||||
'path': ckpt_path,
|
||||
}
|
||||
return spec
|
||||
|
||||
|
||||
async def async_serialize(ckpt_path: str, gda: gda.GlobalDeviceArray,
|
||||
tensorstore_spec):
|
||||
if not tensorstore_spec.get('metadata'):
|
||||
tensorstore_spec['metadata'] = _get_metadata(gda)
|
||||
|
||||
async def _write_array(shard):
|
||||
if shard.replica_id == 0:
|
||||
t = await ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
create=True,
|
||||
open=True,
|
||||
context=ts.Context({'file_io_concurrency': {
|
||||
'limit': 128
|
||||
}}))
|
||||
await t[shard.index].write(shard.data)
|
||||
|
||||
async def writer():
|
||||
future_write_state = jax.tree_util.tree_map(_write_array,
|
||||
tuple(gda.local_shards))
|
||||
await asyncio.gather(*future_write_state)
|
||||
return await writer()
|
||||
|
||||
|
||||
async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec):
|
||||
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
|
||||
|
||||
async def cb(index):
|
||||
return await t[index].read()
|
||||
|
||||
return await create_async_gsda_from_callback(t.shape, mesh, mesh_axes, cb)
|
108
jax/experimental/gda_serialization/serialization_test.py
Normal file
108
jax/experimental/gda_serialization/serialization_test.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for serialization and deserialization of GDA."""
|
||||
import asyncio
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax.config import config
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
from jax.experimental.maps import Mesh
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def create_global_mesh(mesh_shape, axis_names):
|
||||
size = util.prod(mesh_shape)
|
||||
if len(jax.devices()) < size:
|
||||
raise unittest.SkipTest(f'Test requires {size} local devices')
|
||||
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
|
||||
global_mesh = Mesh(mesh_devices, axis_names)
|
||||
return global_mesh
|
||||
|
||||
|
||||
class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
def test_checkpointing(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x', 'y']
|
||||
num = util.prod(global_input_shape)
|
||||
global_input_data1 = np.arange(num).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
|
||||
gsda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb1)
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
|
||||
def cb2(index):
|
||||
return global_input_data2[index]
|
||||
|
||||
gsda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb2)
|
||||
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2)]
|
||||
|
||||
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
# Async Serialization below.
|
||||
async def run_serializer():
|
||||
future_writer = jax.tree_map(
|
||||
serialization.async_serialize,
|
||||
ckpt_paths,
|
||||
[gsda1, gsda2],
|
||||
tspecs
|
||||
)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(run_serializer())
|
||||
|
||||
# Async Deserialization below.
|
||||
async def run():
|
||||
future_gsdas = jax.tree_map(
|
||||
serialization.async_deserialize,
|
||||
ckpt_paths,
|
||||
[global_mesh, global_mesh],
|
||||
[mesh_axes, ['x']],
|
||||
tspecs
|
||||
)
|
||||
return await asyncio.gather(*future_gsdas)
|
||||
|
||||
m1, m2 = asyncio.run(run())
|
||||
|
||||
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
|
||||
np.array([[0], [2]]))
|
||||
self.assertArraysEqual(m1.local_shards[1].data.to_py(),
|
||||
np.array([[1], [3]]))
|
||||
self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
|
||||
self.assertEqual(m1.dtype, np.int32)
|
||||
|
||||
self.assertArraysEqual(m2.local_shards[0].data.to_py(),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertArraysEqual(m2.local_shards[1].data.to_py(),
|
||||
np.array([[16, 17], [18, 19]]))
|
||||
self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(m2.dtype, np.int32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user