mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a dtypes option to cast host arrays when reloading from TS.
PiperOrigin-RevId: 443804229
This commit is contained in:
parent
f682c77d3d
commit
cf87e3a4a3
@ -118,7 +118,8 @@ def run_serialization(gdas, tensorstore_specs):
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None):
|
||||
async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
|
||||
global_shape=None, dtype=None):
|
||||
t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
requires_padding = prod(shape) > prod(t.shape)
|
||||
@ -135,18 +136,24 @@ async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None
|
||||
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
|
||||
restricted_domain = t.domain.intersect(requested_domain)
|
||||
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write(t[restricted_domain])
|
||||
return out
|
||||
else:
|
||||
return await t[index].read()
|
||||
out = await t[index].read()
|
||||
|
||||
if dtype is not None:
|
||||
# Cast while reloading on process to avoid 2 copies on device if the
|
||||
# casting is done on device.
|
||||
return out.astype(dtype)
|
||||
return out
|
||||
|
||||
return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
|
||||
|
||||
|
||||
def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
global_shapes=None):
|
||||
global_shapes=None, dtypes=None):
|
||||
async def _run_deserializer():
|
||||
future_gdas = jax.tree_map(
|
||||
async_deserialize, global_meshes, mesh_axes, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes)
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
||||
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
||||
return await asyncio.gather(*future_gdas)
|
||||
return asyncio.run(_run_deserializer())
|
||||
|
@ -106,7 +106,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
# First GDA
|
||||
global_input_data1 = np.arange(num).reshape(global_input_shape)
|
||||
global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
@ -123,17 +123,18 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
[P('x', 'y')],
|
||||
tspecs,
|
||||
[(12, 2)],
|
||||
[np.float32]
|
||||
)
|
||||
|
||||
expected_data = {
|
||||
0: np.array([[0], [2], [4]]),
|
||||
1: np.array([[1], [3], [5]]),
|
||||
2: np.array([[6], [8], [10]]),
|
||||
3: np.array([[7], [9], [11]]),
|
||||
4: np.array([[12], [14], [0]]),
|
||||
5: np.array([[13], [15], [0]]),
|
||||
6: np.array([[0], [0], [0]]),
|
||||
7: np.array([[0], [0], [0]]),
|
||||
0: np.array([[0], [2], [4]], dtype=np.float32),
|
||||
1: np.array([[1], [3], [5]], dtype=np.float32),
|
||||
2: np.array([[6], [8], [10]], dtype=np.float32),
|
||||
3: np.array([[7], [9], [11]], dtype=np.float32),
|
||||
4: np.array([[12], [14], [0]], dtype=np.float32),
|
||||
5: np.array([[13], [15], [0]], dtype=np.float32),
|
||||
6: np.array([[0], [0], [0]], dtype=np.float32),
|
||||
7: np.array([[0], [0], [0]], dtype=np.float32),
|
||||
}
|
||||
|
||||
for l in m1.local_shards:
|
||||
|
Loading…
x
Reference in New Issue
Block a user