[JAX] Add Python binding for building a colocated Python program

This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.

PiperOrigin-RevId: 700443656
This commit is contained in:
Hyeontaek Lim 2024-11-26 13:30:31 -08:00 committed by jax authors
parent 6763fcfb4e
commit bbaec6ea59
4 changed files with 29 additions and 5 deletions

View File

@ -1198,5 +1198,6 @@ pytype_library(
":util",
":xla_bridge",
"//jax/_src/lib",
"//jax/extend:ifrt_programs",
] + py_deps("numpy") + py_deps("cloudpickle"),
)

View File

@ -28,7 +28,8 @@ from jax._src.lib import xla_client as xc
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
from jax.experimental.colocated_python import func_backend
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
from jax.extend.ifrt_programs import ifrt_programs
ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]
@ -141,8 +142,13 @@ def _compile_to_executable(
devices: xc.DeviceList,
) -> Callable[..., Any]:
"""Compiles a Python function into a runtime executable."""
# TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an
# executable.
pickled_function = _serialize(fun)
program = ifrt_programs.make_colocated_python_program(
name, pickled_function, devices, in_specs_leaves, out_specs_leaves
)
# TODO(hyeontaek): Compile the program and use the executable.
del program
del name
del in_specs_leaves
del out_specs_leaves

View File

@ -1387,6 +1387,7 @@ jax_multiplatform_test(
srcs = ["colocated_python_test.py"],
deps = [
"//jax:experimental_colocated_python",
"//jax/extend:ifrt_programs",
],
)

View File

@ -22,6 +22,8 @@ from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax.experimental import colocated_python
from jax.experimental.colocated_python import func as colocated_python_func
from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs
import jax.numpy as jnp
import numpy as np
@ -77,8 +79,22 @@ class ColocatedPythonTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_extension_version < 290:
self.skipTest("Requires xla_extension_version >= 290")
if xla_extension_version < 298:
self.skipTest("Requires xla_extension_version >= 298")
def testMakeColocatedPythonProgram(self):
def add_one(x):
return x + 1
cpu_devices = _colocated_cpu_devices(jax.local_devices())
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding)
pickled_function = serialization._serialize(add_one)
program = ifrt_programs.make_colocated_python_program(
"add_one", pickled_function, [cpu_devices[0]], [aval], [aval]
)
del program
def testSimpleFunction(self):
@colocated_python.colocated_python