Make resharding of GDA work if the shape is larger than what it was serialized with.

For example: If you serialize with shape (8, 2) and want to deserialize with global shape (12, 2).

PiperOrigin-RevId: 429680502
This commit is contained in:
Yash Katariya 2022-02-18 17:27:45 -08:00 committed by jax authors
parent c161c62878
commit 3290dd3a4d
2 changed files with 52 additions and 6 deletions

View File

@ -108,18 +108,26 @@ def run_serialization(gdas, tensorstore_specs):
asyncio.run(_run_serializer())
async def async_deserialize(mesh, mesh_axes, tensorstore_spec):
async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None):
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
shape = t.shape if global_shape is None else global_shape
new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes)
async def cb(index):
return await t[index].read()
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(t[restricted_domain])
return out
return await create_async_gda_from_callback(t.shape, mesh, mesh_axes, cb)
return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs):
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None):
async def _run_deserializer():
future_gdas = jax.tree_map(async_deserialize, global_meshes, mesh_axes,
tensorstore_specs)
future_gdas = jax.tree_map(
async_deserialize, global_meshes, mesh_axes, tensorstore_specs,
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes)
return await asyncio.gather(*future_gdas)
return asyncio.run(_run_deserializer())

View File

@ -99,6 +99,44 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertEqual(m3.dtype, np.float32)
def test_checkpointing_with_bigger_shape(self):
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
num = util.prod(global_input_shape)
# First GDA
global_input_data1 = np.arange(num).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
['x', 'y'], cb1)
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir1)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([gda1], tspecs)
m1, = serialization.run_deserialization(
[create_global_mesh((4, 2), ('x', 'y'))],
[['x', 'y']],
tspecs,
[(12, 2)],
)
expected_data = {
0: np.array([[0], [2], [4]]),
1: np.array([[1], [3], [5]]),
2: np.array([[6], [8], [10]]),
3: np.array([[7], [9], [11]]),
4: np.array([[12], [14], [0]]),
5: np.array([[13], [15], [0]]),
6: np.array([[0], [0], [0]]),
7: np.array([[0], [0], [0]]),
}
for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())