mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Add the function API of jax.experimental.colocated_python
This change adds an experimental API `jax.experimental.colocated_python`. The ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python code that runs close to (or on) accelerator hosts. Multi-controller JAX can trivially achieve this colocated Python code execution today, while single-controller JAX needed its own solution for distributed Python code execution, which creates fragmentation of the user code for these two runtime architectures. `colocated_python` is an attempt to define a single device model and portable API to allow the user to write a single code once that can run on both runtime architectures. This change includes an implementation of the function API portion of `jax.experimental.colocated_python`. A (stateful) object API will be added separately. Also there will be a separate change that expresses serialized functions as an IFRT `CustomCallProgram`. It is currently in an early development stage. Please proceed with a caution when using the API. PiperOrigin-RevId: 690705899
This commit is contained in:
parent
9fd1ef2784
commit
77797f434d
21
jax/BUILD
21
jax/BUILD
@ -1159,3 +1159,24 @@ pytype_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_colocated_python",
|
||||
srcs = [
|
||||
"experimental/colocated_python/__init__.py",
|
||||
"experimental/colocated_python/api.py",
|
||||
"experimental/colocated_python/func.py",
|
||||
"experimental/colocated_python/func_backend.py",
|
||||
"experimental/colocated_python/serialization.py",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":api_util",
|
||||
":jax",
|
||||
":traceback_util",
|
||||
":tree_util",
|
||||
":util",
|
||||
":xla_bridge",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy") + py_deps("cloudpickle"),
|
||||
)
|
||||
|
23
jax/experimental/colocated_python/__init__.py
Normal file
23
jax/experimental/colocated_python/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Colocated Python API."""
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
# pylint: disable=useless-import-alias
|
||||
from jax.experimental.colocated_python.api import (
|
||||
colocated_cpu_devices as colocated_cpu_devices,
|
||||
colocated_python as colocated_python,
|
||||
)
|
59
jax/experimental/colocated_python/api.py
Normal file
59
jax/experimental/colocated_python/api.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Colocated Python top-level API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental.colocated_python.func import make_callable
|
||||
|
||||
|
||||
def colocated_cpu_devices(
|
||||
devices: Sequence[jax.Device],
|
||||
) -> Sequence[jax.Device]:
|
||||
"""Finds CPU devices colocated with the given devices."""
|
||||
if xla_extension_version < 290:
|
||||
raise NotImplementedError("Requires xla_extension_version >= 290")
|
||||
|
||||
cpu_devices_by_colocation_id = collections.defaultdict(list)
|
||||
for device in devices[0].backend._get_all_devices(): # pylint: disable=protected-access
|
||||
if device.device_kind == "cpu":
|
||||
cpu_devices_by_colocation_id[device.colocation_id].append(device)
|
||||
if not cpu_devices_by_colocation_id:
|
||||
raise ValueError("No CPU devices found")
|
||||
|
||||
colocated_cpu_devices = []
|
||||
for device in devices:
|
||||
matches = cpu_devices_by_colocation_id[device.colocation_id]
|
||||
if not matches:
|
||||
raise ValueError(f"Device {device} has no colocated devices")
|
||||
elif len(matches) > 1:
|
||||
raise ValueError(
|
||||
f"Ambiguous colocated devices; device {device} has"
|
||||
f" {len(matches)} colocated devices: f{matches}"
|
||||
)
|
||||
colocated_cpu_devices.append(matches[0])
|
||||
return colocated_cpu_devices
|
||||
|
||||
|
||||
def colocated_python(fun: Callable[..., Any]) -> Any:
|
||||
"""Executes the given Python function on the same device as the arguments."""
|
||||
return make_callable(
|
||||
fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun)
|
||||
)
|
417
jax/experimental/colocated_python/func.py
Normal file
417
jax/experimental/colocated_python/func.py
Normal file
@ -0,0 +1,417 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Colocated Python function API implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import random
|
||||
import threading
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax._src import tree_util
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import wraps
|
||||
from jax.experimental.colocated_python import func_backend
|
||||
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs
|
||||
|
||||
ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, slots=True)
|
||||
class FunctionInfo:
|
||||
"""User function wrapped by colocated_python."""
|
||||
|
||||
fun: Callable[..., Any]
|
||||
fun_sourceinfo: str | None
|
||||
fun_signature: inspect.Signature | None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, slots=True)
|
||||
class Specialization:
|
||||
"""Specialization for a colocated_python function."""
|
||||
|
||||
in_specs_treedef: tree_util.PyTreeDef | None = None
|
||||
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
|
||||
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None
|
||||
out_specs_treedef: tree_util.PyTreeDef | None = None
|
||||
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
|
||||
devices: xc.DeviceList | None = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
*,
|
||||
in_specs_treedef: tree_util.PyTreeDef | None = None,
|
||||
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
|
||||
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
|
||||
out_specs_treedef: tree_util.PyTreeDef | None = None,
|
||||
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
|
||||
devices: Sequence[jax.Device] | xc.DeviceList | None = None,
|
||||
) -> Any:
|
||||
"""Creates a new specialization with overrides."""
|
||||
if in_specs_treedef is None:
|
||||
in_specs_treedef = self.in_specs_treedef
|
||||
elif self.in_specs_treedef is not None:
|
||||
raise ValueError("in_specs already specified")
|
||||
if in_specs_leaves is None:
|
||||
in_specs_leaves = self.in_specs_leaves
|
||||
elif self.in_specs_leaves is not None:
|
||||
raise ValueError("in_specs already specified")
|
||||
|
||||
if out_specs_fn is None:
|
||||
out_specs_fn = self.out_specs_fn
|
||||
elif self.out_specs_fn is not None:
|
||||
raise ValueError("out_specs_fn already specified")
|
||||
|
||||
if out_specs_treedef is None:
|
||||
out_specs_treedef = self.out_specs_treedef
|
||||
elif self.out_specs_treedef is not None:
|
||||
raise ValueError("out_specs already specified")
|
||||
if out_specs_leaves is None:
|
||||
out_specs_leaves = self.out_specs_leaves
|
||||
elif self.out_specs_leaves is not None:
|
||||
raise ValueError("out_specs already specified")
|
||||
|
||||
if devices is None:
|
||||
devices = self.devices
|
||||
elif self.devices is not None:
|
||||
raise ValueError("devices already specified")
|
||||
elif not isinstance(devices, xc.DeviceList):
|
||||
devices = xc.DeviceList(tuple(devices))
|
||||
|
||||
return Specialization(
|
||||
in_specs_treedef,
|
||||
in_specs_leaves,
|
||||
out_specs_fn,
|
||||
out_specs_treedef,
|
||||
out_specs_leaves,
|
||||
devices,
|
||||
)
|
||||
|
||||
|
||||
def _get_spec(x: Any) -> api.ShapeDtypeStruct:
|
||||
"""Extracts a spec for a value, which must be a JAX Array."""
|
||||
# TODO(hyeontaek): Allow Python values and automatically apply `shard_arg`
|
||||
# with a suitable sharding and layout.
|
||||
if not isinstance(x, jax.Array):
|
||||
raise ValueError(
|
||||
"colocated_python only supports jax.Array as input and output, but got"
|
||||
f" {type(x)}."
|
||||
)
|
||||
return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
|
||||
|
||||
|
||||
def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None:
|
||||
"""Returns a representative device list from function call arguments."""
|
||||
device_list_set: set[xc.DeviceList] = set()
|
||||
for x in args:
|
||||
sharding = getattr(x, "sharding", None)
|
||||
if sharding is not None:
|
||||
device_list_set.add(x.sharding._internal_device_list)
|
||||
if not device_list_set:
|
||||
return None
|
||||
if len(device_list_set) != 1:
|
||||
raise ValueError(
|
||||
"All arguments must use the same device list, but got"
|
||||
f" multiple device lists: {device_list_set}."
|
||||
)
|
||||
return device_list_set.pop()
|
||||
|
||||
|
||||
def _compile_to_executable(
|
||||
name: str,
|
||||
fun: Callable[..., Any],
|
||||
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
|
||||
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
|
||||
devices: xc.DeviceList,
|
||||
) -> Callable[..., Any]:
|
||||
"""Compiles a Python function into a runtime executable."""
|
||||
# TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an
|
||||
# executable.
|
||||
del name
|
||||
del in_specs_leaves
|
||||
del out_specs_leaves
|
||||
del devices
|
||||
return fun
|
||||
|
||||
|
||||
def _make_output_specs_and_push_result_fun(
|
||||
info: FunctionInfo, specialization: Specialization, uid: int
|
||||
) -> Callable[..., Any]:
|
||||
"""Creates a function that computes output specs and pushes the result to the result store."""
|
||||
assert specialization.in_specs_treedef is not None
|
||||
assert specialization.in_specs_leaves is not None
|
||||
assert specialization.out_specs_treedef is None
|
||||
assert specialization.out_specs_leaves is None
|
||||
assert specialization.devices is not None
|
||||
|
||||
devices = specialization.devices
|
||||
|
||||
def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]:
|
||||
result = info.fun(*args, **kwargs)
|
||||
out_leaves, out_treedef = tree_util.tree_flatten(result)
|
||||
out_spec_leaves = tuple(_get_spec(x) for x in out_leaves)
|
||||
func_backend.SINGLETON_RESULT_STORE.push(uid, out_leaves)
|
||||
return _serialize_specs(out_treedef, out_spec_leaves, devices)
|
||||
|
||||
out_specs_leaves, _ = tree_util.tree_flatten(
|
||||
_make_specs_for_serialized_specs(specialization.devices),
|
||||
)
|
||||
name = getattr(info.fun, "__name__", "unknown")
|
||||
name = f"{name}_output_specs_and_push_result"
|
||||
return _compile_to_executable(
|
||||
name=name,
|
||||
fun=lowered_fun,
|
||||
in_specs_leaves=specialization.in_specs_leaves,
|
||||
out_specs_leaves=tuple(out_specs_leaves),
|
||||
devices=specialization.devices,
|
||||
)
|
||||
|
||||
|
||||
def _make_pop_result_fun(
|
||||
info: FunctionInfo, specialization: Specialization, uid: int
|
||||
) -> Callable[..., Any]:
|
||||
"""Makes a function that pops results from the result store."""
|
||||
assert specialization.out_specs_treedef is not None
|
||||
assert specialization.out_specs_leaves is not None
|
||||
assert specialization.devices is not None
|
||||
|
||||
out_specs_treedef = specialization.out_specs_treedef
|
||||
|
||||
def lowered_fun() -> Any:
|
||||
flat_result = func_backend.SINGLETON_RESULT_STORE.pop(uid)
|
||||
return tree_util.tree_unflatten(out_specs_treedef, flat_result)
|
||||
|
||||
in_specs, _ = tree_util.tree_flatten((
|
||||
# args
|
||||
(),
|
||||
# kwargs
|
||||
(),
|
||||
))
|
||||
name = getattr(info.fun, "__name__", "unknown")
|
||||
name = f"{name}_pop_result"
|
||||
return _compile_to_executable(
|
||||
name=name,
|
||||
fun=lowered_fun,
|
||||
in_specs_leaves=tuple(in_specs),
|
||||
out_specs_leaves=specialization.out_specs_leaves,
|
||||
devices=specialization.devices,
|
||||
)
|
||||
|
||||
|
||||
def _make_async_execution_fun(
|
||||
info: FunctionInfo, specialization: Specialization
|
||||
) -> Callable[..., Any]:
|
||||
"""Makes a function that asynchronously executes the function."""
|
||||
assert specialization.in_specs_treedef is not None
|
||||
assert specialization.in_specs_leaves is not None
|
||||
assert specialization.out_specs_treedef is not None
|
||||
assert specialization.out_specs_leaves is not None
|
||||
assert specialization.devices is not None
|
||||
|
||||
name = getattr(info.fun, "__name__", "unknown")
|
||||
return _compile_to_executable(
|
||||
name=name,
|
||||
fun=info.fun,
|
||||
in_specs_leaves=specialization.in_specs_leaves,
|
||||
out_specs_leaves=specialization.out_specs_leaves,
|
||||
devices=specialization.devices,
|
||||
)
|
||||
|
||||
|
||||
@jax.util.cache(max_size=None)
|
||||
def _get_specialized_func(
|
||||
info: FunctionInfo, specialization: Specialization
|
||||
) -> Callable[..., Any]:
|
||||
"""Returns a specialized function for the given specialization."""
|
||||
assert specialization.in_specs_treedef is not None
|
||||
assert specialization.in_specs_leaves is not None
|
||||
assert specialization.devices is not None
|
||||
uid = random.getrandbits(63)
|
||||
|
||||
mutex = threading.Lock()
|
||||
# Asynchronous execution function that has known output_specs.
|
||||
async_execution_func = None
|
||||
|
||||
def specialized_func(*args, **kwargs) -> Any:
|
||||
"""Specialized function to be executed with given args and kwargs."""
|
||||
nonlocal specialization, async_execution_func
|
||||
with mutex:
|
||||
if async_execution_func is None:
|
||||
if specialization.out_specs_treedef is None:
|
||||
if specialization.out_specs_fn is None:
|
||||
serialized_out_specs = _make_output_specs_and_push_result_fun(
|
||||
info, specialization, uid
|
||||
)(*args, **kwargs)
|
||||
|
||||
# Waits for the output_specs. This may block.
|
||||
out_specs_treedef, out_specs_leaves = _deserialize_specs(
|
||||
serialized_out_specs
|
||||
)
|
||||
|
||||
# Subsequent calls would use async_execution_func with discovered
|
||||
# output_specs.
|
||||
specialization = specialization.update(
|
||||
out_specs_treedef=out_specs_treedef,
|
||||
out_specs_leaves=out_specs_leaves,
|
||||
)
|
||||
async_execution_func = _make_async_execution_fun(
|
||||
info, specialization
|
||||
)
|
||||
|
||||
return _make_pop_result_fun(info, specialization, uid)()
|
||||
else:
|
||||
# Compute out_specs using out_specs_fn and inputs.
|
||||
out_specs = specialization.out_specs_fn(*args, **kwargs)
|
||||
# Type checking is ignored to silence mypy error: Incompatible types
|
||||
# in assignment (expression has type "list[Any]", variable has type
|
||||
# "tuple[ShapeDtypeStruct, ...]") [assignment]
|
||||
out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( # type: ignore[assignment]
|
||||
out_specs
|
||||
)
|
||||
specialization = specialization.update(
|
||||
out_specs_treedef=out_specs_treedef,
|
||||
out_specs_leaves=tuple(out_specs_leaves),
|
||||
)
|
||||
async_execution_func = _make_async_execution_fun(
|
||||
info, specialization
|
||||
)
|
||||
# Fall-through.
|
||||
else:
|
||||
async_execution_func = _make_async_execution_fun(info, specialization)
|
||||
# Fall-through.
|
||||
|
||||
return async_execution_func(*args, **kwargs)
|
||||
|
||||
return specialized_func
|
||||
|
||||
|
||||
def make_callable(
|
||||
fun: Callable[..., Any],
|
||||
fun_sourceinfo: str | None,
|
||||
fun_signature: inspect.Signature | None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Makes a colocated Python callable."""
|
||||
return _make_callable(
|
||||
FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization()
|
||||
)
|
||||
|
||||
|
||||
def _make_callable(
|
||||
info: FunctionInfo,
|
||||
specialization: Specialization,
|
||||
) -> Callable[..., Any]:
|
||||
"""Internal implementation of make_callable."""
|
||||
|
||||
def specialize(
|
||||
in_specs: ShapeDtypeStructTree | None = None,
|
||||
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
|
||||
devices: Sequence[jax.Device] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Returns a colocated Python callable with extra specialization.
|
||||
|
||||
Args:
|
||||
in_specs: Optionally specifies the expected input specs. Input specs are
|
||||
expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a
|
||||
function call.
|
||||
out_specs_fn: Optionally specifies a function that computes the output
|
||||
specs from input specs. If unspecified, colocated_python will compute
|
||||
the output specs during the very first execution, and this execution
|
||||
will be synchronous.
|
||||
devices: Optionally specifies the devices to execute the function on. Must
|
||||
be provided if in_specs has no leaves because devices cannot be inferred
|
||||
from input specs or arguments.
|
||||
|
||||
Returns:
|
||||
A colocated Python callable with extra specialization.
|
||||
"""
|
||||
# TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if
|
||||
# `out_specs_fn(in_specs)` returns at least one leaf that we can use for
|
||||
# inferring `devices`.
|
||||
if in_specs is None:
|
||||
in_specs_leaves, in_specs_treedef = None, None
|
||||
else:
|
||||
in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs)
|
||||
in_specs_leaves = tuple(in_specs_leaves_list)
|
||||
return _make_callable(
|
||||
info,
|
||||
specialization.update(
|
||||
in_specs_treedef=in_specs_treedef,
|
||||
in_specs_leaves=in_specs_leaves,
|
||||
out_specs_fn=out_specs_fn,
|
||||
devices=devices,
|
||||
),
|
||||
)
|
||||
|
||||
@api_boundary
|
||||
def __call__(*args, **kwargs) -> Any:
|
||||
"""Executes the function.
|
||||
|
||||
If the output specs are not known, the very first execution will be
|
||||
synchronous.
|
||||
"""
|
||||
args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs))
|
||||
|
||||
in_specs_leaves = tuple(_get_spec(x) for x in args_leaves)
|
||||
if specialization.in_specs_treedef is None:
|
||||
# Allow input polymorphism by applying input_specs specialization
|
||||
# temporarily for this call.
|
||||
return _make_callable(
|
||||
info,
|
||||
specialization.update(
|
||||
in_specs_treedef=in_specs_treedef,
|
||||
in_specs_leaves=in_specs_leaves,
|
||||
),
|
||||
)(*args, **kwargs)
|
||||
|
||||
if specialization.devices is None:
|
||||
devices = _infer_devices_from_args(args_leaves)
|
||||
if devices is None:
|
||||
raise ValueError(
|
||||
"No devices found. colocated_python function without input"
|
||||
" arguments must be first specialized with devices."
|
||||
)
|
||||
# Allow device polymorphism by applying devices specialization temporarily
|
||||
# for this call.
|
||||
return _make_callable(info, specialization.update(devices=devices))(
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
# Assertion is added to silence mypy error: Unsupported operand types for !=
|
||||
# ("PyTreeDef" and "None") [operator]
|
||||
assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef)
|
||||
|
||||
# If input_specs is known, verify that it matches actual inputs.
|
||||
if (specialization.in_specs_treedef != in_specs_treedef
|
||||
or specialization.in_specs_leaves != in_specs_leaves):
|
||||
raise ValueError(
|
||||
"Input specs in specialization and input specs of arguments must have"
|
||||
" the same pytree structure, but they have the following structural"
|
||||
" differences:\n"
|
||||
+ ("\n".join(
|
||||
f" - {tree_util.keystr(path)} is a {thing1} in value 1 and"
|
||||
f" a {thing2} in value 2, so {explanation}.\n"
|
||||
for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef(
|
||||
specialization.in_specs_treedef, in_specs_treedef
|
||||
))))
|
||||
|
||||
return _get_specialized_func(info, specialization)(*args, **kwargs)
|
||||
|
||||
__call__ = wraps(info.fun)(__call__)
|
||||
__call__.specialize = specialize
|
||||
return __call__
|
44
jax/experimental/colocated_python/func_backend.py
Normal file
44
jax/experimental/colocated_python/func_backend.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Backend for colocated_python.func."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Sequence
|
||||
|
||||
import jax
|
||||
|
||||
|
||||
class _ResultStore:
|
||||
"""Temporarily stores results from synchronous execution of functions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._storage: dict[int, Sequence[jax.Array]] = {}
|
||||
|
||||
def push(self, uid: int, out: Sequence[jax.Array]) -> None:
|
||||
with self._lock:
|
||||
if uid in self._storage:
|
||||
raise ValueError(f"uid {uid} already exists")
|
||||
self._storage[uid] = out
|
||||
|
||||
def pop(self, uid: int) -> Sequence[jax.Array]:
|
||||
with self._lock:
|
||||
if uid not in self._storage:
|
||||
raise ValueError(f"uid {uid} does not exist")
|
||||
return self._storage.pop(uid)
|
||||
|
||||
|
||||
SINGLETON_RESULT_STORE = _ResultStore()
|
228
jax/experimental/colocated_python/serialization.py
Normal file
228
jax/experimental/colocated_python/serialization.py
Normal file
@ -0,0 +1,228 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Colocated Python serialization utilities."""
|
||||
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import io
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
try:
|
||||
import cloudpickle # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
cloudpickle = None
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax._src import tree_util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
import numpy as np
|
||||
|
||||
DeviceList = xc.DeviceList
|
||||
|
||||
# Hard-coded limit for serialized specs size.
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
_MAX_SERIALIZED_SPECS_SIZE = 1048576
|
||||
|
||||
|
||||
@jax.util.cache(max_size=None)
|
||||
def _get_cpu_device_map() -> dict[int, jax.Device]:
|
||||
"""Returns a map from a device id to a matching device."""
|
||||
cpu_device_map: dict[int, jax.Device] = {}
|
||||
# TODO(hyeontaek): We should look up CPU devices for a specific CPU backend.
|
||||
# When deserializing a device on the controller, the backend should be the one
|
||||
# associated with colocated_python. When deserializing on the colocated_python
|
||||
# executor, it should be the CPU backend visible to the user function running
|
||||
# under colocated_python.
|
||||
for backed in xb.backends().values():
|
||||
for d in backed._get_all_devices(): # pylint: disable=protected-access
|
||||
if d.device_kind == "cpu":
|
||||
if d.id in cpu_device_map:
|
||||
raise ValueError(
|
||||
f"Multiple CPU devices with id {d.id} found:"
|
||||
f" {cpu_device_map[d.id]} and {d}"
|
||||
)
|
||||
cpu_device_map[d.id] = d
|
||||
return cpu_device_map
|
||||
|
||||
|
||||
def _reduce_mesh(
|
||||
mesh: jax.sharding.Mesh,
|
||||
) -> tuple[Callable[..., jax.sharding.Mesh], Any]:
|
||||
def make_mesh(
|
||||
mesh_device_ids: np.ndarray, axis_names: Any
|
||||
) -> jax.sharding.Mesh:
|
||||
cpu_device_map = _get_cpu_device_map()
|
||||
mesh_devices = np.vectorize(lambda device_id: cpu_device_map[device_id])(
|
||||
mesh_device_ids
|
||||
)
|
||||
return jax.sharding.Mesh(mesh_devices, axis_names)
|
||||
|
||||
mesh_device_ids = np.vectorize(lambda d: d.id, otypes=[int])(mesh.devices)
|
||||
return make_mesh, (mesh_device_ids, mesh.axis_names)
|
||||
|
||||
|
||||
def _reduce_device_list(
|
||||
device_list: DeviceList,
|
||||
) -> tuple[Callable[..., DeviceList], Any]:
|
||||
def make_device_list(device_ids: Sequence[int]) -> DeviceList:
|
||||
cpu_device_map = _get_cpu_device_map()
|
||||
devices = np.vectorize(lambda device_id: cpu_device_map[device_id])(
|
||||
device_ids
|
||||
)
|
||||
return DeviceList(devices)
|
||||
|
||||
device_ids = [d.id for d in device_list]
|
||||
return make_device_list, (device_ids,)
|
||||
|
||||
|
||||
def _reduce_single_device_sharding(
|
||||
sharding: jax.sharding.SingleDeviceSharding,
|
||||
) -> tuple[Callable[..., jax.sharding.SingleDeviceSharding], Any]:
|
||||
|
||||
def make_single_device_sharding(device_id: int):
|
||||
cpu_device_map = _get_cpu_device_map()
|
||||
return jax.sharding.SingleDeviceSharding(cpu_device_map[device_id])
|
||||
|
||||
return make_single_device_sharding, (sharding.device_set.pop().id,)
|
||||
|
||||
|
||||
def _serialize(obj: Any) -> bytes:
|
||||
"""Serializes callables and input/output spec objects.
|
||||
|
||||
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
|
||||
colocated_python.
|
||||
|
||||
This module contains utility functions used internally for implementiong
|
||||
`colocated_python` when it ships callables and input/output specs through
|
||||
IFRT. The pickled data is produced and consumed in an ephermeral fashion
|
||||
without any persistence, and it does not expect any version compatibility
|
||||
(which cloudpickle does not guarantee). Furthermore, serialization and
|
||||
deserialization is expected to be done on machine(s) that are controlled by a
|
||||
single tenant, which allows unpickling done during deserialization to be
|
||||
trusted.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If cloudpickle is not available.
|
||||
"""
|
||||
if cloudpickle is None:
|
||||
raise ModuleNotFoundError('No module named "cloudpickle"')
|
||||
|
||||
class _CustomPickler(cloudpickle.Pickler):
|
||||
dispatch_table = collections.ChainMap(
|
||||
{jax.sharding.Mesh: _reduce_mesh},
|
||||
{DeviceList: _reduce_device_list},
|
||||
{jax.sharding.SingleDeviceSharding: _reduce_single_device_sharding},
|
||||
cloudpickle.CloudPickler.dispatch_table, # pylint: disable=attribute-error
|
||||
)
|
||||
dispatch = dispatch_table
|
||||
|
||||
with io.BytesIO() as file:
|
||||
_CustomPickler(file).dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
def _deserialize(serialized: bytes) -> Any:
|
||||
"""Deserializes callables and input/output spec objects.
|
||||
|
||||
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
|
||||
colocated_python. See serialize() for details.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If cloudpickle is not available.
|
||||
"""
|
||||
if cloudpickle is None:
|
||||
raise ModuleNotFoundError('No module named "cloudpickle"')
|
||||
|
||||
return cloudpickle.loads(serialized)
|
||||
|
||||
|
||||
def _make_specs_for_serialized_specs(
|
||||
devices: DeviceList,
|
||||
) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]:
|
||||
"""Makes output specs for serialized specs."""
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
|
||||
replicated_sharding = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec()
|
||||
)
|
||||
return (
|
||||
api.ShapeDtypeStruct(
|
||||
shape=(), dtype=np.int32, sharding=replicated_sharding
|
||||
),
|
||||
api.ShapeDtypeStruct(
|
||||
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
|
||||
dtype=np.uint8,
|
||||
sharding=replicated_sharding,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _serialize_specs(
|
||||
specs_treedef: tree_util.PyTreeDef,
|
||||
specs_leaves: tuple[api.ShapeDtypeStruct, ...],
|
||||
devices: DeviceList,
|
||||
) -> tuple[jax.Array, ...]:
|
||||
"""Serializes the output specs into a tuple of arrays.
|
||||
|
||||
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
|
||||
colocated_python. See serialize() for details.
|
||||
"""
|
||||
s = _serialize((specs_treedef, specs_leaves))
|
||||
assert (
|
||||
len(s) <= _MAX_SERIALIZED_SPECS_SIZE
|
||||
), f"Too large serialized spec size: {len(s)}"
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
|
||||
replicated_sharding = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec()
|
||||
)
|
||||
len_array = jax.make_array_from_callback(
|
||||
shape=(),
|
||||
sharding=replicated_sharding,
|
||||
data_callback=lambda _: np.array(len(s), dtype=np.int32),
|
||||
)
|
||||
data_array = jax.make_array_from_callback(
|
||||
shape=(_MAX_SERIALIZED_SPECS_SIZE,),
|
||||
sharding=replicated_sharding,
|
||||
data_callback=lambda _: np.frombuffer(
|
||||
s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
)
|
||||
return len_array, data_array
|
||||
|
||||
|
||||
def _deserialize_specs(
|
||||
serialized_specs: tuple[jax.Array, ...],
|
||||
) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]:
|
||||
"""Deserializes the specs from the serialized specs.
|
||||
|
||||
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
|
||||
colocated_python. See serialize() for details.
|
||||
"""
|
||||
# TODO(jmudigonda): Use a string-typed array for output structure when it
|
||||
# becomes available. Using a fixed uint8 array is only for prototyping.
|
||||
len_array, data_array = serialized_specs
|
||||
length = int(len_array.addressable_shards[0].data)
|
||||
data = np.asarray(data_array.addressable_shards[0].data).tobytes()
|
||||
return _deserialize(data[:length])
|
@ -1345,6 +1345,14 @@ jax_multiplatform_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "colocated_python_test",
|
||||
srcs = ["colocated_python_test.py"],
|
||||
deps = [
|
||||
"//jax:experimental_colocated_python",
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "experimental_rnn_test",
|
||||
srcs = ["experimental_rnn_test.py"],
|
||||
|
210
tests/colocated_python_test.py
Normal file
210
tests/colocated_python_test.py
Normal file
@ -0,0 +1,210 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
|
||||
from jax.experimental import colocated_python
|
||||
from jax.experimental.colocated_python import func as colocated_python_func
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _colocated_cpu_devices(
|
||||
devices: Sequence[jax.Device],
|
||||
) -> Sequence[jax.Device]:
|
||||
"""Returns CPU devices colocated with the given devices."""
|
||||
# TODO(hyeontaek): Use `colocated_python.colocated_cpu_devices(devices)` once
|
||||
# PjRt-IFRT prepares CPU devices by its own.
|
||||
cpu_backend_devices = jax.local_devices(backend="cpu")
|
||||
device_index_map = {device.id: i for i, device in enumerate(jax.devices())}
|
||||
return [
|
||||
cpu_backend_devices[device_index_map[device.id]] for device in devices
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _count_colocated_python_specialization_cache_miss() -> list[int]:
|
||||
"""Counts the number of cache misses for colocated_python specialization."""
|
||||
original_get_specialized_func = colocated_python_func._get_specialized_func
|
||||
count = [0]
|
||||
|
||||
@jax.util.cache(max_size=None)
|
||||
def get_specialized_func(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_get_specialized_func(*args, **kwargs)
|
||||
|
||||
colocated_python_func._get_specialized_func = get_specialized_func
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
colocated_python_func._get_specialized_func = original_get_specialized_func
|
||||
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
|
||||
def setUpModule():
|
||||
# TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT
|
||||
# prepares CPU devices by its own.
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
|
||||
|
||||
class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_extension_version < 290:
|
||||
self.skipTest("Requires xla_extension_version >= 290")
|
||||
|
||||
def testSimpleFunction(self):
|
||||
@colocated_python.colocated_python
|
||||
def add_one(x):
|
||||
return x + 1
|
||||
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
x = np.array(1)
|
||||
x = jax.device_put(x, cpu_devices[0])
|
||||
|
||||
with _count_colocated_python_specialization_cache_miss() as count:
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def testSimpleFunctioWithTree(self):
|
||||
@colocated_python.colocated_python
|
||||
def add_one(x):
|
||||
return jax.tree.map(lambda x: x + 1, x)
|
||||
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
|
||||
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))
|
||||
|
||||
with _count_colocated_python_specialization_cache_miss() as count:
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def testEmptyInputFailsWithoutSpecialization(self):
|
||||
@colocated_python.colocated_python
|
||||
def make_zero():
|
||||
return jnp.array(0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"No devices found. colocated_python function without input arguments"
|
||||
" must be first specialized with devices.",
|
||||
):
|
||||
_ = make_zero()
|
||||
|
||||
def testEmptyInputWithDevicesSpecialization(self):
|
||||
@colocated_python.colocated_python
|
||||
def make_zero():
|
||||
return jnp.array(0)
|
||||
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
|
||||
with _count_colocated_python_specialization_cache_miss() as count:
|
||||
make_zero = make_zero.specialize(devices=cpu_devices[:1])
|
||||
out = make_zero()
|
||||
self.assertEqual(out, np.array(0))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
out = make_zero()
|
||||
self.assertEqual(out, np.array(0))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def testInputPolymorphismWithoutOutSpecsFn(self):
|
||||
@colocated_python.colocated_python
|
||||
def add_one(x):
|
||||
return jax.tree.map(lambda x: x + 1, x)
|
||||
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
x = np.array(1)
|
||||
x = jax.device_put(x, cpu_devices[0])
|
||||
|
||||
with _count_colocated_python_specialization_cache_miss() as count:
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
# Different input tree structure and dtype/shape.
|
||||
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
|
||||
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))
|
||||
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
def testInputPolymorphismAllowedWithOutSpecsFn(self):
|
||||
@colocated_python.colocated_python
|
||||
def add_one(x):
|
||||
return jax.tree.map(lambda x: x + 1, x)
|
||||
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
x = np.array(1)
|
||||
x = jax.device_put(x, cpu_devices[0])
|
||||
|
||||
with _count_colocated_python_specialization_cache_miss() as count:
|
||||
add_one = add_one.specialize(out_specs_fn=lambda x: x)
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
# Different input tree structure and dtype/shape.
|
||||
x = [np.array(1), (np.array(2), {"v": jnp.array(3)})]
|
||||
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))
|
||||
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
out = add_one(x)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user