Add a dtypes option to cast host arrays when reloading from TS.

PiperOrigin-RevId: 443804229
This commit is contained in:
Yash Katariya 2022-04-22 17:59:53 -07:00 committed by jax authors
parent f682c77d3d
commit cf87e3a4a3
2 changed files with 22 additions and 14 deletions

View File

@ -118,7 +118,8 @@ def run_serialization(gdas, tensorstore_specs):
asyncio.run(_run_serializer())
async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None):
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
global_shape=None, dtype=None):
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
shape = t.shape if global_shape is None else global_shape
requires_padding = prod(shape) > prod(t.shape)
@ -135,18 +136,24 @@ async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(t[restricted_domain])
return out
else:
return await t[index].read()
out = await t[index].read()
if dtype is not None:
# Cast while reloading on process to avoid 2 copies on device if the
# casting is done on device.
return out.astype(dtype)
return out
return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None):
global_shapes=None, dtypes=None):
async def _run_deserializer():
future_gdas = jax.tree_map(
async_deserialize, global_meshes, mesh_axes, tensorstore_specs,
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes)
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
return await asyncio.gather(*future_gdas)
return asyncio.run(_run_deserializer())

View File

@ -106,7 +106,7 @@ class CheckpointTest(jtu.JaxTestCase):
num = util.prod(global_input_shape)
# First GDA
global_input_data1 = np.arange(num).reshape(global_input_shape)
global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
def cb1(index):
return global_input_data1[index]
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
@ -123,17 +123,18 @@ class CheckpointTest(jtu.JaxTestCase):
[P('x', 'y')],
tspecs,
[(12, 2)],
[np.float32]
)
expected_data = {
0: np.array([[0], [2], [4]]),
1: np.array([[1], [3], [5]]),
2: np.array([[6], [8], [10]]),
3: np.array([[7], [9], [11]]),
4: np.array([[12], [14], [0]]),
5: np.array([[13], [15], [0]]),
6: np.array([[0], [0], [0]]),
7: np.array([[0], [0], [0]]),
0: np.array([[0], [2], [4]], dtype=np.float32),
1: np.array([[1], [3], [5]], dtype=np.float32),
2: np.array([[6], [8], [10]], dtype=np.float32),
3: np.array([[7], [9], [11]], dtype=np.float32),
4: np.array([[12], [14], [0]], dtype=np.float32),
5: np.array([[13], [15], [0]], dtype=np.float32),
6: np.array([[0], [0], [0]], dtype=np.float32),
7: np.array([[0], [0], [0]], dtype=np.float32),
}
for l in m1.local_shards: