[JAX] Add the function API of jax.experimental.colocated_python

This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.

This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.

It is currently in an early development stage. Please proceed with a caution
when using the API.

PiperOrigin-RevId: 690705899
This commit is contained in:
Hyeontaek Lim 2024-10-28 12:17:34 -07:00 committed by jax authors
parent 9fd1ef2784
commit 77797f434d
8 changed files with 1010 additions and 0 deletions

View File

@ -1159,3 +1159,24 @@ pytype_library(
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "experimental_colocated_python",
srcs = [
"experimental/colocated_python/__init__.py",
"experimental/colocated_python/api.py",
"experimental/colocated_python/func.py",
"experimental/colocated_python/func_backend.py",
"experimental/colocated_python/serialization.py",
],
visibility = ["//visibility:public"],
deps = [
":api_util",
":jax",
":traceback_util",
":tree_util",
":util",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy") + py_deps("cloudpickle"),
)

View File

@ -0,0 +1,23 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python API."""
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
# pylint: disable=useless-import-alias
from jax.experimental.colocated_python.api import (
colocated_cpu_devices as colocated_cpu_devices,
colocated_python as colocated_python,
)

View File

@ -0,0 +1,59 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python top-level API."""
from __future__ import annotations
import collections
from typing import Any, Callable, Sequence
import jax
from jax._src import api_util
from jax._src.lib import xla_extension_version
from jax.experimental.colocated_python.func import make_callable
def colocated_cpu_devices(
devices: Sequence[jax.Device],
) -> Sequence[jax.Device]:
"""Finds CPU devices colocated with the given devices."""
if xla_extension_version < 290:
raise NotImplementedError("Requires xla_extension_version >= 290")
cpu_devices_by_colocation_id = collections.defaultdict(list)
for device in devices[0].backend._get_all_devices(): # pylint: disable=protected-access
if device.device_kind == "cpu":
cpu_devices_by_colocation_id[device.colocation_id].append(device)
if not cpu_devices_by_colocation_id:
raise ValueError("No CPU devices found")
colocated_cpu_devices = []
for device in devices:
matches = cpu_devices_by_colocation_id[device.colocation_id]
if not matches:
raise ValueError(f"Device {device} has no colocated devices")
elif len(matches) > 1:
raise ValueError(
f"Ambiguous colocated devices; device {device} has"
f" {len(matches)} colocated devices: f{matches}"
)
colocated_cpu_devices.append(matches[0])
return colocated_cpu_devices
def colocated_python(fun: Callable[..., Any]) -> Any:
"""Executes the given Python function on the same device as the arguments."""
return make_callable(
fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun)
)

View File

