Handle serialization of arrays with shape (0,). These arrays are usually empty lists (np.array([]))

PiperOrigin-RevId: 426172532
This commit is contained in:
Yash Katariya 2022-02-03 09:59:25 -08:00 committed by jax authors
parent d04dce3fa2
commit b7dcc4ce01
2 changed files with 23 additions and 7 deletions

View File

@ -58,7 +58,7 @@ def _get_metadata(gda):
'id': 'gzip'
},
'shape': gda.shape,
'chunks': np.array(gda.local_data(0).shape),
'chunks': np.array(np.maximum(1, gda.local_data(0).shape)),
'dtype': dtype,
}

View File

@ -45,29 +45,39 @@ class CheckpointTest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
num = util.prod(global_input_shape)
# First GDA
global_input_data1 = np.arange(num).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb1)
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
# Second GDA
global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
def cb2(index):
return global_input_data2[index]
gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb2)
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2)]
# Third GDA
def cb3(index):
return np.array([])
global_mesh1d = create_global_mesh((8,), ('x',))
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, [None], cb3)
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([gda1, gda2], tspecs)
serialization.run_serialization([gda1, gda2, gda3], tspecs)
m1, m2 = serialization.run_deserialization(
[global_mesh, global_mesh], [mesh_axes, ['x']], tspecs)
m1, m2, m3 = serialization.run_deserialization(
[global_mesh, global_mesh, global_mesh1d],
[mesh_axes, ['x'], [None]],
tspecs)
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
np.array([[0], [2]]))
@ -83,6 +93,12 @@ class CheckpointTest(jtu.JaxTestCase):
self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32)
for i, s in enumerate(m3.local_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertEqual(m3.dtype, np.float32)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())