diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 8e2883ea4..effca1fe7 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -201,7 +201,7 @@ def _make_output_specs_and_push_result_fun( devices = specialization.devices - def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]: + def lowered_fun(*args, **kwargs) -> jax.Array: result = info.fun(*args, **kwargs) result_leaves, out_treedef = tree_util.tree_flatten(result) out_spec_leaves = tuple(_get_spec(x) for x in result_leaves) diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index bfd5ec2e6..1ca29ab12 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -13,11 +13,9 @@ # limitations under the License. """Colocated Python serialization utilities.""" -# TODO(jmudigonda): Use a string-typed array for output structure when it -# becomes available. Using a fixed uint8 array is only for prototyping. - from __future__ import annotations +import base64 import collections import functools import io @@ -37,12 +35,6 @@ import numpy as np DeviceList = xc.DeviceList -# Hard-coded limit for serialized specs size. -# TODO(jmudigonda): Use a string-typed array for output structure when it -# becomes available. Using a fixed uint8 array is only for prototyping. -_MAX_SERIALIZED_SPECS_SIZE = 1048576 - - @jax.util.cache(max_size=None) def _get_cpu_device_map() -> dict[int, jax.Device]: """Returns a map from a device id to a matching device.""" @@ -185,23 +177,14 @@ def _deserialize(serialized: bytes) -> Any: def _make_specs_for_serialized_specs( devices: DeviceList, -) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]: +) -> api.ShapeDtypeStruct: """Makes output specs for serialized specs.""" - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. mesh = jax.sharding.Mesh(tuple(devices), ("x",)) replicated_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec() ) - return ( - api.ShapeDtypeStruct( - shape=(), dtype=np.int32, sharding=replicated_sharding - ), - api.ShapeDtypeStruct( - shape=(_MAX_SERIALIZED_SPECS_SIZE,), - dtype=np.uint8, - sharding=replicated_sharding, - ), + return api.ShapeDtypeStruct( + shape=(), dtype=np.dtypes.StringDType(), sharding=replicated_sharding # type: ignore ) @@ -209,49 +192,49 @@ def _serialize_specs( specs_treedef: tree_util.PyTreeDef, specs_leaves: tuple[api.ShapeDtypeStruct, ...], devices: DeviceList, -) -> tuple[jax.Array, ...]: - """Serializes the output specs into a tuple of arrays. +) -> jax.Array: + """Serializes the output specs into a jax.Array of string type. DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF colocated_python. See serialize() for details. """ - s = _serialize((specs_treedef, specs_leaves)) - assert ( - len(s) <= _MAX_SERIALIZED_SPECS_SIZE - ), f"Too large serialized spec size: {len(s)}" - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. - mesh = jax.sharding.Mesh(tuple(devices), ("x",)) + if not hasattr(np.dtypes, "StringDType"): + raise TypeError( + "Serializing Colocated Python requires StringDType. Please use" + " numpy to 2.0.0 or later, or explicityly provide an output spec" + " function." + ) + + s_bytes = _serialize((specs_treedef, specs_leaves)) + s_str = base64.b64encode(s_bytes).decode("ascii") + s_np_array = np.array(s_str, dtype=np.dtypes.StringDType()) # type: ignore + + # TODO(jmudigonda): Revisit this when JAX supports HLO sharding for making + # jax.Array via make_array_from_single_device_arrays. We should then use a + # sharding that spans all the execution devices - not just the addressable + # ones. + addressable_devices = devices.addressable_device_list + mesh = jax.sharding.Mesh(tuple(addressable_devices), ("x",)) replicated_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec() ) - len_array = jax.make_array_from_callback( - shape=(), - sharding=replicated_sharding, - data_callback=lambda _: np.array(len(s), dtype=np.int32), + + out_arrays = [ + jax.device_put(s_np_array, device) for device in addressable_devices + ] + return jax.make_array_from_single_device_arrays( + arrays=out_arrays, sharding=replicated_sharding, shape=(), ) - data_array = jax.make_array_from_callback( - shape=(_MAX_SERIALIZED_SPECS_SIZE,), - sharding=replicated_sharding, - data_callback=lambda _: np.frombuffer( - s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)), - dtype=np.uint8, - ), - ) - return len_array, data_array def _deserialize_specs( - serialized_specs: tuple[jax.Array, ...], + serialized_specs: jax.Array, ) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]: """Deserializes the specs from the serialized specs. DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF colocated_python. See serialize() for details. """ - # TODO(jmudigonda): Use a string-typed array for output structure when it - # becomes available. Using a fixed uint8 array is only for prototyping. - len_array, data_array = serialized_specs - length = int(len_array.addressable_shards[0].data) - data = np.asarray(data_array.addressable_shards[0].data).tobytes() - return _deserialize(data[:length]) + data_array = serialized_specs.addressable_shards[0].data + data = base64.b64decode(data_array.item().encode("ascii")) + return _deserialize(data) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index d6abe8bec..52d494904 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -66,6 +66,14 @@ _count_colocated_python_specialization_cache_miss = jtu.count_events( class ColocatedPythonTest(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if np.lib.NumpyVersion(np.__version__) < "2.0.0": + self.skipTest( + "Serialization in Colocated Python needs StringDType, and thus" + " requires NumPy 2.0.0 or later" + ) + def testMakeColocatedPythonProgram(self): def add_one(x): return x + 1 @@ -382,8 +390,6 @@ class ColocatedPythonTest(jtu.JaxTestCase): del colocated_python._testing_global_state def testStringProcessing(self): - if np.lib.NumpyVersion(np.__version__) < "2.0.0": - self.skipTest("StringDType requires NumPy 2.0.0 or later") cpu_devices = _colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") @@ -425,8 +431,6 @@ class ColocatedPythonTest(jtu.JaxTestCase): ) def testBinaryDataProcessing(self): - if np.lib.NumpyVersion(np.__version__) < "2.0.0": - self.skipTest("StringDType requires NumPy 2.0.0 or later") cpu_devices = _colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 1: self.skipTest("Need at least one CPU devices")