[JAX] Introduce DeviceList backed by C++ xla::ifrt::DeviceList

This change adds `xla_client.DeviceList` that is implemented in C++
`jax::PyDeviceList`. `jax::PyDeviceList` implements the features of
`pxla._DeviceAssignment` as a functional drop-in replacement.
`jax::PyDeviceList` internally has `xla::ifrt::DeviceList`, which will be used
when using IFRT APIs without having to construct a new copy of a potentially
large device list.

`pxla._DeviceAssignment`'s interface is changed slightly to encourage avoiding
conversion to tuple.

Note that for the backward compatibility (and fast `xla_client.Device`
conversion), `jax::PyDeviceList` still uses a Python tuple whose element can be
any Python object matches `xla_client.Device` interface with duck typing. This
duck typing support will be removed when such use case is deprecated.
Eventually, we can try to avoid any type conversion to remove a shadow copy of
device list in JAX.

PiperOrigin-RevId: 555317152
This commit is contained in:
Hyeontaek Lim 2023-08-09 16:57:28 -07:00 committed by jax authors
parent 22a005c2a3
commit 97b96bbd4b

View File

@ -24,7 +24,7 @@ from functools import partial, lru_cache, cached_property
import itertools as it import itertools as it
import logging import logging
import math import math
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar) from typing import (Any, Callable, NamedTuple, Iterator, Optional, Union, cast, TypeVar)
import warnings import warnings
import numpy as np import numpy as np
@ -58,6 +58,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec from jax._src.partition_spec import PartitionSpec
@ -1789,7 +1790,9 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
out_shardings = semantic_out_shardings.shardings out_shardings = semantic_out_shardings.shardings
global_in_avals = closed_jaxpr.in_avals global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals global_out_avals = closed_jaxpr.out_avals
device_assignment = da_object.device_assignment # TODO(yashkatariya): Make device_assignment directly usable in the downstream
# code without tuple conversion.
device_assignment = tuple(da_object)
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
if logger.isEnabledFor(log_priority): if logger.isEnabledFor(log_priority):
@ -1862,37 +1865,52 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
nreps, tuple_args, lowering_result.shape_poly_state) nreps, tuple_args, lowering_result.shape_poly_state)
@dataclasses.dataclass(frozen=True) if xla_extension_version >= 181:
class _DeviceAssignment: _DeviceAssignment = xc.DeviceList
device_assignment: tuple[xc.Device, ...] else:
@dataclasses.dataclass(frozen=True)
class _DeviceAssignment: # type: ignore
_device_assignment: tuple[xc.Device, ...]
@cached_property @cached_property
def _hash(self): def _hash(self) -> int:
return hash(self.device_assignment) return hash(self._device_assignment)
def __hash__(self): def __hash__(self) -> int:
return self._hash return self._hash
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
if not isinstance(other, _DeviceAssignment): if not isinstance(other, _DeviceAssignment):
return False return False
if id(self) == id(other): if id(self) == id(other):
return True return True
return (self.device_assignment == other.device_assignment) return (self._device_assignment == other._device_assignment)
@cached_property def __len__(self) -> int:
def is_fully_addressable(self): return len(self._device_assignment)
return len(self.device_assignment) == len(self.addressable_device_assignment)
@cached_property def __getitem__(self, index: Any) -> Any:
def addressable_device_assignment(self): return self._device_assignment[index]
return [d for d in self.device_assignment
if d.process_index == d.client.process_index()] def __iter__(self) -> Iterator[xc.Device]:
return iter(self._device_assignment)
@cached_property
def is_fully_addressable(self) -> bool:
return len(self._device_assignment) == len(
self.addressable_device_list._device_assignment
)
@cached_property
def addressable_device_list(self) -> _DeviceAssignment: # type: ignore
return _create_da_object(
tuple(d for d in self._device_assignment
if d.process_index == d.client.process_index()))
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def _create_da_object( def _create_da_object( # pytype: disable=invalid-annotation
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment: device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment: # type: ignore
return _DeviceAssignment(device_assignment) return _DeviceAssignment(device_assignment)
@ -1961,7 +1979,7 @@ def lower_sharding_computation(
da_object = _create_da_object(tuple(device_assignment)) da_object = _create_da_object(tuple(device_assignment))
if not da_object.is_fully_addressable: if not da_object.is_fully_addressable: # type: ignore
if inline and config.jax_spmd_mode != 'allow_all': if inline and config.jax_spmd_mode != 'allow_all':
raise RuntimeError( raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this " "Running operations on `Array`s that are not fully addressable by this "
@ -2262,15 +2280,13 @@ def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
def _get_input_indices( def _get_input_indices(
avals: Sequence[ShapedArray], avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding], shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: _DeviceAssignment | Sequence[xc.Device], da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
) -> Sequence[tuple[Index | None, ...]]: ) -> Sequence[tuple[Index | None, ...]]:
input_indices = [] input_indices = []
if isinstance(da_object, _DeviceAssignment): if not isinstance(da_object, _DeviceAssignment):
num_addressable_devices = len(da_object.addressable_device_assignment) da_object = _create_da_object(tuple(da_object))
else: num_addressable_devices = len(da_object.addressable_device_list)
num_addressable_devices = len(
[d for d in da_object if d.process_index == d.client.process_index()])
for aval, sharding in zip(avals, shardings): for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token: if aval is core.abstract_token:
@ -2426,16 +2442,13 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
_allow_propagation_to_outputs, host_callbacks, backend, _allow_propagation_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys, da, pmap_nreps, compiler_options_keys,
compiler_options_values): compiler_options_values):
device_assignment = da.device_assignment if isinstance(
da, _DeviceAssignment) else da
# TODO(phawkins): One would normally just write: # TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment) # dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices. # The formulation below is substantially faster if there are many devices.
# If we were to optimize __getattr__ on xc.Device we might not need this # If we were to optimize __getattr__ on xc.Device we might not need this
# workaround. # workaround.
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])( dev = np.vectorize(lambda i: da[i], otypes=[object])(
np.arange(len(device_assignment)) np.arange(len(da))
) )
if pmap_nreps > 1: if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1 num_replicas, num_partitions = pmap_nreps, 1
@ -2493,7 +2506,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
@dataclasses.dataclass @dataclasses.dataclass
class UnloadedMeshExecutable: class UnloadedMeshExecutable:
xla_executable: Any xla_executable: Any
device_assignment: _DeviceAssignment | Sequence[xc.Device] device_assignment: _DeviceAssignment | Sequence[xc.Device] # type: ignore
backend: xb.XlaBackend backend: xb.XlaBackend
input_avals: Sequence[ShapedArray] input_avals: Sequence[ShapedArray]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding] input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
@ -2550,7 +2563,7 @@ class UnloadedMeshExecutable:
keepalive: Any, keepalive: Any,
kept_var_idx: set[int], kept_var_idx: set[int],
backend: xb.XlaBackend, backend: xb.XlaBackend,
device_assignment: _DeviceAssignment | Sequence[xc.Device], device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore
committed: bool, committed: bool,
pmap_nreps: int = 1, pmap_nreps: int = 1,
jaxpr_debug_info: core.JaxprDebugInfo | None = None, jaxpr_debug_info: core.JaxprDebugInfo | None = None,
@ -2563,8 +2576,10 @@ class UnloadedMeshExecutable:
compiler_options.keys()) if compiler_options is not None else None compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple( compiler_options_values = tuple(
compiler_options.values()) if compiler_options is not None else None compiler_options.values()) if compiler_options is not None else None
da = device_assignment if isinstance( if isinstance(device_assignment, _DeviceAssignment):
device_assignment, _DeviceAssignment) else tuple(device_assignment) da = device_assignment
else:
da = _create_da_object(tuple(device_assignment))
del device_assignment del device_assignment
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings) allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
@ -2605,8 +2620,9 @@ class UnloadedMeshExecutable:
elif (out_shardings and any(is_unspecified(o) for o in out_shardings) elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1): and pmap_nreps == 1):
assert mesh is None assert mesh is None
device_assignment = da.device_assignment if isinstance( # type: ignore # TODO(yashkatariya): Make da directly usable in the downstream code
da, _DeviceAssignment) else da # without tuple conversion.
device_assignment = tuple(da)
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, len(global_out_avals)) # type: ignore xla_executable, device_assignment, len(global_out_avals)) # type: ignore
orig_out_shardings = out_shardings orig_out_shardings = out_shardings
@ -2707,9 +2723,9 @@ class MeshExecutable(stages.XlaExecutable):
backend, da_object, committed, kept_var_idx, 1) backend, da_object, committed, kept_var_idx, 1)
out_shardings = _out_shardings_for_trivial( out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object.device_assignment) jaxpr, consts, in_shardings, da_object)
indices = _get_input_indices(global_out_avals, out_shardings, da_object) indices = _get_input_indices(global_out_avals, out_shardings, da_object)
local_device_assignment = da_object.addressable_device_assignment local_device_assignment = da_object.addressable_device_list
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices) handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
handle_outs = global_avals_to_results_handler( handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed, global_out_avals, out_shardings, committed,
@ -2889,7 +2905,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend, jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
da_object, committed, kept_var_idx, pmap_nreps): da_object, committed, kept_var_idx, pmap_nreps):
out_shardings = _out_shardings_for_trivial( out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object.device_assignment) jaxpr, consts, in_shardings, da_object)
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
handle_outs = global_avals_to_results_handler( handle_outs = global_avals_to_results_handler(
@ -2898,7 +2914,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
# Use the standard out_handler. # Use the standard out_handler.
unsafe_call = backend.compile_replicated( unsafe_call = backend.compile_replicated(
is_trivial=True, jaxpr=jaxpr, consts=consts, is_trivial=True, jaxpr=jaxpr, consts=consts,
device_assignment=da_object.device_assignment, in_avals=global_in_avals, device_assignment=da_object, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings, in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx, out_handler=handle_outs, kept_var_idx=kept_var_idx, out_handler=handle_outs,
out_shardings=out_shardings, pmap_nreps=pmap_nreps) out_shardings=out_shardings, pmap_nreps=pmap_nreps)