mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Handle serialization of arrays with shape (0,)
. These arrays are usually empty lists (np.array([])
)
PiperOrigin-RevId: 426172532
This commit is contained in:
parent
d04dce3fa2
commit
b7dcc4ce01
@ -58,7 +58,7 @@ def _get_metadata(gda):
|
||||
'id': 'gzip'
|
||||
},
|
||||
'shape': gda.shape,
|
||||
'chunks': np.array(gda.local_data(0).shape),
|
||||
'chunks': np.array(np.maximum(1, gda.local_data(0).shape)),
|
||||
'dtype': dtype,
|
||||
}
|
||||
|
||||
|
@ -45,29 +45,39 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x', 'y']
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
# First GDA
|
||||
global_input_data1 = np.arange(num).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb1)
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
# Second GDA
|
||||
global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
|
||||
def cb2(index):
|
||||
return global_input_data2[index]
|
||||
|
||||
gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb2)
|
||||
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2)]
|
||||
# Third GDA
|
||||
def cb3(index):
|
||||
return np.array([])
|
||||
global_mesh1d = create_global_mesh((8,), ('x',))
|
||||
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, [None], cb3)
|
||||
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
|
||||
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1, gda2], tspecs)
|
||||
serialization.run_serialization([gda1, gda2, gda3], tspecs)
|
||||
|
||||
m1, m2 = serialization.run_deserialization(
|
||||
[global_mesh, global_mesh], [mesh_axes, ['x']], tspecs)
|
||||
m1, m2, m3 = serialization.run_deserialization(
|
||||
[global_mesh, global_mesh, global_mesh1d],
|
||||
[mesh_axes, ['x'], [None]],
|
||||
tspecs)
|
||||
|
||||
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
|
||||
np.array([[0], [2]]))
|
||||
@ -83,6 +93,12 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(m2.dtype, np.int32)
|
||||
|
||||
for i, s in enumerate(m3.local_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
self.assertArraysEqual(s.data.to_py(), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user