mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays.
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively. PiperOrigin-RevId: 734299259
This commit is contained in:
parent
4b49c03523
commit
cd7f03f272
@ -201,7 +201,7 @@ def _make_output_specs_and_push_result_fun(
|
||||
|
||||
devices = specialization.devices
|
||||
|
||||
def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]:
|
||||
def lowered_fun(*args, **kwargs) -> jax.Array:
|
||||
result = info.fun(*args, **kwargs)
|
||||
result_leaves, out_treedef = tree_util.tree_flatten(result)
|
||||
out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)
|
||||
|
@ -13,11 +13,9 @@
|
||||
# limitations under the License.
|
||||
"""Colocated Python serialization utilities."""
|
||||
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import collections
|
||||
import functools
|
||||
import io
|
||||
@ -37,12 +35,6 @@ import numpy as np
|
||||
|
||||
DeviceList = xc.DeviceList
|
||||
|
||||
# Hard-coded limit for serialized specs size.
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
_MAX_SERIALIZED_SPECS_SIZE = 1048576
|
||||
|
||||
|
||||
@jax.util.cache(max_size=None)
|
||||
def _get_cpu_device_map() -> dict[int, jax.Device]:
|
||||
"""Returns a map from a device id to a matching device."""
|
||||
@ -185,23 +177,14 @@ def _deserialize(serialized: bytes) -> Any:
|
||||
|
||||
def _make_specs_for_serialized_specs(
|
||||
devices: DeviceList,
|
||||
) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]:
|
||||
) -> api.ShapeDtypeStruct:
|
||||
"""Makes output specs for serialized specs."""
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
|
||||
replicated_sharding = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec()
|
||||
)
|
||||
return (
|
||||
api.ShapeDtypeStruct(
|
||||
shape=(), dtype=np.int32, sharding=replicated_sharding
|
||||
),
|
||||
api.ShapeDtypeStruct(
|
||||
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
|
||||
dtype=np.uint8,
|
||||
sharding=replicated_sharding,
|
||||
),
|
||||
return api.ShapeDtypeStruct(
|
||||
shape=(), dtype=np.dtypes.StringDType(), sharding=replicated_sharding # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@ -209,49 +192,49 @@ def _serialize_specs(
|
||||
specs_treedef: tree_util.PyTreeDef,
|
||||
specs_leaves: tuple[api.ShapeDtypeStruct, ...],
|
||||
devices: DeviceList,
|
||||
) -> tuple[jax.Array, ...]:
|
||||
"""Serializes the output specs into a tuple of arrays.
|
||||
) -> jax.Array:
|
||||
"""Serializes the output specs into a jax.Array of string type.
|
||||
|
||||
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
|
||||
colocated_python. See serialize() for details.
|
||||
"""
|
||||
s = _serialize((specs_treedef, specs_leaves))
|
||||
assert (
|
||||
len(s) <= _MAX_SERIALIZED_SPECS_SIZE
|
||||
), f"Too large serialized spec size: {len(s)}"
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
|
||||
if not hasattr(np.dtypes, "StringDType"):
|
||||
raise TypeError(
|
||||
"Serializing Colocated Python requires StringDType. Please use"
|
||||
" numpy to 2.0.0 or later, or explicityly provide an output spec"
|
||||
" function."
|
||||
)
|
||||
|
||||
s_bytes = _serialize((specs_treedef, specs_leaves))
|
||||
s_str = base64.b64encode(s_bytes).decode("ascii")
|
||||
s_np_array = np.array(s_str, dtype=np.dtypes.StringDType()) # type: ignore
|
||||
|
||||
# TODO(jmudigonda): Revisit this when JAX supports HLO sharding for making
|
||||
# jax.Array via make_array_from_single_device_arrays. We should then use a
|
||||
# sharding that spans all the execution devices - not just the addressable
|
||||
# ones.
|
||||
addressable_devices = devices.addressable_device_list
|
||||
mesh = jax.sharding.Mesh(tuple(addressable_devices), ("x",))
|
||||
replicated_sharding = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec()
|
||||
)
|
||||
len_array = jax.make_array_from_callback(
|
||||
shape=(),
|
||||
sharding=replicated_sharding,
|
||||
data_callback=lambda _: np.array(len(s), dtype=np.int32),
|
||||
|
||||
out_arrays = [
|
||||
jax.device_put(s_np_array, device) for device in addressable_devices
|
||||
]
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
arrays=out_arrays, sharding=replicated_sharding, shape=(),
|
||||
)
|
||||
data_array = jax.make_array_from_callback(
|
||||
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
|
||||
sharding=replicated_sharding,
|
||||
data_callback=lambda _: np.frombuffer(
|
||||
s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
)
|
||||
return len_array, data_array
|
||||
|
||||
|
||||
def _deserialize_specs(
|
||||
serialized_specs: tuple[jax.Array, ...],
|
||||
serialized_specs: jax.Array,
|
||||
) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]:
|
||||
"""Deserializes the specs from the serialized specs.
|
||||
|
||||
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
|
||||
colocated_python. See serialize() for details.
|
||||
"""
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
len_array, data_array = serialized_specs
|
||||
length = int(len_array.addressable_shards[0].data)
|
||||
data = np.asarray(data_array.addressable_shards[0].data).tobytes()
|
||||
return _deserialize(data[:length])
|
||||
data_array = serialized_specs.addressable_shards[0].data
|
||||
data = base64.b64decode(data_array.item().encode("ascii"))
|
||||
return _deserialize(data)
|
||||
|
@ -66,6 +66,14 @@ _count_colocated_python_specialization_cache_miss = jtu.count_events(
|
||||
|
||||
class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
|
||||
self.skipTest(
|
||||
"Serialization in Colocated Python needs StringDType, and thus"
|
||||
" requires NumPy 2.0.0 or later"
|
||||
)
|
||||
|
||||
def testMakeColocatedPythonProgram(self):
|
||||
def add_one(x):
|
||||
return x + 1
|
||||
@ -382,8 +390,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
del colocated_python._testing_global_state
|
||||
|
||||
def testStringProcessing(self):
|
||||
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
|
||||
self.skipTest("StringDType requires NumPy 2.0.0 or later")
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
if len(cpu_devices) < 2:
|
||||
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")
|
||||
@ -425,8 +431,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def testBinaryDataProcessing(self):
|
||||
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
|
||||
self.skipTest("StringDType requires NumPy 2.0.0 or later")
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
if len(cpu_devices) < 1:
|
||||
self.skipTest("Need at least one CPU devices")
|
||||
|
Loading…
x
Reference in New Issue
Block a user