Adds support for string and binary data processing in Colocated Python.

PiperOrigin-RevId: 727048049
This commit is contained in:
jax authors 2025-02-14 13:38:41 -08:00
parent 36d7f8530b
commit 9b6b569f3c

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import struct
import tempfile
import threading
import time
@ -23,6 +25,7 @@ from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax.experimental import colocated_python
from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs
@ -37,6 +40,7 @@ try:
except (ModuleNotFoundError, ImportError):
raise unittest.SkipTest("tests depend on cloudpickle library")
def _colocated_cpu_devices(
devices: Sequence[jax.Device],
) -> Sequence[jax.Device]:
@ -378,6 +382,99 @@ class ColocatedPythonTest(jtu.JaxTestCase):
if "_testing_global_state" in colocated_python.__dict__:
del colocated_python._testing_global_state
def testStringProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 2:
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")
@colocated_python.colocated_python
def f(x):
out_arrays = []
upper_caser = np.vectorize(
lambda x: x.upper(), otypes=[np.dtypes.StringDType()]
)
for shard in x.addressable_shards:
np_array = jax.device_get(shard.data)
out_np_array = upper_caser(np_array)
out_arrays.append(jax.device_put(out_np_array, device=shard.device))
return jax.make_array_from_single_device_arrays(
sharding=x.sharding, shape=x.shape, arrays=out_arrays
)
# Make a string array.
numpy_string_array = np.array(
[["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore
)
mesh = jax.sharding.Mesh(
np.array(cpu_devices[:2]).reshape((2, 1)), ("x", "y")
)
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x"))
x = jax.device_put(numpy_string_array, device=sharding)
# Run the colocated Python function with the string array as input.
out = f(x)
out = jax.device_get(out)
# Should have gotten the strings with all upper case letters.
np.testing.assert_equal(
out,
np.array(
[["ABCD", "EFGH"], ["IJKL", "MNOP"]], dtype=np.dtypes.StringDType()
),
)
def testBinaryDataProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 1:
self.skipTest("Need at least one CPU devices")
@colocated_python.colocated_python
def f(x):
out_arrays = []
for shard in x.addressable_shards:
np_array = jax.device_get(shard.data)
input_ints = struct.unpack(
"<ii", base64.b64decode(np_array[0].encode("ascii"))
)
output_string = base64.b64encode(
struct.pack("<ii", input_ints[0] + 1, input_ints[1] + 1)
).decode("ascii")
out_np_array = np.array([output_string], dtype=np.dtypes.StringDType())
out_arrays.append(jax.device_put(out_np_array, device=shard.device))
out = jax.make_array_from_single_device_arrays(
sharding=x.sharding, shape=x.shape, arrays=out_arrays
)
return out
# Make the input array with the binary data that packs two integers as ascii
# string.
input_string = base64.b64encode(struct.pack("<ii", 1001, 1002)).decode(
"ascii"
)
numpy_string_array = np.array([input_string], dtype=np.dtypes.StringDType())
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
x = jax.device_put(numpy_string_array, device=sharding)
out = f(x)
out = jax.device_get(out)
# Should have gotten the binary data with the incremented integers as a
# ascii string.
out_ints = struct.unpack("<ii", base64.b64decode(out[0].encode("ascii")))
self.assertEqual(out_ints[0], 1002)
self.assertEqual(out_ints[1], 1003)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())