mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] Introduce DeviceList
backed by C++ xla::ifrt::DeviceList
This change adds `xla_client.DeviceList` that is implemented in C++ `jax::PyDeviceList`. `jax::PyDeviceList` implements the features of `pxla._DeviceAssignment` as a functional drop-in replacement. `jax::PyDeviceList` internally has `xla::ifrt::DeviceList`, which will be used when using IFRT APIs without having to construct a new copy of a potentially large device list. `pxla._DeviceAssignment`'s interface is changed slightly to encourage avoiding conversion to tuple. Note that for the backward compatibility (and fast `xla_client.Device` conversion), `jax::PyDeviceList` still uses a Python tuple whose element can be any Python object matches `xla_client.Device` interface with duck typing. This duck typing support will be removed when such use case is deprecated. Eventually, we can try to avoid any type conversion to remove a shadow copy of device list in JAX. PiperOrigin-RevId: 555317152
This commit is contained in:
parent
22a005c2a3
commit
97b96bbd4b
@ -24,7 +24,7 @@ from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
|
||||
from typing import (Any, Callable, NamedTuple, Iterator, Optional, Union, cast, TypeVar)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -58,6 +58,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -1789,7 +1790,9 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
global_in_avals = closed_jaxpr.in_avals
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
device_assignment = da_object.device_assignment
|
||||
# TODO(yashkatariya): Make device_assignment directly usable in the downstream
|
||||
# code without tuple conversion.
|
||||
device_assignment = tuple(da_object)
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
if logger.isEnabledFor(log_priority):
|
||||
@ -1862,37 +1865,52 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
nreps, tuple_args, lowering_result.shape_poly_state)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeviceAssignment:
|
||||
device_assignment: tuple[xc.Device, ...]
|
||||
if xla_extension_version >= 181:
|
||||
_DeviceAssignment = xc.DeviceList
|
||||
else:
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeviceAssignment: # type: ignore
|
||||
_device_assignment: tuple[xc.Device, ...]
|
||||
|
||||
@cached_property
|
||||
def _hash(self):
|
||||
return hash(self.device_assignment)
|
||||
@cached_property
|
||||
def _hash(self) -> int:
|
||||
return hash(self._device_assignment)
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
def __hash__(self) -> int:
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _DeviceAssignment):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
return (self.device_assignment == other.device_assignment)
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, _DeviceAssignment):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
return (self._device_assignment == other._device_assignment)
|
||||
|
||||
@cached_property
|
||||
def is_fully_addressable(self):
|
||||
return len(self.device_assignment) == len(self.addressable_device_assignment)
|
||||
def __len__(self) -> int:
|
||||
return len(self._device_assignment)
|
||||
|
||||
@cached_property
|
||||
def addressable_device_assignment(self):
|
||||
return [d for d in self.device_assignment
|
||||
if d.process_index == d.client.process_index()]
|
||||
def __getitem__(self, index: Any) -> Any:
|
||||
return self._device_assignment[index]
|
||||
|
||||
def __iter__(self) -> Iterator[xc.Device]:
|
||||
return iter(self._device_assignment)
|
||||
|
||||
@cached_property
|
||||
def is_fully_addressable(self) -> bool:
|
||||
return len(self._device_assignment) == len(
|
||||
self.addressable_device_list._device_assignment
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def addressable_device_list(self) -> _DeviceAssignment: # type: ignore
|
||||
return _create_da_object(
|
||||
tuple(d for d in self._device_assignment
|
||||
if d.process_index == d.client.process_index()))
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _create_da_object(
|
||||
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment:
|
||||
def _create_da_object( # pytype: disable=invalid-annotation
|
||||
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment: # type: ignore
|
||||
return _DeviceAssignment(device_assignment)
|
||||
|
||||
|
||||
@ -1961,7 +1979,7 @@ def lower_sharding_computation(
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
if not da_object.is_fully_addressable:
|
||||
if not da_object.is_fully_addressable: # type: ignore
|
||||
if inline and config.jax_spmd_mode != 'allow_all':
|
||||
raise RuntimeError(
|
||||
"Running operations on `Array`s that are not fully addressable by this "
|
||||
@ -2262,15 +2280,13 @@ def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
|
||||
def _get_input_indices(
|
||||
avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
da_object: _DeviceAssignment | Sequence[xc.Device],
|
||||
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||
) -> Sequence[tuple[Index | None, ...]]:
|
||||
|
||||
input_indices = []
|
||||
if isinstance(da_object, _DeviceAssignment):
|
||||
num_addressable_devices = len(da_object.addressable_device_assignment)
|
||||
else:
|
||||
num_addressable_devices = len(
|
||||
[d for d in da_object if d.process_index == d.client.process_index()])
|
||||
if not isinstance(da_object, _DeviceAssignment):
|
||||
da_object = _create_da_object(tuple(da_object))
|
||||
num_addressable_devices = len(da_object.addressable_device_list)
|
||||
|
||||
for aval, sharding in zip(avals, shardings):
|
||||
if aval is core.abstract_token:
|
||||
@ -2426,16 +2442,13 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
_allow_propagation_to_outputs, host_callbacks, backend,
|
||||
da, pmap_nreps, compiler_options_keys,
|
||||
compiler_options_values):
|
||||
device_assignment = da.device_assignment if isinstance(
|
||||
da, _DeviceAssignment) else da
|
||||
|
||||
# TODO(phawkins): One would normally just write:
|
||||
# dev = np.array(device_assignment)
|
||||
# The formulation below is substantially faster if there are many devices.
|
||||
# If we were to optimize __getattr__ on xc.Device we might not need this
|
||||
# workaround.
|
||||
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
|
||||
np.arange(len(device_assignment))
|
||||
dev = np.vectorize(lambda i: da[i], otypes=[object])(
|
||||
np.arange(len(da))
|
||||
)
|
||||
if pmap_nreps > 1:
|
||||
num_replicas, num_partitions = pmap_nreps, 1
|
||||
@ -2493,7 +2506,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
@dataclasses.dataclass
|
||||
class UnloadedMeshExecutable:
|
||||
xla_executable: Any
|
||||
device_assignment: _DeviceAssignment | Sequence[xc.Device]
|
||||
device_assignment: _DeviceAssignment | Sequence[xc.Device] # type: ignore
|
||||
backend: xb.XlaBackend
|
||||
input_avals: Sequence[ShapedArray]
|
||||
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
@ -2550,7 +2563,7 @@ class UnloadedMeshExecutable:
|
||||
keepalive: Any,
|
||||
kept_var_idx: set[int],
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: _DeviceAssignment | Sequence[xc.Device],
|
||||
device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||
committed: bool,
|
||||
pmap_nreps: int = 1,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
|
||||
@ -2563,8 +2576,10 @@ class UnloadedMeshExecutable:
|
||||
compiler_options.keys()) if compiler_options is not None else None
|
||||
compiler_options_values = tuple(
|
||||
compiler_options.values()) if compiler_options is not None else None
|
||||
da = device_assignment if isinstance(
|
||||
device_assignment, _DeviceAssignment) else tuple(device_assignment)
|
||||
if isinstance(device_assignment, _DeviceAssignment):
|
||||
da = device_assignment
|
||||
else:
|
||||
da = _create_da_object(tuple(device_assignment))
|
||||
del device_assignment
|
||||
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
|
||||
|
||||
@ -2605,8 +2620,9 @@ class UnloadedMeshExecutable:
|
||||
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
|
||||
and pmap_nreps == 1):
|
||||
assert mesh is None
|
||||
device_assignment = da.device_assignment if isinstance( # type: ignore
|
||||
da, _DeviceAssignment) else da
|
||||
# TODO(yashkatariya): Make da directly usable in the downstream code
|
||||
# without tuple conversion.
|
||||
device_assignment = tuple(da)
|
||||
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
xla_executable, device_assignment, len(global_out_avals)) # type: ignore
|
||||
orig_out_shardings = out_shardings
|
||||
@ -2707,9 +2723,9 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
backend, da_object, committed, kept_var_idx, 1)
|
||||
|
||||
out_shardings = _out_shardings_for_trivial(
|
||||
jaxpr, consts, in_shardings, da_object.device_assignment)
|
||||
jaxpr, consts, in_shardings, da_object)
|
||||
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
|
||||
local_device_assignment = da_object.addressable_device_assignment
|
||||
local_device_assignment = da_object.addressable_device_list
|
||||
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_out_avals, out_shardings, committed,
|
||||
@ -2889,7 +2905,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
|
||||
da_object, committed, kept_var_idx, pmap_nreps):
|
||||
out_shardings = _out_shardings_for_trivial(
|
||||
jaxpr, consts, in_shardings, da_object.device_assignment)
|
||||
jaxpr, consts, in_shardings, da_object)
|
||||
|
||||
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
@ -2898,7 +2914,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
# Use the standard out_handler.
|
||||
unsafe_call = backend.compile_replicated(
|
||||
is_trivial=True, jaxpr=jaxpr, consts=consts,
|
||||
device_assignment=da_object.device_assignment, in_avals=global_in_avals,
|
||||
device_assignment=da_object, in_avals=global_in_avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=kept_var_idx, out_handler=handle_outs,
|
||||
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
|
||||
|
Loading…
x
Reference in New Issue
Block a user