Replace _DeviceAssignment with xc.DeviceList

PiperOrigin-RevId: 599597226
This commit is contained in:
Yash Katariya 2024-01-18 12:54:54 -08:00 committed by jax authors
parent 72f00ebaec
commit 71c9be14b8

View File

@ -1849,13 +1849,10 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
nreps, tuple_args, lowering_result.shape_poly_state)
_DeviceAssignment = xc.DeviceList
@lru_cache(maxsize=2048)
def _create_da_object( # pytype: disable=invalid-annotation
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment: # type: ignore
return _DeviceAssignment(device_assignment)
device_assignment: tuple[xc.Device, ...]) -> xc.DeviceList: # type: ignore
return xc.DeviceList(device_assignment)
def jaxpr_transfer_mem_kinds(
@ -2237,11 +2234,11 @@ if xla_extension_version < 229:
def _get_input_indices(
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
da_object: xc.DeviceList | Sequence[xc.Device], # type: ignore
) -> Sequence[tuple[Index | None, ...]]:
input_indices = []
if not isinstance(da_object, _DeviceAssignment):
if not isinstance(da_object, xc.DeviceList):
da_object = _create_da_object(tuple(da_object))
num_addressable_devices = len(da_object.addressable_device_list)
@ -2568,7 +2565,7 @@ def _get_shardings_from_executable(
@dataclasses.dataclass
class UnloadedMeshExecutable:
xla_executable: Any
device_assignment: _DeviceAssignment | Sequence[xc.Device] # type: ignore
device_assignment: xc.DeviceList | Sequence[xc.Device] # type: ignore
backend: xb.XlaBackend
input_avals: Sequence[ShapedArray]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
@ -2631,7 +2628,7 @@ class UnloadedMeshExecutable:
keepalive: Any,
kept_var_idx: set[int],
backend: xb.XlaBackend,
device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore
device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore
committed: bool,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
@ -2647,7 +2644,7 @@ class UnloadedMeshExecutable:
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(
compiler_options.values()) if compiler_options is not None else None
if isinstance(device_assignment, _DeviceAssignment):
if isinstance(device_assignment, xc.DeviceList):
da = device_assignment
else:
da = _create_da_object(tuple(device_assignment))