mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Implement an initial object API for colocated Python
Colocated Python adds `colocated_python_class`. This API wraps a user-defined class for automatic remoting of object construction/destruction and method calls: * An object will be initialized on the backend. At least for now, initialization is deferred until the first method is called; at this point, colocated Python knows what devices the objects should be accessible and thus it can construct the object(s). * When an object method is called, the method call runs as a colocated Python function call on the backend. * When the object is destroyed (either by reaching a zero reference count or through Python GC), destruction also runs as a colocated Python function call and destroys all objects from the backend. This change provides an intial API implementation. Main limitations are as follows: * The methods of a colocated Python class does not support specialization. Calling it requires at least one argument. * Colocated Python objects cannot reference or interact with each other on the controller or on the colocated Python backend. These limitations will be lifted as the object API implementation is improved. PiperOrigin-RevId: 729629265
This commit is contained in:
parent
6c83d43635
commit
96b7dbabdc
@ -1239,6 +1239,8 @@ pytype_library(
|
||||
"experimental/colocated_python/api.py",
|
||||
"experimental/colocated_python/func.py",
|
||||
"experimental/colocated_python/func_backend.py",
|
||||
"experimental/colocated_python/obj.py",
|
||||
"experimental/colocated_python/obj_backend.py",
|
||||
"experimental/colocated_python/serialization.py",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -20,4 +20,5 @@
|
||||
from jax.experimental.colocated_python.api import (
|
||||
colocated_cpu_devices as colocated_cpu_devices,
|
||||
colocated_python as colocated_python,
|
||||
colocated_python_class as colocated_python_class,
|
||||
)
|
||||
|
@ -16,11 +16,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from typing import Any, Callable, Sequence
|
||||
from typing import Any, Callable, Sequence, Type
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax.experimental.colocated_python.func import make_callable
|
||||
from jax.experimental.colocated_python.obj import wrap_class
|
||||
|
||||
|
||||
def colocated_cpu_devices(
|
||||
@ -48,8 +49,13 @@ def colocated_cpu_devices(
|
||||
return colocated_cpu_devices
|
||||
|
||||
|
||||
def colocated_python(fun: Callable[..., Any]) -> Any:
|
||||
"""Executes the given Python function on the same device as the arguments."""
|
||||
def colocated_python(fun: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""Executes the given Python function on the same devices as the arguments."""
|
||||
return make_callable(
|
||||
fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun)
|
||||
)
|
||||
|
||||
|
||||
def colocated_python_class(cls: Type[object]) -> Type[object]:
|
||||
"""Executes the given Python class methods on the same devices as the arguments."""
|
||||
return wrap_class(cls, api_util.fun_sourceinfo(cls))
|
||||
|
174
jax/experimental/colocated_python/obj.py
Normal file
174
jax/experimental/colocated_python/obj.py
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Colocated Python object API implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import random
|
||||
import threading
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import tree_util
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import wraps
|
||||
from jax.experimental.colocated_python import func
|
||||
from jax.experimental.colocated_python import obj_backend
|
||||
|
||||
|
||||
class _InstanceRegistry:
|
||||
"""Registry of object instances."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._storage: dict[int, set[jax.Device]] = {}
|
||||
|
||||
def new_instance(self) -> int:
|
||||
"""Returns a new unique identifier for an instance on the controller."""
|
||||
uid = random.getrandbits(63)
|
||||
with self._lock:
|
||||
assert uid not in self._storage
|
||||
self._storage[uid] = set()
|
||||
return uid
|
||||
|
||||
def update_devices(self, uid: int, device_set: set[jax.Device]) -> None:
|
||||
"""Updates the set of devices on which it is live."""
|
||||
with self._lock:
|
||||
self._storage[uid] |= device_set
|
||||
|
||||
def pop_instance(self, uid: int) -> set[jax.Device]:
|
||||
"""Removes the instance and returns the set of devices on which it is live."""
|
||||
with self._lock:
|
||||
return self._storage.pop(uid)
|
||||
|
||||
|
||||
SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry()
|
||||
|
||||
|
||||
@jax.util.cache(max_size=4096)
|
||||
def _update_instance_devices(
|
||||
uid: int, shardings: tuple[jax.sharding.Sharding, ...]
|
||||
) -> None:
|
||||
"""Caching version of _InstanceRegistry.update_devices()."""
|
||||
device_set = set()
|
||||
for sharding in shardings:
|
||||
device_set |= sharding.device_set
|
||||
SINGLETON_INSTANCE_REGISTRY.update_devices(uid, device_set)
|
||||
|
||||
|
||||
def _make_method(
|
||||
cls: Type[object],
|
||||
cls_sourceinfo: str | None,
|
||||
uid: int,
|
||||
init_args: tuple[Any, ...],
|
||||
init_kwargs: dict[str, Any],
|
||||
method_name: str,
|
||||
original_method: Callable[..., Any],
|
||||
):
|
||||
# Initializer to use when the object is not present in the backend.
|
||||
def initializer() -> object:
|
||||
return cls(*init_args, **init_kwargs)
|
||||
|
||||
# Method to call on the backend.
|
||||
def method(*args, **kwargs):
|
||||
obj = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(uid, initializer)
|
||||
return getattr(obj, method_name)(*args, **kwargs)
|
||||
|
||||
# Colocated Python callable for the controller.
|
||||
callable = func.make_callable(
|
||||
method,
|
||||
cls_sourceinfo,
|
||||
api_util.fun_signature(original_method),
|
||||
)
|
||||
|
||||
# Outer wrapper of the method for the controller. It tracks
|
||||
@api_boundary
|
||||
def method_wrapper(*args, **kwargs):
|
||||
if not args:
|
||||
raise NotImplementedError(
|
||||
'Method calls with no arguments are not yet supported.'
|
||||
)
|
||||
# TODO(hyeontaek): Instead of inspecting argument shardings, get shardings
|
||||
# from final specialization of the function. This may require lowering
|
||||
# `_update_instance_devices` into the function API.
|
||||
args_leaves = tree_util.tree_leaves((args, kwargs))
|
||||
shardings_leaves = tuple(func._get_spec(x).sharding for x in args_leaves)
|
||||
_update_instance_devices(uid, shardings_leaves)
|
||||
return callable(*args, **kwargs)
|
||||
|
||||
method_wrapper = wraps(original_method)(method_wrapper)
|
||||
return method_wrapper
|
||||
|
||||
|
||||
def wrap_class(
|
||||
cls: Type[object],
|
||||
cls_sourceinfo: str | None,
|
||||
) -> Type[object]:
|
||||
class WrappedClass:
|
||||
|
||||
@wraps(cls.__init__)
|
||||
def __init__(self, *init_args, **init_kwargs) -> None:
|
||||
uid = self._colocated_python_uid = (
|
||||
SINGLETON_INSTANCE_REGISTRY.new_instance()
|
||||
)
|
||||
for attr_name in dir(cls):
|
||||
original_member = getattr(cls, attr_name)
|
||||
if not inspect.isfunction(original_member):
|
||||
continue
|
||||
|
||||
# WrappedClass defines lazy initialization and colocated deletion logic.
|
||||
# WrappedClass is not serializable even if the original class may be
|
||||
# serializable.
|
||||
if attr_name in ('__init__', '__del__', '__reduce__', '__reduce_ex__'):
|
||||
continue
|
||||
|
||||
method = _make_method(
|
||||
cls,
|
||||
cls_sourceinfo,
|
||||
uid,
|
||||
init_args,
|
||||
init_kwargs,
|
||||
attr_name,
|
||||
original_member,
|
||||
)
|
||||
# TODO(hyeontaek): Support method specialization similar to function
|
||||
# specialization.
|
||||
setattr(self, attr_name, method)
|
||||
|
||||
def __del__(self) -> None:
|
||||
uid = self._colocated_python_uid
|
||||
devices = SINGLETON_INSTANCE_REGISTRY.pop_instance(uid)
|
||||
if devices:
|
||||
|
||||
def remove_object() -> None:
|
||||
obj_backend.SINGLETON_OBJECT_STORE.remove(uid)
|
||||
|
||||
# TODO(hyeontaek): Request "best-effort" non-SPMD execution that tries
|
||||
# to run this function on any healthy processes instead of failing when
|
||||
# any process of the execution is unhealthy.
|
||||
destructor = func.make_callable(
|
||||
remove_object,
|
||||
cls_sourceinfo,
|
||||
None,
|
||||
)
|
||||
destructor = destructor.specialize( # type: ignore[attribute-error]
|
||||
devices=devices
|
||||
)
|
||||
destructor()
|
||||
|
||||
WrappedClass.__name__ = cls.__name__
|
||||
WrappedClass.__doc__ = cls.__doc__
|
||||
return WrappedClass
|
77
jax/experimental/colocated_python/obj_backend.py
Normal file
77
jax/experimental/colocated_python/obj_backend.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Backend for colocated_python.obj."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _ObjectState:
|
||||
is_being_initialized: bool
|
||||
exc: Exception | None = None
|
||||
obj: Any = None
|
||||
|
||||
|
||||
class _ObjectStore:
|
||||
"""Stores live objects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Condition()
|
||||
self._storage: dict[int, _ObjectState] = {}
|
||||
|
||||
def get_or_create(self, uid: int, initializer: Callable[[], Any]) -> Any:
|
||||
"""Returns the object associated with the given uid, or creates it if it does not exist."""
|
||||
with self._lock:
|
||||
if uid in self._storage:
|
||||
while True:
|
||||
state = self._storage[uid]
|
||||
if state.is_being_initialized:
|
||||
# Another thread is initializing the object. Wait for it to finish.
|
||||
self._lock.wait()
|
||||
else:
|
||||
break
|
||||
|
||||
if state.exc is not None:
|
||||
raise state.exc
|
||||
return state.obj
|
||||
|
||||
self._storage[uid] = _ObjectState(is_being_initialized=True)
|
||||
|
||||
try:
|
||||
obj = initializer()
|
||||
except Exception as exc:
|
||||
with self._lock:
|
||||
self._storage[uid] = _ObjectState(is_being_initialized=False, exc=exc)
|
||||
self._lock.notify_all()
|
||||
raise
|
||||
|
||||
with self._lock:
|
||||
self._storage[uid] = _ObjectState(is_being_initialized=False, obj=obj)
|
||||
self._lock.notify_all()
|
||||
return obj
|
||||
|
||||
def remove(self, uid: int) -> None:
|
||||
"""Removes the object associated with the given uid."""
|
||||
with self._lock:
|
||||
state = self._storage.pop(uid)
|
||||
|
||||
# The object will be deleted without holding the lock.
|
||||
del state
|
||||
|
||||
|
||||
SINGLETON_OBJECT_STORE = _ObjectStore()
|
@ -101,7 +101,7 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def testSimpleFunctioWithTree(self):
|
||||
def testSimpleFunctionWithTree(self):
|
||||
@colocated_python.colocated_python
|
||||
def add_one(x):
|
||||
return jax.tree.map(lambda x: x + 1, x)
|
||||
@ -496,6 +496,155 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
make_zero = make_zero.specialize(devices=cpu_devices)
|
||||
jax.block_until_ready(make_zero())
|
||||
|
||||
def testObjectLifecycle(self):
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
|
||||
|
||||
@colocated_python.colocated_python_class
|
||||
class Object:
|
||||
|
||||
def __init__(self) -> None:
|
||||
colocated_python._testing_initialized = True
|
||||
|
||||
def __del__(self) -> None:
|
||||
colocated_python._testing_destroyed = True
|
||||
|
||||
# TODO(hyeontaek): Support method calls with no arguments and remove
|
||||
# `x` parameter.
|
||||
def echo(self, x: jax.Array) -> jax.Array:
|
||||
return x
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def check_initialized() -> jax.Array:
|
||||
initialized = getattr(colocated_python, "_testing_initialized", False)
|
||||
return jax.device_put(np.array(initialized), sharding)
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def check_destroyed() -> jax.Array:
|
||||
destroyed = getattr(colocated_python, "_testing_destroyed", False)
|
||||
return jax.device_put(np.array(destroyed), sharding)
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def cleanup():
|
||||
if "_testing_initialized" in colocated_python.__dict__:
|
||||
del colocated_python._testing_initialized
|
||||
if "_testing_destroyed" in colocated_python.__dict__:
|
||||
del colocated_python._testing_destroyed
|
||||
|
||||
check_initialized = check_initialized.specialize(devices=cpu_devices[:1])
|
||||
check_destroyed = check_destroyed.specialize(devices=cpu_devices[:1])
|
||||
cleanup = cleanup.specialize(devices=cpu_devices[:1])
|
||||
|
||||
try:
|
||||
# Object initialization is deferred until the first method call.
|
||||
obj = Object()
|
||||
self.assertEqual(jax.device_get(check_initialized()), False)
|
||||
self.assertEqual(jax.device_get(check_destroyed()), False)
|
||||
|
||||
# If the object is destroyed without any method calls, the object is
|
||||
# destroyed without initialization.
|
||||
del obj
|
||||
self.assertEqual(jax.device_get(check_initialized()), False)
|
||||
self.assertEqual(jax.device_get(check_destroyed()), False)
|
||||
finally:
|
||||
cleanup()
|
||||
|
||||
try:
|
||||
# Object initialization is deferred until the first method call.
|
||||
obj = Object()
|
||||
self.assertEqual(jax.device_get(check_initialized()), False)
|
||||
self.assertEqual(jax.device_get(check_destroyed()), False)
|
||||
|
||||
# The first method call on a process triggers object initialization there.
|
||||
x = np.array(1)
|
||||
x = jax.device_put(x, sharding)
|
||||
obj.echo(x)
|
||||
self.assertEqual(jax.device_get(check_initialized()), True)
|
||||
self.assertEqual(jax.device_get(check_destroyed()), False)
|
||||
|
||||
del obj
|
||||
self.assertEqual(jax.device_get(check_initialized()), True)
|
||||
self.assertEqual(jax.device_get(check_destroyed()), True)
|
||||
finally:
|
||||
cleanup()
|
||||
|
||||
def testStatefulObject(self):
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
|
||||
@colocated_python.colocated_python_class
|
||||
class Value:
|
||||
|
||||
def __init__(self, initial_value: np.ndarray) -> None:
|
||||
self.value = initial_value
|
||||
|
||||
def add(self, x: jax.Array) -> jax.Array:
|
||||
self.value += np.asarray(x)
|
||||
return jax.device_put(self.value, x.sharding)
|
||||
|
||||
# TODO(hyeontaek): Support method calls with no arguments and remove
|
||||
# `x` parameter.
|
||||
def fetch(self, x: jax.Array) -> jax.Array:
|
||||
return jax.device_put(self.value, x.sharding)
|
||||
|
||||
value = Value(np.array(5))
|
||||
|
||||
x = np.array(1)
|
||||
x = jax.device_put(x, cpu_devices[0])
|
||||
|
||||
out = jax.device_get(value.add(x))
|
||||
self.assertEqual(out, np.array(6))
|
||||
|
||||
out = jax.device_get(value.add(x))
|
||||
self.assertEqual(out, np.array(7))
|
||||
|
||||
out = jax.device_get(value.fetch(x))
|
||||
self.assertEqual(out, np.array(7))
|
||||
|
||||
def testObjectWithCapturedSharding(self):
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
if len(cpu_devices) < 2:
|
||||
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")
|
||||
|
||||
mesh = jax.sharding.Mesh(cpu_devices[0:2], "x")
|
||||
sharding1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
sharding2 = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec("x")
|
||||
)
|
||||
|
||||
@colocated_python.colocated_python_class
|
||||
class Value:
|
||||
|
||||
def __init__(self, initial_value: np.ndarray) -> None:
|
||||
self.value = initial_value
|
||||
# Captured shardings in the closure.
|
||||
self.sharding1 = sharding1
|
||||
self.sharding2 = sharding2
|
||||
|
||||
def add_sharding1(self, x: jax.Array) -> jax.Array:
|
||||
self.value += np.asarray(x)
|
||||
return jax.device_put(self.value, self.sharding1)
|
||||
|
||||
def add_sharding2(self, x: jax.Array) -> jax.Array:
|
||||
self.value += np.asarray(x)
|
||||
return jax.device_put(self.value, self.sharding2)
|
||||
|
||||
value = Value(np.array([5, 15]))
|
||||
|
||||
x = np.array([1])
|
||||
x = jax.device_put(
|
||||
x, jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
)
|
||||
|
||||
out = value.add_sharding1(x)
|
||||
self.assertEqual(out.sharding, sharding1)
|
||||
out = jax.device_get(out)
|
||||
self.assertArraysEqual(out, np.array([6, 16]))
|
||||
|
||||
out = value.add_sharding2(x)
|
||||
self.assertEqual(out.sharding, sharding2)
|
||||
out = jax.device_get(out)
|
||||
self.assertArraysEqual(out, np.array([7, 17]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user