@ -0,0 +1,417 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python function API implementation."""
from __future__ import annotations
import dataclasses
import inspect
import random
import threading
from typing import Any, Callable, Sequence
import jax
from jax._src import api
from jax._src import tree_util
from jax._src.lib import xla_client as xc
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
from jax.experimental.colocated_python import func_backend
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs
ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]
@dataclasses.dataclass(frozen=True, slots=True)
class FunctionInfo:
"""User function wrapped by colocated_python."""
fun: Callable[..., Any]
fun_sourceinfo: str | None
fun_signature: inspect.Signature | None
@dataclasses.dataclass(frozen=True, slots=True)
class Specialization:
"""Specialization for a colocated_python function."""
in_specs_treedef: tree_util.PyTreeDef | None = None
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None
out_specs_treedef: tree_util.PyTreeDef | None = None
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
devices: xc.DeviceList | None = None
def update(
self,
*,
in_specs_treedef: tree_util.PyTreeDef | None = None,
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
out_specs_treedef: tree_util.PyTreeDef | None = None,
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
devices: Sequence[jax.Device] | xc.DeviceList | None = None,
) -> Any:
"""Creates a new specialization with overrides."""
if in_specs_treedef is None:
in_specs_treedef = self.in_specs_treedef
elif self.in_specs_treedef is not None:
raise ValueError("in_specs already specified")
if in_specs_leaves is None:
in_specs_leaves = self.in_specs_leaves
elif self.in_specs_leaves is not None:
raise ValueError("in_specs already specified")
if out_specs_fn is None:
out_specs_fn = self.out_specs_fn
elif self.out_specs_fn is not None:
raise ValueError("out_specs_fn already specified")
if out_specs_treedef is None:
out_specs_treedef = self.out_specs_treedef
elif self.out_specs_treedef is not None:
raise ValueError("out_specs already specified")
if out_specs_leaves is None:
out_specs_leaves = self.out_specs_leaves
elif self.out_specs_leaves is not None:
raise ValueError("out_specs already specified")
if devices is None:
devices = self.devices
elif self.devices is not None:
raise ValueError("devices already specified")
elif not isinstance(devices, xc.DeviceList):
devices = xc.DeviceList(tuple(devices))
return Specialization(
in_specs_treedef,
in_specs_leaves,
out_specs_fn,
out_specs_treedef,
out_specs_leaves,
devices,
)
def _get_spec(x: Any) -> api.ShapeDtypeStruct:
"""Extracts a spec for a value, which must be a JAX Array."""
# TODO(hyeontaek): Allow Python values and automatically apply `shard_arg`
# with a suitable sharding and layout.
if not isinstance(x, jax.Array):
raise ValueError(
"colocated_python only supports jax.Array as input and output, but got"
f" {type(x)}."
)
return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None:
"""Returns a representative device list from function call arguments."""
device_list_set: set[xc.DeviceList] = set()
for x in args:
sharding = getattr(x, "sharding", None)
if sharding is not None:
device_list_set.add(x.sharding._internal_device_list)
if not device_list_set:
return None
if len(device_list_set) != 1:
raise ValueError(
"All arguments must use the same device list, but got"
f" multiple device lists: {device_list_set}."
)
return device_list_set.pop()
def _compile_to_executable(
name: str,
fun: Callable[..., Any],
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
devices: xc.DeviceList,
) -> Callable[..., Any]:
"""Compiles a Python function into a runtime executable."""
# TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an
# executable.
del name
del in_specs_leaves
del out_specs_leaves
del devices
return fun
def _make_output_specs_and_push_result_fun(
info: FunctionInfo, specialization: Specialization, uid: int
) -> Callable[..., Any]:
"""Creates a function that computes output specs and pushes the result to the result store."""
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.out_specs_treedef is None
assert specialization.out_specs_leaves is None
assert specialization.devices is not None
devices = specialization.devices
def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]:
result = info.fun(*args, **kwargs)
out_leaves, out_treedef = tree_util.tree_flatten(result)
out_spec_leaves = tuple(_get_spec(x) for x in out_leaves)
func_backend.SINGLETON_RESULT_STORE.push(uid, out_leaves)
return _serialize_specs(out_treedef, out_spec_leaves, devices)
out_specs_leaves, _ = tree_util.tree_flatten(
_make_specs_for_serialized_specs(specialization.devices),
)
name = getattr(info.fun, "__name__", "unknown")
name = f"{name}_output_specs_and_push_result"
return _compile_to_executable(
name=name,
fun=lowered_fun,
in_specs_leaves=specialization.in_specs_leaves,
out_specs_leaves=tuple(out_specs_leaves),
devices=specialization.devices,
)
def _make_pop_result_fun(
info: FunctionInfo, specialization: Specialization, uid: int
) -> Callable[..., Any]:
"""Makes a function that pops results from the result store."""
assert specialization.out_specs_treedef is not None
assert specialization.out_specs_leaves is not None
assert specialization.devices is not None
out_specs_treedef = specialization.out_specs_treedef
def lowered_fun() -> Any:
flat_result = func_backend.SINGLETON_RESULT_STORE.pop(uid)
return tree_util.tree_unflatten(out_specs_treedef, flat_result)
in_specs, _ = tree_util.tree_flatten((
# args
(),
# kwargs
(),
))
name = getattr(info.fun, "__name__", "unknown")
name = f"{name}_pop_result"
return _compile_to_executable(
name=name,
fun=lowered_fun,
in_specs_leaves=tuple(in_specs),
out_specs_leaves=specialization.out_specs_leaves,
devices=specialization.devices,
)
def _make_async_execution_fun(
info: FunctionInfo, specialization: Specialization
) -> Callable[..., Any]:
"""Makes a function that asynchronously executes the function."""
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.out_specs_treedef is not None
assert specialization.out_specs_leaves is not None
assert specialization.devices is not None
name = getattr(info.fun, "__name__", "unknown")
return _compile_to_executable(
name=name,
fun=info.fun,
in_specs_leaves=specialization.in_specs_leaves,
out_specs_leaves=specialization.out_specs_leaves,
devices=specialization.devices,
)
@jax.util.cache(max_size=None)
def _get_specialized_func(
info: FunctionInfo, specialization: Specialization
) -> Callable[..., Any]:
"""Returns a specialized function for the given specialization."""
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.devices is not None
uid = random.getrandbits(63)
mutex = threading.Lock()
# Asynchronous execution function that has known output_specs.
async_execution_func = None
def specialized_func(*args, **kwargs) -> Any:
"""Specialized function to be executed with given args and kwargs."""
nonlocal specialization, async_execution_func
with mutex:
if async_execution_func is None:
if specialization.out_specs_treedef is None:
if specialization.out_specs_fn is None:
serialized_out_specs = _make_output_specs_and_push_result_fun(
info, specialization, uid
)(*args, **kwargs)
# Waits for the output_specs. This may block.
out_specs_treedef, out_specs_leaves = _deserialize_specs(
serialized_out_specs
)
# Subsequent calls would use async_execution_func with discovered
# output_specs.
specialization = specialization.update(
out_specs_treedef=out_specs_treedef,
out_specs_leaves=out_specs_leaves,
)
async_execution_func = _make_async_execution_fun(
info, specialization
)
return _make_pop_result_fun(info, specialization, uid)()
else:
# Compute out_specs using out_specs_fn and inputs.
out_specs = specialization.out_specs_fn(*args, **kwargs)
# Type checking is ignored to silence mypy error: Incompatible types
# in assignment (expression has type "list[Any]", variable has type
# "tuple[ShapeDtypeStruct, ...]") [assignment]
out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( # type: ignore[assignment]
out_specs
)
specialization = specialization.update(
out_specs_treedef=out_specs_treedef,
out_specs_leaves=tuple(out_specs_leaves),
)
async_execution_func = _make_async_execution_fun(
info, specialization
)
# Fall-through.
else:
async_execution_func = _make_async_execution_fun(info, specialization)
# Fall-through.
return async_execution_func(*args, **kwargs)
return specialized_func
def make_callable(
fun: Callable[..., Any],
fun_sourceinfo: str | None,
fun_signature: inspect.Signature | None,
) -> Callable[..., Any]:
"""Makes a colocated Python callable."""
return _make_callable(
FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization()
)
def _make_callable(
info: FunctionInfo,
specialization: Specialization,
) -> Callable[..., Any]:
"""Internal implementation of make_callable."""
def specialize(
in_specs: ShapeDtypeStructTree | None = None,
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
devices: Sequence[jax.Device] | None = None,
) -> Callable[..., Any]:
"""Returns a colocated Python callable with extra specialization.
Args:
in_specs: Optionally specifies the expected input specs. Input specs are
expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a
function call.
out_specs_fn: Optionally specifies a function that computes the output
specs from input specs. If unspecified, colocated_python will compute
the output specs during the very first execution, and this execution
will be synchronous.
devices: Optionally specifies the devices to execute the function on. Must
be provided if in_specs has no leaves because devices cannot be inferred
from input specs or arguments.
Returns:
A colocated Python callable with extra specialization.
"""
# TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if
# `out_specs_fn(in_specs)` returns at least one leaf that we can use for
# inferring `devices`.
if in_specs is None:
in_specs_leaves, in_specs_treedef = None, None
else:
in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs)
in_specs_leaves = tuple(in_specs_leaves_list)
return _make_callable(
info,
specialization.update(
in_specs_treedef=in_specs_treedef,
in_specs_leaves=in_specs_leaves,
out_specs_fn=out_specs_fn,
devices=devices,
),
)
@api_boundary
def __call__(*args, **kwargs) -> Any:
"""Executes the function.
If the output specs are not known, the very first execution will be
synchronous.
"""
args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs))
in_specs_leaves = tuple(_get_spec(x) for x in args_leaves)
if specialization.in_specs_treedef is None:
# Allow input polymorphism by applying input_specs specialization
# temporarily for this call.
return _make_callable(
info,
specialization.update(
in_specs_treedef=in_specs_treedef,
in_specs_leaves=in_specs_leaves,
),
)(*args, **kwargs)
if specialization.devices is None:
devices = _infer_devices_from_args(args_leaves)
if devices is None:
raise ValueError(
"No devices found. colocated_python function without input"
" arguments must be first specialized with devices."
)
# Allow device polymorphism by applying devices specialization temporarily
# for this call.
return _make_callable(info, specialization.update(devices=devices))(
*args, **kwargs
)
# Assertion is added to silence mypy error: Unsupported operand types for !=
# ("PyTreeDef" and "None") [operator]
assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef)
# If input_specs is known, verify that it matches actual inputs.
if (specialization.in_specs_treedef != in_specs_treedef
or specialization.in_specs_leaves != in_specs_leaves):
raise ValueError(
"Input specs in specialization and input specs of arguments must have"
" the same pytree structure, but they have the following structural"
" differences:\n"
+ ("\n".join(
f" - {tree_util.keystr(path)} is a {thing1} in value 1 and"
f" a {thing2} in value 2, so {explanation}.\n"
for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef(
specialization.in_specs_treedef, in_specs_treedef
))))
return _get_specialized_func(info, specialization)(*args, **kwargs)
__call__ = wraps(info.fun)(__call__)
__call__.specialize = specialize
return __call__

View File

@ -0,0 +1,44 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Backend for colocated_python.func."""
from __future__ import annotations
import threading
from typing import Sequence
import jax
class _ResultStore:
"""Temporarily stores results from synchronous execution of functions."""
def __init__(self) -> None:
self._lock = threading.Lock()
self._storage: dict[int, Sequence[jax.Array]] = {}
def push(self, uid: int, out: Sequence[jax.Array]) -> None:
with self._lock:
if uid in self._storage:
raise ValueError(f"uid {uid} already exists")
self._storage[uid] = out
def pop(self, uid: int) -> Sequence[jax.Array]:
with self._lock:
if uid not in self._storage:
raise ValueError(f"uid {uid} does not exist")
return self._storage.pop(uid)
SINGLETON_RESULT_STORE = _ResultStore()

