[JAX] Implement an initial object API for colocated Python

Colocated Python adds `colocated_python_class`. This API wraps a user-defined
class for automatic remoting of object construction/destruction and method calls:

* An object will be initialized on the backend. At least for now,
initialization is deferred until the first method is called; at this point,
colocated Python knows what devices the objects should be accessible and thus
it can construct the object(s).

* When an object method is called, the method call runs as a colocated Python
function call on the backend.

* When the object is destroyed (either by reaching a zero reference count or
through Python GC), destruction also runs as a colocated Python function call
and destroys all objects from the backend.

This change provides an intial API implementation. Main limitations are as
follows:

* The methods of a colocated Python class does not support specialization.
Calling it requires at least one argument.

* Colocated Python objects cannot reference or interact with each other on the
controller or on the colocated Python backend.

These limitations will be lifted as the object API implementation is improved.

PiperOrigin-RevId: 729629265
This commit is contained in:
Hyeontaek Lim 2025-02-21 12:57:49 -08:00 committed by jax authors
parent 6c83d43635
commit 96b7dbabdc
6 changed files with 413 additions and 4 deletions

View File

@ -1239,6 +1239,8 @@ pytype_library(
"experimental/colocated_python/api.py",
"experimental/colocated_python/func.py",
"experimental/colocated_python/func_backend.py",
"experimental/colocated_python/obj.py",
"experimental/colocated_python/obj_backend.py",
"experimental/colocated_python/serialization.py",
],
visibility = ["//visibility:public"],

View File

@ -20,4 +20,5 @@
from jax.experimental.colocated_python.api import (
colocated_cpu_devices as colocated_cpu_devices,
colocated_python as colocated_python,
colocated_python_class as colocated_python_class,
)

View File

