mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make resharding of GDA work if the shape is larger than what it was serialized with.
For example: If you serialize with shape (8, 2) and want to deserialize with global shape (12, 2). PiperOrigin-RevId: 429680502
This commit is contained in:
parent
c161c62878
commit
3290dd3a4d
@ -108,18 +108,26 @@ def run_serialization(gdas, tensorstore_specs):
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec):
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None):
|
||||
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes)
|
||||
|
||||
async def cb(index):
|
||||
return await t[index].read()
|
||||
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
|
||||
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
|
||||
restricted_domain = t.domain.intersect(requested_domain)
|
||||
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(t[restricted_domain])
|
||||
return out
|
||||
|
||||
return await create_async_gda_from_callback(t.shape, mesh, mesh_axes, cb)
|
||||
return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
|
||||
|
||||
|
||||
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs):
|
||||
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None):
|
||||
async def _run_deserializer():
|
||||
future_gdas = jax.tree_map(async_deserialize, global_meshes, mesh_axes,
|
||||
tensorstore_specs)
|
||||
future_gdas = jax.tree_map(
|
||||
async_deserialize, global_meshes, mesh_axes, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes)
|
||||
return await asyncio.gather(*future_gdas)
|
||||
return asyncio.run(_run_deserializer())
|
||||
|
@ -99,6 +99,44 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(s.data.to_py(), np.array([]))
|
||||
self.assertEqual(m3.dtype, np.float32)
|
||||
|
||||
def test_checkpointing_with_bigger_shape(self):
|
||||
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
# First GDA
|
||||
global_input_data1 = np.arange(num).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
['x', 'y'], cb1)
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[create_global_mesh((4, 2), ('x', 'y'))],
|
||||
[['x', 'y']],
|
||||
tspecs,
|
||||
[(12, 2)],
|
||||
)
|
||||
|
||||
expected_data = {
|
||||
0: np.array([[0], [2], [4]]),
|
||||
1: np.array([[1], [3], [5]]),
|
||||
2: np.array([[6], [8], [10]]),
|
||||
3: np.array([[7], [9], [11]]),
|
||||
4: np.array([[12], [14], [0]]),
|
||||
5: np.array([[13], [15], [0]]),
|
||||
6: np.array([[0], [0], [0]]),
|
||||
7: np.array([[0], [0], [0]]),
|
||||
}
|
||||
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user