View File

@ -0,0 +1,228 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python serialization utilities."""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
from __future__ import annotations
import collections
import io
from typing import Any, Callable, Sequence
try:
import cloudpickle # type: ignore[import-not-found]
except ImportError:
cloudpickle = None
import jax
from jax._src import api
from jax._src import tree_util
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
import numpy as np
DeviceList = xc.DeviceList
# Hard-coded limit for serialized specs size.
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
_MAX_SERIALIZED_SPECS_SIZE = 1048576
@jax.util.cache(max_size=None)
def _get_cpu_device_map() -> dict[int, jax.Device]:
"""Returns a map from a device id to a matching device."""
cpu_device_map: dict[int, jax.Device] = {}
# TODO(hyeontaek): We should look up CPU devices for a specific CPU backend.
# When deserializing a device on the controller, the backend should be the one
# associated with colocated_python. When deserializing on the colocated_python
# executor, it should be the CPU backend visible to the user function running
# under colocated_python.
for backed in xb.backends().values():
for d in backed._get_all_devices(): # pylint: disable=protected-access
if d.device_kind == "cpu":
if d.id in cpu_device_map:
raise ValueError(
f"Multiple CPU devices with id {d.id} found:"
f" {cpu_device_map[d.id]} and {d}"
)
cpu_device_map[d.id] = d
return cpu_device_map
def _reduce_mesh(
mesh: jax.sharding.Mesh,
) -> tuple[Callable[..., jax.sharding.Mesh], Any]:
def make_mesh(
mesh_device_ids: np.ndarray, axis_names: Any
) -> jax.sharding.Mesh:
cpu_device_map = _get_cpu_device_map()
mesh_devices = np.vectorize(lambda device_id: cpu_device_map[device_id])(
mesh_device_ids
)
return jax.sharding.Mesh(mesh_devices, axis_names)
mesh_device_ids = np.vectorize(lambda d: d.id, otypes=[int])(mesh.devices)
return make_mesh, (mesh_device_ids, mesh.axis_names)
def _reduce_device_list(
device_list: DeviceList,
) -> tuple[Callable[..., DeviceList], Any]:
def make_device_list(device_ids: Sequence[int]) -> DeviceList:
cpu_device_map = _get_cpu_device_map()
devices = np.vectorize(lambda device_id: cpu_device_map[device_id])(
device_ids
)
return DeviceList(devices)
device_ids = [d.id for d in device_list]
return make_device_list, (device_ids,)
def _reduce_single_device_sharding(
sharding: jax.sharding.SingleDeviceSharding,
) -> tuple[Callable[..., jax.sharding.SingleDeviceSharding], Any]:
def make_single_device_sharding(device_id: int):
cpu_device_map = _get_cpu_device_map()
return jax.sharding.SingleDeviceSharding(cpu_device_map[device_id])
return make_single_device_sharding, (sharding.device_set.pop().id,)
def _serialize(obj: Any) -> bytes:
"""Serializes callables and input/output spec objects.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python.
This module contains utility functions used internally for implementiong
`colocated_python` when it ships callables and input/output specs through
IFRT. The pickled data is produced and consumed in an ephermeral fashion
without any persistence, and it does not expect any version compatibility
(which cloudpickle does not guarantee). Furthermore, serialization and
deserialization is expected to be done on machine(s) that are controlled by a
single tenant, which allows unpickling done during deserialization to be
trusted.
Raises:
ModuleNotFoundError: If cloudpickle is not available.
"""
if cloudpickle is None:
raise ModuleNotFoundError('No module named "cloudpickle"')
class _CustomPickler(cloudpickle.Pickler):
dispatch_table = collections.ChainMap(
{jax.sharding.Mesh: _reduce_mesh},
{DeviceList: _reduce_device_list},
{jax.sharding.SingleDeviceSharding: _reduce_single_device_sharding},
cloudpickle.CloudPickler.dispatch_table, # pylint: disable=attribute-error
)
dispatch = dispatch_table
with io.BytesIO() as file:
_CustomPickler(file).dump(obj)
return file.getvalue()
def _deserialize(serialized: bytes) -> Any:
"""Deserializes callables and input/output spec objects.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
Raises:
ModuleNotFoundError: If cloudpickle is not available.
"""
if cloudpickle is None:
raise ModuleNotFoundError('No module named "cloudpickle"')
return cloudpickle.loads(serialized)
def _make_specs_for_serialized_specs(
devices: DeviceList,
) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]:
"""Makes output specs for serialized specs."""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
return (
api.ShapeDtypeStruct(
shape=(), dtype=np.int32, sharding=replicated_sharding
),
api.ShapeDtypeStruct(
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
dtype=np.uint8,
sharding=replicated_sharding,
),
)
def _serialize_specs(
specs_treedef: tree_util.PyTreeDef,
specs_leaves: tuple[api.ShapeDtypeStruct, ...],
devices: DeviceList,
) -> tuple[jax.Array, ...]:
"""Serializes the output specs into a tuple of arrays.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
"""
s = _serialize((specs_treedef, specs_leaves))
assert (
len(s) <= _MAX_SERIALIZED_SPECS_SIZE
), f"Too large serialized spec size: {len(s)}"
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
len_array = jax.make_array_from_callback(
shape=(),
sharding=replicated_sharding,
data_callback=lambda _: np.array(len(s), dtype=np.int32),
)
data_array = jax.make_array_from_callback(
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
sharding=replicated_sharding,
data_callback=lambda _: np.frombuffer(
s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)),
dtype=np.uint8,
),
)
return len_array, data_array
def _deserialize_specs(
serialized_specs: tuple[jax.Array, ...],
) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]:
"""Deserializes the specs from the serialized specs.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
"""
# TODO(jmudigonda): Use a string-typed array for output structure when it
# becomes available. Using a fixed uint8 array is only for prototyping.
len_array, data_array = serialized_specs
length = int(len_array.addressable_shards[0].data)
data = np.asarray(data_array.addressable_shards[0].data).tobytes()
return _deserialize(data[:length])

View File

@ -1345,6 +1345,14 @@ jax_multiplatform_test(
],
)
jax_multiplatform_test(
name = "colocated_python_test",
srcs = ["colocated_python_test.py"],
deps = [
"//jax:experimental_colocated_python",
],
)
jax_multiplatform_test(
name = "experimental_rnn_test",
srcs = ["experimental_rnn_test.py"],

View File

@ -0,0 +1,210 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Sequence
from absl.testing import absltest
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax.experimental import colocated_python
from jax.experimental.colocated_python import func as colocated_python_func
import jax.numpy as jnp
import numpy as np
config.parse_flags_with_absl()
def _colocated_cpu_devices(
devices: Sequence[jax.Device],
) -> Sequence[jax.Device]:
"""Returns CPU devices colocated with the given devices."""
# TODO(hyeontaek): Use `colocated_python.colocated_cpu_devices(devices)` once
# PjRt-IFRT prepares CPU devices by its own.
cpu_backend_devices = jax.local_devices(backend="cpu")
device_index_map = {device.id: i for i, device in enumerate(jax.devices())}
return [
cpu_backend_devices[device_index_map[device.id]] for device in devices
]
@contextlib.contextmanager
def _count_colocated_python_specialization_cache_miss() -> list[int]:
"""Counts the number of cache misses for colocated_python specialization."""
original_get_specialized_func = colocated_python_func._get_specialized_func
count = [0]
@jax.util.cache(max_size=None)
def get_specialized_func(*args, **kwargs):
count[0] += 1
return original_get_specialized_func(*args, **kwargs)
colocated_python_func._get_specialized_func = get_specialized_func
try:
yield count
finally:
colocated_python_func._get_specialized_func = original_get_specialized_func
_exit_stack = contextlib.ExitStack()
def setUpModule():
# TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT
# prepares CPU devices by its own.
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
def tearDownModule():
_exit_stack.close()
class ColocatedPythonTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_extension_version < 290:
self.skipTest("Requires xla_extension_version >= 290")
def testSimpleFunction(self):
@colocated_python.colocated_python
def add_one(x):
return x + 1
cpu_devices = _colocated_cpu_devices(jax.local_devices())
x = np.array(1)
x = jax.device_put(x, cpu_devices[0])
with _count_colocated_python_specialization_cache_miss() as count:
out = add_one(x)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
out = add_one(x)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
def testSimpleFunctioWithTree(self):
@colocated_python.colocated_python
def add_one(x):
return jax.tree.map(lambda x: x + 1, x)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))
with _count_colocated_python_specialization_cache_miss() as count:
out = add_one(x)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 1)
out = add_one(x)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 1)
def testEmptyInputFailsWithoutSpecialization(self):
@colocated_python.colocated_python
def make_zero():
return jnp.array(0)
with self.assertRaisesRegex(
ValueError,
"No devices found. colocated_python function without input arguments"
" must be first specialized with devices.",
):
_ = make_zero()
def testEmptyInputWithDevicesSpecialization(self):
@colocated_python.colocated_python
def make_zero():
return jnp.array(0)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
with _count_colocated_python_specialization_cache_miss() as count:
make_zero = make_zero.specialize(devices=cpu_devices[:1])
out = make_zero()
self.assertEqual(out, np.array(0))
self.assertEqual(count[0], 1)
out = make_zero()
self.assertEqual(out, np.array(0))
self.assertEqual(count[0], 1)
def testInputPolymorphismWithoutOutSpecsFn(self):
@colocated_python.colocated_python
def add_one(x):
return jax.tree.map(lambda x: x + 1, x)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
x = np.array(1)
x = jax.device_put(x, cpu_devices[0])
with _count_colocated_python_specialization_cache_miss() as count:
out = add_one(x)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
out = add_one(x)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
# Different input tree structure and dtype/shape.
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))
out = add_one(x)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
out = add_one(x)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
def testInputPolymorphismAllowedWithOutSpecsFn(self):
@colocated_python.colocated_python
def add_one(x):
return jax.tree.map(lambda x: x + 1, x)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
x = np.array(1)
x = jax.device_put(x, cpu_devices[0])
with _count_colocated_python_specialization_cache_miss() as count:
add_one = add_one.specialize(out_specs_fn=lambda x: x)
out = add_one(x)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
out = add_one(x)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
# Different input tree structure and dtype/shape.
x = [np.array(1), (np.array(2), {"v": jnp.array(3)})]
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))
out = add_one(x)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
out = add_one(x)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())