mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX] Introduce DeviceList
backed by C++ xla::ifrt::DeviceList
This change adds `xla_client.DeviceList` that is implemented in C++ `jax::PyDeviceList`. `jax::PyDeviceList` implements the features of `pxla._DeviceAssignment` as a functional drop-in replacement. `jax::PyDeviceList` internally has `xla::ifrt::DeviceList`, which will be used when using IFRT APIs without having to construct a new copy of a potentially large device list. `pxla._DeviceAssignment`'s interface is changed slightly to encourage avoiding conversion to tuple. Note that for the backward compatibility (and fast `xla_client.Device` conversion), `jax::PyDeviceList` still uses a Python tuple whose element can be any Python object matches `xla_client.Device` interface with duck typing. This duck typing support will be removed when such use case is deprecated. Eventually, we can try to avoid any type conversion to remove a shadow copy of device list in JAX. PiperOrigin-RevId: 555317152
This commit is contained in:
parent
22a005c2a3
commit
97b96bbd4b
@ -24,7 +24,7 @@ from functools import partial, lru_cache, cached_property
|
|||||||
import itertools as it
|
import itertools as it
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
|
from typing import (Any, Callable, NamedTuple, Iterator, Optional, Union, cast, TypeVar)
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -58,6 +58,7 @@ from jax._src.interpreters import partial_eval as pe
|
|||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
from jax._src.interpreters import xla
|
from jax._src.interpreters import xla
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
|
from jax._src.lib import xla_extension_version
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.lib.mlir.dialects import hlo
|
from jax._src.lib.mlir.dialects import hlo
|
||||||
from jax._src.partition_spec import PartitionSpec
|
from jax._src.partition_spec import PartitionSpec
|
||||||
@ -1789,7 +1790,9 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|||||||
out_shardings = semantic_out_shardings.shardings
|
out_shardings = semantic_out_shardings.shardings
|
||||||
global_in_avals = closed_jaxpr.in_avals
|
global_in_avals = closed_jaxpr.in_avals
|
||||||
global_out_avals = closed_jaxpr.out_avals
|
global_out_avals = closed_jaxpr.out_avals
|
||||||
device_assignment = da_object.device_assignment
|
# TODO(yashkatariya): Make device_assignment directly usable in the downstream
|
||||||
|
# code without tuple conversion.
|
||||||
|
device_assignment = tuple(da_object)
|
||||||
|
|
||||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||||
if logger.isEnabledFor(log_priority):
|
if logger.isEnabledFor(log_priority):
|
||||||
@ -1862,37 +1865,52 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|||||||
nreps, tuple_args, lowering_result.shape_poly_state)
|
nreps, tuple_args, lowering_result.shape_poly_state)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
if xla_extension_version >= 181:
|
||||||
class _DeviceAssignment:
|
_DeviceAssignment = xc.DeviceList
|
||||||
device_assignment: tuple[xc.Device, ...]
|
else:
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class _DeviceAssignment: # type: ignore
|
||||||
|
_device_assignment: tuple[xc.Device, ...]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _hash(self):
|
def _hash(self) -> int:
|
||||||
return hash(self.device_assignment)
|
return hash(self._device_assignment)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return self._hash
|
return self._hash
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: Any) -> bool:
|
||||||
if not isinstance(other, _DeviceAssignment):
|
if not isinstance(other, _DeviceAssignment):
|
||||||
return False
|
return False
|
||||||
if id(self) == id(other):
|
if id(self) == id(other):
|
||||||
return True
|
return True
|
||||||
return (self.device_assignment == other.device_assignment)
|
return (self._device_assignment == other._device_assignment)
|
||||||
|
|
||||||
@cached_property
|
def __len__(self) -> int:
|
||||||
def is_fully_addressable(self):
|
return len(self._device_assignment)
|
||||||
return len(self.device_assignment) == len(self.addressable_device_assignment)
|
|
||||||
|
|
||||||
@cached_property
|
def __getitem__(self, index: Any) -> Any:
|
||||||
def addressable_device_assignment(self):
|
return self._device_assignment[index]
|
||||||
return [d for d in self.device_assignment
|
|
||||||
if d.process_index == d.client.process_index()]
|
def __iter__(self) -> Iterator[xc.Device]:
|
||||||
|
return iter(self._device_assignment)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def is_fully_addressable(self) -> bool:
|
||||||
|
return len(self._device_assignment) == len(
|
||||||
|
self.addressable_device_list._device_assignment
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def addressable_device_list(self) -> _DeviceAssignment: # type: ignore
|
||||||
|
return _create_da_object(
|
||||||
|
tuple(d for d in self._device_assignment
|
||||||
|
if d.process_index == d.client.process_index()))
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=2048)
|
@lru_cache(maxsize=2048)
|
||||||
def _create_da_object(
|
def _create_da_object( # pytype: disable=invalid-annotation
|
||||||
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment:
|
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment: # type: ignore
|
||||||
return _DeviceAssignment(device_assignment)
|
return _DeviceAssignment(device_assignment)
|
||||||
|
|
||||||
|
|
||||||
@ -1961,7 +1979,7 @@ def lower_sharding_computation(
|
|||||||
|
|
||||||
da_object = _create_da_object(tuple(device_assignment))
|
da_object = _create_da_object(tuple(device_assignment))
|
||||||
|
|
||||||
if not da_object.is_fully_addressable:
|
if not da_object.is_fully_addressable: # type: ignore
|
||||||
if inline and config.jax_spmd_mode != 'allow_all':
|
if inline and config.jax_spmd_mode != 'allow_all':
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Running operations on `Array`s that are not fully addressable by this "
|
"Running operations on `Array`s that are not fully addressable by this "
|
||||||
@ -2262,15 +2280,13 @@ def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
|
|||||||
def _get_input_indices(
|
def _get_input_indices(
|
||||||
avals: Sequence[ShapedArray],
|
avals: Sequence[ShapedArray],
|
||||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||||
da_object: _DeviceAssignment | Sequence[xc.Device],
|
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||||
) -> Sequence[tuple[Index | None, ...]]:
|
) -> Sequence[tuple[Index | None, ...]]:
|
||||||
|
|
||||||
input_indices = []
|
input_indices = []
|
||||||
if isinstance(da_object, _DeviceAssignment):
|
if not isinstance(da_object, _DeviceAssignment):
|
||||||
num_addressable_devices = len(da_object.addressable_device_assignment)
|
da_object = _create_da_object(tuple(da_object))
|
||||||
else:
|
num_addressable_devices = len(da_object.addressable_device_list)
|
||||||
num_addressable_devices = len(
|
|
||||||
[d for d in da_object if d.process_index == d.client.process_index()])
|
|
||||||
|
|
||||||
for aval, sharding in zip(avals, shardings):
|
for aval, sharding in zip(avals, shardings):
|
||||||
if aval is core.abstract_token:
|
if aval is core.abstract_token:
|
||||||
@ -2426,16 +2442,13 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
|||||||
_allow_propagation_to_outputs, host_callbacks, backend,
|
_allow_propagation_to_outputs, host_callbacks, backend,
|
||||||
da, pmap_nreps, compiler_options_keys,
|
da, pmap_nreps, compiler_options_keys,
|
||||||
compiler_options_values):
|
compiler_options_values):
|
||||||
device_assignment = da.device_assignment if isinstance(
|
|
||||||
da, _DeviceAssignment) else da
|
|
||||||
|
|
||||||
# TODO(phawkins): One would normally just write:
|
# TODO(phawkins): One would normally just write:
|
||||||
# dev = np.array(device_assignment)
|
# dev = np.array(device_assignment)
|
||||||
# The formulation below is substantially faster if there are many devices.
|
# The formulation below is substantially faster if there are many devices.
|
||||||
# If we were to optimize __getattr__ on xc.Device we might not need this
|
# If we were to optimize __getattr__ on xc.Device we might not need this
|
||||||
# workaround.
|
# workaround.
|
||||||
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
|
dev = np.vectorize(lambda i: da[i], otypes=[object])(
|
||||||
np.arange(len(device_assignment))
|
np.arange(len(da))
|
||||||
)
|
)
|
||||||
if pmap_nreps > 1:
|
if pmap_nreps > 1:
|
||||||
num_replicas, num_partitions = pmap_nreps, 1
|
num_replicas, num_partitions = pmap_nreps, 1
|
||||||
@ -2493,7 +2506,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class UnloadedMeshExecutable:
|
class UnloadedMeshExecutable:
|
||||||
xla_executable: Any
|
xla_executable: Any
|
||||||
device_assignment: _DeviceAssignment | Sequence[xc.Device]
|
device_assignment: _DeviceAssignment | Sequence[xc.Device] # type: ignore
|
||||||
backend: xb.XlaBackend
|
backend: xb.XlaBackend
|
||||||
input_avals: Sequence[ShapedArray]
|
input_avals: Sequence[ShapedArray]
|
||||||
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||||
@ -2550,7 +2563,7 @@ class UnloadedMeshExecutable:
|
|||||||
keepalive: Any,
|
keepalive: Any,
|
||||||
kept_var_idx: set[int],
|
kept_var_idx: set[int],
|
||||||
backend: xb.XlaBackend,
|
backend: xb.XlaBackend,
|
||||||
device_assignment: _DeviceAssignment | Sequence[xc.Device],
|
device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||||
committed: bool,
|
committed: bool,
|
||||||
pmap_nreps: int = 1,
|
pmap_nreps: int = 1,
|
||||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
|
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
|
||||||
@ -2563,8 +2576,10 @@ class UnloadedMeshExecutable:
|
|||||||
compiler_options.keys()) if compiler_options is not None else None
|
compiler_options.keys()) if compiler_options is not None else None
|
||||||
compiler_options_values = tuple(
|
compiler_options_values = tuple(
|
||||||
compiler_options.values()) if compiler_options is not None else None
|
compiler_options.values()) if compiler_options is not None else None
|
||||||
da = device_assignment if isinstance(
|
if isinstance(device_assignment, _DeviceAssignment):
|
||||||
device_assignment, _DeviceAssignment) else tuple(device_assignment)
|
da = device_assignment
|
||||||
|
else:
|
||||||
|
da = _create_da_object(tuple(device_assignment))
|
||||||
del device_assignment
|
del device_assignment
|
||||||
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
|
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
|
||||||
|
|
||||||
@ -2605,8 +2620,9 @@ class UnloadedMeshExecutable:
|
|||||||
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
|
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
|
||||||
and pmap_nreps == 1):
|
and pmap_nreps == 1):
|
||||||
assert mesh is None
|
assert mesh is None
|
||||||
device_assignment = da.device_assignment if isinstance( # type: ignore
|
# TODO(yashkatariya): Make da directly usable in the downstream code
|
||||||
da, _DeviceAssignment) else da
|
# without tuple conversion.
|
||||||
|
device_assignment = tuple(da)
|
||||||
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||||
xla_executable, device_assignment, len(global_out_avals)) # type: ignore
|
xla_executable, device_assignment, len(global_out_avals)) # type: ignore
|
||||||
orig_out_shardings = out_shardings
|
orig_out_shardings = out_shardings
|
||||||
@ -2707,9 +2723,9 @@ class MeshExecutable(stages.XlaExecutable):
|
|||||||
backend, da_object, committed, kept_var_idx, 1)
|
backend, da_object, committed, kept_var_idx, 1)
|
||||||
|
|
||||||
out_shardings = _out_shardings_for_trivial(
|
out_shardings = _out_shardings_for_trivial(
|
||||||
jaxpr, consts, in_shardings, da_object.device_assignment)
|
jaxpr, consts, in_shardings, da_object)
|
||||||
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
|
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
|
||||||
local_device_assignment = da_object.addressable_device_assignment
|
local_device_assignment = da_object.addressable_device_list
|
||||||
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
|
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
|
||||||
handle_outs = global_avals_to_results_handler(
|
handle_outs = global_avals_to_results_handler(
|
||||||
global_out_avals, out_shardings, committed,
|
global_out_avals, out_shardings, committed,
|
||||||
@ -2889,7 +2905,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
|||||||
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
|
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
|
||||||
da_object, committed, kept_var_idx, pmap_nreps):
|
da_object, committed, kept_var_idx, pmap_nreps):
|
||||||
out_shardings = _out_shardings_for_trivial(
|
out_shardings = _out_shardings_for_trivial(
|
||||||
jaxpr, consts, in_shardings, da_object.device_assignment)
|
jaxpr, consts, in_shardings, da_object)
|
||||||
|
|
||||||
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
|
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
|
||||||
handle_outs = global_avals_to_results_handler(
|
handle_outs = global_avals_to_results_handler(
|
||||||
@ -2898,7 +2914,7 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
|||||||
# Use the standard out_handler.
|
# Use the standard out_handler.
|
||||||
unsafe_call = backend.compile_replicated(
|
unsafe_call = backend.compile_replicated(
|
||||||
is_trivial=True, jaxpr=jaxpr, consts=consts,
|
is_trivial=True, jaxpr=jaxpr, consts=consts,
|
||||||
device_assignment=da_object.device_assignment, in_avals=global_in_avals,
|
device_assignment=da_object, in_avals=global_in_avals,
|
||||||
in_indices=input_indices, in_shardings=in_shardings,
|
in_indices=input_indices, in_shardings=in_shardings,
|
||||||
kept_var_idx=kept_var_idx, out_handler=handle_outs,
|
kept_var_idx=kept_var_idx, out_handler=handle_outs,
|
||||||
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
|
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user