[JAX] Introduce DeviceList backed by C++ xla::ifrt::DeviceList

This change adds `xla_client.DeviceList` that is implemented in C++
`jax::PyDeviceList`. `jax::PyDeviceList` implements the features of
`pxla._DeviceAssignment` as a functional drop-in replacement.
`jax::PyDeviceList` internally has `xla::ifrt::DeviceList`, which will be used
when using IFRT APIs without having to construct a new copy of a potentially
large device list.

`pxla._DeviceAssignment`'s interface is changed slightly to encourage avoiding
conversion to tuple.

Note that for the backward compatibility (and fast `xla_client.Device`
conversion), `jax::PyDeviceList` still uses a Python tuple whose element can be
any Python object matches `xla_client.Device` interface with duck typing. This
duck typing support will be removed when such use case is deprecated.
Eventually, we can try to avoid any type conversion to remove a shadow copy of
device list in JAX.

PiperOrigin-RevId: 555317152
This commit is contained in:
Hyeontaek Lim 2023-08-09 16:57:28 -07:00 committed by jax authors
parent 22a005c2a3
commit 97b96bbd4b

View File

@ -24,7 +24,7 @@ from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
from typing import (Any, Callable, NamedTuple, Iterator, Optional, Union, cast, TypeVar)
import warnings
import numpy as np
@ -58,6 +58,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
@ -1789,7 +1790,9 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
out_shardings = semantic_out_shardings.shardings
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
device_assignment = da_object.device_assignment
# TODO(yashkatariya): Make device_assignment directly usable in the downstream
# code without tuple conversion.
device_assignment = tuple(da_object)
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
if logger.isEnabledFor(log_priority):
@ -1862,37 +1865,52 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
nreps, tuple_args, lowering_result.shape_poly_state)
@dataclasses.dataclass(frozen=True)
class _DeviceAssignment:
device_assignment: tuple[xc.Device, ...]
if xla_extension_version >= 181:
_DeviceAssignment = xc.DeviceList
else:
@dataclasses.dataclass(frozen=True)
class _DeviceAssignment: # type: ignore
_device_assignment: tuple[xc.Device, ...]
@cached_property
def _hash(self):
return hash(self.device_assignment)
@cached_property
def _hash(self) -> int:
return hash(self._device_assignment)
def __hash__(self):
return self._hash
def __hash__(self) -> int:
return self._hash
def __eq__(self, other):
if not isinstance(other, _DeviceAssignment):
return False
if id(self) == id(other):
return True
return (self.device_assignment == other.device_assignment)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, _DeviceAssignment):
return False
if id(self) == id(other):
return True
return (self._device_assignment == other._device_assignment)
@cached_property
def is_fully_addressable(self):
return len(self.device_assignment) == len(self.addressable_device_assignment)
def __len__(self) -> int:
return len(self._device_assignment)
@cached_property
def addressable_device_assignment(self):
return [d for d in self.device_assignment
if d.process_index == d.client.process_index()]
def __getitem__(self, index: Any) -> Any:
return self._device_assignment[index]
def __iter__(self) -> Iterator[xc.Device]:
return iter(self._device_assignment)
@cached_property
def is_fully_addressable(self) -> bool:
return len(self._device_assignment) == len(
self.addressable_device_list._device_assignment
)
@cached_property
def addressable_device_list(self) -> _DeviceAssignment: # type: ignore
return _create_da_object(
tuple(d for d in self._device_assignment
if d.process_index == d.client.process_index()))
@lru_cache(maxsize=2048)
def _create_da_object(
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment:
def _create_da_object( # pytype: disable=invalid-annotation
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment: # type: ignore
return _DeviceAssignment(device_assignment)
@ -1961,7 +1979,7 @@ def lower_sharding_computation(
da_object = _create_da_object(tuple(device_assignment))
if not da_object.is_fully_addressable:
if not da_object.is_fully_addressable: # type: ignore
if inline and config.jax_spmd_mode != 'allow_all':
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
@ -2262,15 +2280,13 @@ def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
def _get_input_indices(
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: _DeviceAssignment | Sequence[xc.Device],
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
) -> Sequence[tuple[Index | None, ...]]:
input_indices = []
if isinstance(da_object, _DeviceAssignment):
num_addressable_devices = len(da_object.addressable_device_assignment)
else:
num_addressable_devices = len(
[d for d in da_object if d.process_index == d.client.process_index()])
if not isinstance(da_object, _DeviceAssignment):
da_object = _create_da_object(tuple(da_object))
num_addressable_devices = len(da_object.addressable_device_list)
for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token:
@ -2426,16 +2442,13 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
_allow_propagation_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys,
compiler_options_values):
device_assignment = da.device_assignment if isinstance(
da, _DeviceAssignment) else da
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
# If we were to optimize __getattr__ on xc.Device we might not need this
# workaround.
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
np.arange(len(device_assignment))
dev = np.vectorize(lambda i: da[i], otypes=[object])(
np.arange(len(da))
)
if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1
@ -2493,7 +2506,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
@dataclasses.dataclass
class UnloadedMeshExecutable:
xla_executable: Any
device_assignment: _DeviceAssignment | Sequence[xc.Device]
device_assignment: _DeviceAssignment | Sequence[xc.Device] # type: ignore
backend: xb.XlaBackend
input_avals: Sequence[ShapedArray]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
@ -2550,7 +2563,7 @@ class UnloadedMeshExecutable:
keepalive: Any,
kept_var_idx: set[int],
backend: xb.XlaBackend,
device_assignment: _DeviceAssignment | Sequence[xc.Device],
device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore
committed: bool,
pmap_nreps: int = 1,
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
@ -2563,8 +2576,10 @@ class UnloadedMeshExecutable:
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(
compiler_options.values()) if compiler_options is not None else None
da = device_assignment if isinstance(
device_assignment, _DeviceAssignment) else tuple(device_assignment)
if isinstance(device_assignment, _DeviceAssignment):
da = device_assignment
else:
da = _create_da_object(tuple(device_assignment))
del device_assignment
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
@ -2605,8 +2620,9 @@ class UnloadedMeshExecutable:
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1):
assert mesh is None
device_assignment = da.device_assignment if isinstance( # type: ignore
da, _DeviceAssignment) else da
# TODO(yashkatariya): Make da directly usable in the downstream code
# without tuple conversion.
device_assignment = tuple(da)
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, len(global_out_avals)) # type: ignore
orig_out_shardings = out_shardings
@ -2707,9 +2723,9 @@ class MeshExecutable(stages.XlaExecutable):
backend, da_object, committed, kept_var_idx, 1)
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object.device_assignment)
jaxpr, consts, in_shardings, da_object)
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
local_device_assignment = da_object.addressable_device_assignment
local_device_assignment = da_object.addressable_device_list
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
@ -2889,7 +2905,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
da_object, committed, kept_var_idx, pmap_nreps):
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object.device_assignment)
jaxpr, consts, in_shardings, da_object)
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
handle_outs = global_avals_to_results_handler(
@ -2898,7 +2914,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
# Use the standard out_handler.
unsafe_call = backend.compile_replicated(
is_trivial=True, jaxpr=jaxpr, consts=consts,
device_assignment=da_object.device_assignment, in_avals=global_in_avals,
device_assignment=da_object, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx, out_handler=handle_outs,
out_shardings=out_shardings, pmap_nreps=pmap_nreps)