Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays.

Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.

PiperOrigin-RevId: 734299259
This commit is contained in:
jax authors 2025-03-06 14:57:18 -08:00
parent 4b49c03523
commit cd7f03f272
3 changed files with 42 additions and 55 deletions

View File

@ -201,7 +201,7 @@ def _make_output_specs_and_push_result_fun(
devices = specialization.devices
def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]:
def lowered_fun(*args, **kwargs) -> jax.Array:
result = info.fun(*args, **kwargs)
result_leaves, out_treedef = tree_util.tree_flatten(result)
out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)

View File

@ -13,11 +13,9 @@
# limitations under the License.
"""Colocated Python serialization utilities."""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
from __future__ import annotations
import base64
import collections
import functools
import io
@ -37,12 +35,6 @@ import numpy as np
DeviceList = xc.DeviceList
# Hard-coded limit for serialized specs size.
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
_MAX_SERIALIZED_SPECS_SIZE = 1048576
@jax.util.cache(max_size=None)
def _get_cpu_device_map() -> dict[int, jax.Device]:
"""Returns a map from a device id to a matching device."""
@ -185,23 +177,14 @@ def _deserialize(serialized: bytes) -> Any:
def _make_specs_for_serialized_specs(
devices: DeviceList,
) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]:
) -> api.ShapeDtypeStruct:
"""Makes output specs for serialized specs."""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
return (
api.ShapeDtypeStruct(
shape=(), dtype=np.int32, sharding=replicated_sharding
),
api.ShapeDtypeStruct(
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
dtype=np.uint8,
sharding=replicated_sharding,
),
return api.ShapeDtypeStruct(
shape=(), dtype=np.dtypes.StringDType(), sharding=replicated_sharding # type: ignore
)
@ -209,49 +192,49 @@ def _serialize_specs(
specs_treedef: tree_util.PyTreeDef,
specs_leaves: tuple[api.ShapeDtypeStruct, ...],
devices: DeviceList,
) -> tuple[jax.Array, ...]:
"""Serializes the output specs into a tuple of arrays.
) -> jax.Array:
"""Serializes the output specs into a jax.Array of string type.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
"""
s = _serialize((specs_treedef, specs_leaves))
assert (
len(s) <= _MAX_SERIALIZED_SPECS_SIZE
), f"Too large serialized spec size: {len(s)}"
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
if not hasattr(np.dtypes, "StringDType"):
raise TypeError(
"Serializing Colocated Python requires StringDType. Please use"
" numpy to 2.0.0 or later, or explicityly provide an output spec"
" function."
)
s_bytes = _serialize((specs_treedef, specs_leaves))
s_str = base64.b64encode(s_bytes).decode("ascii")
s_np_array = np.array(s_str, dtype=np.dtypes.StringDType()) # type: ignore
# TODO(jmudigonda): Revisit this when JAX supports HLO sharding for making
# jax.Array via make_array_from_single_device_arrays. We should then use a
# sharding that spans all the execution devices - not just the addressable
# ones.
addressable_devices = devices.addressable_device_list
mesh = jax.sharding.Mesh(tuple(addressable_devices), ("x",))
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
len_array = jax.make_array_from_callback(
shape=(),
sharding=replicated_sharding,
data_callback=lambda _: np.array(len(s), dtype=np.int32),
out_arrays = [
jax.device_put(s_np_array, device) for device in addressable_devices
]
return jax.make_array_from_single_device_arrays(
arrays=out_arrays, sharding=replicated_sharding, shape=(),
)
data_array = jax.make_array_from_callback(
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
sharding=replicated_sharding,
data_callback=lambda _: np.frombuffer(
s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)),
dtype=np.uint8,
),
)
return len_array, data_array
def _deserialize_specs(
serialized_specs: tuple[jax.Array, ...],
serialized_specs: jax.Array,
) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]:
"""Deserializes the specs from the serialized specs.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
"""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
len_array, data_array = serialized_specs
length = int(len_array.addressable_shards[0].data)
data = np.asarray(data_array.addressable_shards[0].data).tobytes()
return _deserialize(data[:length])
data_array = serialized_specs.addressable_shards[0].data
data = base64.b64decode(data_array.item().encode("ascii"))
return _deserialize(data)

View File

@ -66,6 +66,14 @@ _count_colocated_python_specialization_cache_miss = jtu.count_events(
class ColocatedPythonTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
self.skipTest(
"Serialization in Colocated Python needs StringDType, and thus"
" requires NumPy 2.0.0 or later"
)
def testMakeColocatedPythonProgram(self):
def add_one(x):
return x + 1
@ -382,8 +390,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
del colocated_python._testing_global_state
def testStringProcessing(self):
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
self.skipTest("StringDType requires NumPy 2.0.0 or later")
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 2:
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")
@ -425,8 +431,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
)
def testBinaryDataProcessing(self):
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
self.skipTest("StringDType requires NumPy 2.0.0 or later")
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 1:
self.skipTest("Need at least one CPU devices")