mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Add Python binding for building a colocated Python program
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a colocated Python program. This Python binding will be used internally in the colocated Python API implementation. The API does not yet compile the program into an executable, which will be added separately. PiperOrigin-RevId: 700443656
This commit is contained in:
parent
6763fcfb4e
commit
bbaec6ea59
@ -1198,5 +1198,6 @@ pytype_library(
|
||||
":util",
|
||||
":xla_bridge",
|
||||
"//jax/_src/lib",
|
||||
"//jax/extend:ifrt_programs",
|
||||
] + py_deps("numpy") + py_deps("cloudpickle"),
|
||||
)
|
||||
|
@ -28,7 +28,8 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import wraps
|
||||
from jax.experimental.colocated_python import func_backend
|
||||
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs
|
||||
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
|
||||
from jax.extend.ifrt_programs import ifrt_programs
|
||||
|
||||
ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]
|
||||
|
||||
@ -141,8 +142,13 @@ def _compile_to_executable(
|
||||
devices: xc.DeviceList,
|
||||
) -> Callable[..., Any]:
|
||||
"""Compiles a Python function into a runtime executable."""
|
||||
# TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an
|
||||
# executable.
|
||||
pickled_function = _serialize(fun)
|
||||
program = ifrt_programs.make_colocated_python_program(
|
||||
name, pickled_function, devices, in_specs_leaves, out_specs_leaves
|
||||
)
|
||||
# TODO(hyeontaek): Compile the program and use the executable.
|
||||
del program
|
||||
|
||||
del name
|
||||
del in_specs_leaves
|
||||
del out_specs_leaves
|
||||
|
@ -1387,6 +1387,7 @@ jax_multiplatform_test(
|
||||
srcs = ["colocated_python_test.py"],
|
||||
deps = [
|
||||
"//jax:experimental_colocated_python",
|
||||
"//jax/extend:ifrt_programs",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,6 +22,8 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
|
||||
from jax.experimental import colocated_python
|
||||
from jax.experimental.colocated_python import func as colocated_python_func
|
||||
from jax.experimental.colocated_python import serialization
|
||||
from jax.extend.ifrt_programs import ifrt_programs
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -77,8 +79,22 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_extension_version < 290:
|
||||
self.skipTest("Requires xla_extension_version >= 290")
|
||||
if xla_extension_version < 298:
|
||||
self.skipTest("Requires xla_extension_version >= 298")
|
||||
|
||||
def testMakeColocatedPythonProgram(self):
|
||||
def add_one(x):
|
||||
return x + 1
|
||||
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
|
||||
aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding)
|
||||
|
||||
pickled_function = serialization._serialize(add_one)
|
||||
program = ifrt_programs.make_colocated_python_program(
|
||||
"add_one", pickled_function, [cpu_devices[0]], [aval], [aval]
|
||||
)
|
||||
del program
|
||||
|
||||
def testSimpleFunction(self):
|
||||
@colocated_python.colocated_python
|
||||
|
Loading…
x
Reference in New Issue
Block a user