mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Adds support for string and binary data processing in Colocated Python.
PiperOrigin-RevId: 727048049
This commit is contained in:
parent
36d7f8530b
commit
9b6b569f3c
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
@ -23,6 +25,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental import colocated_python
|
||||
from jax.experimental.colocated_python import serialization
|
||||
from jax.extend.ifrt_programs import ifrt_programs
|
||||
@ -37,6 +40,7 @@ try:
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
raise unittest.SkipTest("tests depend on cloudpickle library")
|
||||
|
||||
|
||||
def _colocated_cpu_devices(
|
||||
devices: Sequence[jax.Device],
|
||||
) -> Sequence[jax.Device]:
|
||||
@ -378,6 +382,99 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
if "_testing_global_state" in colocated_python.__dict__:
|
||||
del colocated_python._testing_global_state
|
||||
|
||||
def testStringProcessing(self):
|
||||
if xla_extension_version < 315:
|
||||
self.skipTest(
|
||||
"String support for colocated Python requires xla_extension_version"
|
||||
" >= 315"
|
||||
)
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
if len(cpu_devices) < 2:
|
||||
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def f(x):
|
||||
out_arrays = []
|
||||
upper_caser = np.vectorize(
|
||||
lambda x: x.upper(), otypes=[np.dtypes.StringDType()]
|
||||
)
|
||||
for shard in x.addressable_shards:
|
||||
np_array = jax.device_get(shard.data)
|
||||
out_np_array = upper_caser(np_array)
|
||||
out_arrays.append(jax.device_put(out_np_array, device=shard.device))
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
sharding=x.sharding, shape=x.shape, arrays=out_arrays
|
||||
)
|
||||
|
||||
# Make a string array.
|
||||
numpy_string_array = np.array(
|
||||
[["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore
|
||||
)
|
||||
mesh = jax.sharding.Mesh(
|
||||
np.array(cpu_devices[:2]).reshape((2, 1)), ("x", "y")
|
||||
)
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x"))
|
||||
x = jax.device_put(numpy_string_array, device=sharding)
|
||||
|
||||
# Run the colocated Python function with the string array as input.
|
||||
out = f(x)
|
||||
out = jax.device_get(out)
|
||||
|
||||
# Should have gotten the strings with all upper case letters.
|
||||
np.testing.assert_equal(
|
||||
out,
|
||||
np.array(
|
||||
[["ABCD", "EFGH"], ["IJKL", "MNOP"]], dtype=np.dtypes.StringDType()
|
||||
),
|
||||
)
|
||||
|
||||
def testBinaryDataProcessing(self):
|
||||
if xla_extension_version < 315:
|
||||
self.skipTest(
|
||||
"String support for colocated Python requires xla_extension_version"
|
||||
" >= 315"
|
||||
)
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
if len(cpu_devices) < 1:
|
||||
self.skipTest("Need at least one CPU devices")
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def f(x):
|
||||
out_arrays = []
|
||||
for shard in x.addressable_shards:
|
||||
np_array = jax.device_get(shard.data)
|
||||
input_ints = struct.unpack(
|
||||
"<ii", base64.b64decode(np_array[0].encode("ascii"))
|
||||
)
|
||||
output_string = base64.b64encode(
|
||||
struct.pack("<ii", input_ints[0] + 1, input_ints[1] + 1)
|
||||
).decode("ascii")
|
||||
out_np_array = np.array([output_string], dtype=np.dtypes.StringDType())
|
||||
out_arrays.append(jax.device_put(out_np_array, device=shard.device))
|
||||
|
||||
out = jax.make_array_from_single_device_arrays(
|
||||
sharding=x.sharding, shape=x.shape, arrays=out_arrays
|
||||
)
|
||||
return out
|
||||
|
||||
# Make the input array with the binary data that packs two integers as ascii
|
||||
# string.
|
||||
input_string = base64.b64encode(struct.pack("<ii", 1001, 1002)).decode(
|
||||
"ascii"
|
||||
)
|
||||
numpy_string_array = np.array([input_string], dtype=np.dtypes.StringDType())
|
||||
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
|
||||
x = jax.device_put(numpy_string_array, device=sharding)
|
||||
|
||||
out = f(x)
|
||||
out = jax.device_get(out)
|
||||
|
||||
# Should have gotten the binary data with the incremented integers as a
|
||||
# ascii string.
|
||||
out_ints = struct.unpack("<ii", base64.b64decode(out[0].encode("ascii")))
|
||||
self.assertEqual(out_ints[0], 1002)
|
||||
self.assertEqual(out_ints[1], 1003)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user