@ -16,11 +16,12 @@
from __future__ import annotations
import collections
from typing import Any, Callable, Sequence
from typing import Any, Callable, Sequence, Type
import jax
from jax._src import api_util
from jax.experimental.colocated_python.func import make_callable
from jax.experimental.colocated_python.obj import wrap_class
def colocated_cpu_devices(
@ -48,8 +49,13 @@ def colocated_cpu_devices(
return colocated_cpu_devices
def colocated_python(fun: Callable[..., Any]) -> Any:
"""Executes the given Python function on the same device as the arguments."""
def colocated_python(fun: Callable[..., Any]) -> Callable[..., Any]:
"""Executes the given Python function on the same devices as the arguments."""
return make_callable(
fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun)
)
def colocated_python_class(cls: Type[object]) -> Type[object]:
"""Executes the given Python class methods on the same devices as the arguments."""
return wrap_class(cls, api_util.fun_sourceinfo(cls))

View File

@ -0,0 +1,174 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python object API implementation."""
from __future__ import annotations
import inspect
import random
import threading
from typing import Any, Callable, Type
import jax
from jax._src import api_util
from jax._src import tree_util
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
from jax.experimental.colocated_python import func
from jax.experimental.colocated_python import obj_backend
class _InstanceRegistry:
"""Registry of object instances."""
def __init__(self) -> None:
self._lock = threading.Lock()
self._storage: dict[int, set[jax.Device]] = {}
def new_instance(self) -> int:
"""Returns a new unique identifier for an instance on the controller."""
uid = random.getrandbits(63)
with self._lock:
assert uid not in self._storage
self._storage[uid] = set()
return uid
def update_devices(self, uid: int, device_set: set[jax.Device]) -> None:
"""Updates the set of devices on which it is live."""
with self._lock:
self._storage[uid] |= device_set
def pop_instance(self, uid: int) -> set[jax.Device]:
"""Removes the instance and returns the set of devices on which it is live."""
with self._lock:
return self._storage.pop(uid)
SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry()
@jax.util.cache(max_size=4096)
def _update_instance_devices(
uid: int, shardings: tuple[jax.sharding.Sharding, ...]
) -> None:
"""Caching version of _InstanceRegistry.update_devices()."""
device_set = set()
for sharding in shardings:
device_set |= sharding.device_set
SINGLETON_INSTANCE_REGISTRY.update_devices(uid, device_set)
def _make_method(
cls: Type[object],
cls_sourceinfo: str | None,
uid: int,
init_args: tuple[Any, ...],
init_kwargs: dict[str, Any],
method_name: str,
original_method: Callable[..., Any],
):
# Initializer to use when the object is not present in the backend.
def initializer() -> object:
return cls(*init_args, **init_kwargs)
# Method to call on the backend.
def method(*args, **kwargs):
obj = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(uid, initializer)
return getattr(obj, method_name)(*args, **kwargs)
# Colocated Python callable for the controller.
callable = func.make_callable(
method,
cls_sourceinfo,
api_util.fun_signature(original_method),
)
# Outer wrapper of the method for the controller. It tracks
@api_boundary
def method_wrapper(*args, **kwargs):
if not args:
raise NotImplementedError(
'Method calls with no arguments are not yet supported.'
)
# TODO(hyeontaek): Instead of inspecting argument shardings, get shardings
# from final specialization of the function. This may require lowering
# `_update_instance_devices` into the function API.
args_leaves = tree_util.tree_leaves((args, kwargs))
shardings_leaves = tuple(func._get_spec(x).sharding for x in args_leaves)
_update_instance_devices(uid, shardings_leaves)
return callable(*args, **kwargs)
method_wrapper = wraps(original_method)(method_wrapper)
return method_wrapper
def wrap_class(
cls: Type[object],
cls_sourceinfo: str | None,
) -> Type[object]:
class WrappedClass:
@wraps(cls.__init__)
def __init__(self, *init_args, **init_kwargs) -> None:
uid = self._colocated_python_uid = (
SINGLETON_INSTANCE_REGISTRY.new_instance()
)
for attr_name in dir(cls):
original_member = getattr(cls, attr_name)
if not inspect.isfunction(original_member):
continue
# WrappedClass defines lazy initialization and colocated deletion logic.
# WrappedClass is not serializable even if the original class may be
# serializable.
if attr_name in ('__init__', '__del__', '__reduce__', '__reduce_ex__'):
continue
method = _make_method(
cls,
cls_sourceinfo,
uid,
init_args,
init_kwargs,
attr_name,
original_member,
)
# TODO(hyeontaek): Support method specialization similar to function
# specialization.
setattr(self, attr_name, method)
def __del__(self) -> None:
uid = self._colocated_python_uid
devices = SINGLETON_INSTANCE_REGISTRY.pop_instance(uid)
if devices:
def remove_object() -> None:
obj_backend.SINGLETON_OBJECT_STORE.remove(uid)
# TODO(hyeontaek): Request "best-effort" non-SPMD execution that tries
# to run this function on any healthy processes instead of failing when
# any process of the execution is unhealthy.
destructor = func.make_callable(
remove_object,
cls_sourceinfo,
None,
)
destructor = destructor.specialize( # type: ignore[attribute-error]
devices=devices
)
destructor()
WrappedClass.__name__ = cls.__name__
WrappedClass.__doc__ = cls.__doc__
return WrappedClass

View File

@ -0,0 +1,77 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Backend for colocated_python.obj."""
from __future__ import annotations
import dataclasses
import threading
from typing import Any, Callable
@dataclasses.dataclass(frozen=True)
class _ObjectState:
is_being_initialized: bool
exc: Exception | None = None
obj: Any = None
class _ObjectStore:
"""Stores live objects."""
def __init__(self) -> None:
self._lock = threading.Condition()
self._storage: dict[int, _ObjectState] = {}
def get_or_create(self, uid: int, initializer: Callable[[], Any]) -> Any:
"""Returns the object associated with the given uid, or creates it if it does not exist."""
with self._lock:
if uid in self._storage:
while True:
state = self._storage[uid]
if state.is_being_initialized:
# Another thread is initializing the object. Wait for it to finish.
self._lock.wait()
else:
break
if state.exc is not None:
raise state.exc
return state.obj
self._storage[uid] = _ObjectState(is_being_initialized=True)
try:
obj = initializer()
except Exception as exc:
with self._lock:
self._storage[uid] = _ObjectState(is_being_initialized=False, exc=exc)
self._lock.notify_all()
raise
with self._lock:
self._storage[uid] = _ObjectState(is_being_initialized=False, obj=obj)
self._lock.notify_all()
return obj
def remove(self, uid: int) -> None:
"""Removes the object associated with the given uid."""
with self._lock:
state = self._storage.pop(uid)
# The object will be deleted without holding the lock.
del state
SINGLETON_OBJECT_STORE = _ObjectStore()

View File

@ -101,7 +101,7 @@ class ColocatedPythonTest(jtu.JaxTestCase):
self.assertEqual(out, np.array(2))
self.assertEqual(count(), 1)
def testSimpleFunctioWithTree(self):
def testSimpleFunctionWithTree(self):
@colocated_python.colocated_python
def add_one(x):
return jax.tree.map(lambda x: x + 1, x)
@ -496,6 +496,155 @@ class ColocatedPythonTest(jtu.JaxTestCase):
make_zero = make_zero.specialize(devices=cpu_devices)
jax.block_until_ready(make_zero())
def testObjectLifecycle(self):
cpu_devices = _colocated_cpu_devices(jax.local_devices())
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
@colocated_python.colocated_python_class
class Object:
def __init__(self) -> None:
colocated_python._testing_initialized = True
def __del__(self) -> None:
colocated_python._testing_destroyed = True
# TODO(hyeontaek): Support method calls with no arguments and remove
# `x` parameter.
def echo(self, x: jax.Array) -> jax.Array:
return x
@colocated_python.colocated_python
def check_initialized() -> jax.Array:
initialized = getattr(colocated_python, "_testing_initialized", False)
return jax.device_put(np.array(initialized), sharding)
@colocated_python.colocated_python
def check_destroyed() -> jax.Array:
destroyed = getattr(colocated_python, "_testing_destroyed", False)
return jax.device_put(np.array(destroyed), sharding)
@colocated_python.colocated_python
def cleanup():
if "_testing_initialized" in colocated_python.__dict__:
del colocated_python._testing_initialized
if "_testing_destroyed" in colocated_python.__dict__:
del colocated_python._testing_destroyed
check_initialized = check_initialized.specialize(devices=cpu_devices[:1])
check_destroyed = check_destroyed.specialize(devices=cpu_devices[:1])
cleanup = cleanup.specialize(devices=cpu_devices[:1])
try:
# Object initialization is deferred until the first method call.
obj = Object()
self.assertEqual(jax.device_get(check_initialized()), False)
self.assertEqual(jax.device_get(check_destroyed()), False)
# If the object is destroyed without any method calls, the object is
# destroyed without initialization.
del obj
self.assertEqual(jax.device_get(check_initialized()), False)
self.assertEqual(jax.device_get(check_destroyed()), False)
finally:
cleanup()
try:
# Object initialization is deferred until the first method call.
obj = Object()
self.assertEqual(jax.device_get(check_initialized()), False)
self.assertEqual(jax.device_get(check_destroyed()), False)
# The first method call on a process triggers object initialization there.
x = np.array(1)
x = jax.device_put(x, sharding)
obj.echo(x)
self.assertEqual(jax.device_get(check_initialized()), True)
self.assertEqual(jax.device_get(check_destroyed()), False)
del obj
self.assertEqual(jax.device_get(check_initialized()), True)
self.assertEqual(jax.device_get(check_destroyed()), True)
finally:
cleanup()
def testStatefulObject(self):
cpu_devices = _colocated_cpu_devices(jax.local_devices())
@colocated_python.colocated_python_class
class Value:
def __init__(self, initial_value: np.ndarray) -> None:
self.value = initial_value
def add(self, x: jax.Array) -> jax.Array:
self.value += np.asarray(x)
return jax.device_put(self.value, x.sharding)
# TODO(hyeontaek): Support method calls with no arguments and remove
# `x` parameter.
def fetch(self, x: jax.Array) -> jax.Array:
return jax.device_put(self.value, x.sharding)
value = Value(np.array(5))
x = np.array(1)
x = jax.device_put(x, cpu_devices[0])
out = jax.device_get(value.add(x))
self.assertEqual(out, np.array(6))
out = jax.device_get(value.add(x))
self.assertEqual(out, np.array(7))
out = jax.device_get(value.fetch(x))
self.assertEqual(out, np.array(7))
def testObjectWithCapturedSharding(self):
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 2:
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")
mesh = jax.sharding.Mesh(cpu_devices[0:2], "x")
sharding1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
sharding2 = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("x")
)
@colocated_python.colocated_python_class
class Value:
def __init__(self, initial_value: np.ndarray) -> None:
self.value = initial_value
# Captured shardings in the closure.
self.sharding1 = sharding1
self.sharding2 = sharding2
def add_sharding1(self, x: jax.Array) -> jax.Array:
self.value += np.asarray(x)
return jax.device_put(self.value, self.sharding1)
def add_sharding2(self, x: jax.Array) -> jax.Array:
self.value += np.asarray(x)
return jax.device_put(self.value, self.sharding2)
value = Value(np.array([5, 15]))
x = np.array([1])
x = jax.device_put(
x, jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
)
out = value.add_sharding1(x)
self.assertEqual(out.sharding, sharding1)
out = jax.device_get(out)
self.assertArraysEqual(out, np.array([6, 16]))
out = value.add_sharding2(x)
self.assertEqual(out.sharding, sharding2)
out = jax.device_get(out)
self.assertArraysEqual(out, np.array([7, 17]))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())