From 71c9be14b8835694bb2613a0ccff5ce5b22edf93 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 18 Jan 2024 12:54:54 -0800 Subject: [PATCH] Replace _DeviceAssignment with xc.DeviceList PiperOrigin-RevId: 599597226 --- jax/_src/interpreters/pxla.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6d29dd8c6..321d5a4c8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1849,13 +1849,10 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, nreps, tuple_args, lowering_result.shape_poly_state) -_DeviceAssignment = xc.DeviceList - - @lru_cache(maxsize=2048) def _create_da_object( # pytype: disable=invalid-annotation - device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment: # type: ignore - return _DeviceAssignment(device_assignment) + device_assignment: tuple[xc.Device, ...]) -> xc.DeviceList: # type: ignore + return xc.DeviceList(device_assignment) def jaxpr_transfer_mem_kinds( @@ -2237,11 +2234,11 @@ if xla_extension_version < 229: def _get_input_indices( avals: Sequence[ShapedArray], shardings: Sequence[sharding_impls.XLACompatibleSharding], - da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore + da_object: xc.DeviceList | Sequence[xc.Device], # type: ignore ) -> Sequence[tuple[Index | None, ...]]: input_indices = [] - if not isinstance(da_object, _DeviceAssignment): + if not isinstance(da_object, xc.DeviceList): da_object = _create_da_object(tuple(da_object)) num_addressable_devices = len(da_object.addressable_device_list) @@ -2568,7 +2565,7 @@ def _get_shardings_from_executable( @dataclasses.dataclass class UnloadedMeshExecutable: xla_executable: Any - device_assignment: _DeviceAssignment | Sequence[xc.Device] # type: ignore + device_assignment: xc.DeviceList | Sequence[xc.Device] # type: ignore backend: xb.XlaBackend input_avals: Sequence[ShapedArray] input_shardings: Sequence[sharding_impls.XLACompatibleSharding] @@ -2631,7 +2628,7 @@ class UnloadedMeshExecutable: keepalive: Any, kept_var_idx: set[int], backend: xb.XlaBackend, - device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore + device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore committed: bool, in_layouts: MaybeLayout, out_layouts: MaybeLayout, @@ -2647,7 +2644,7 @@ class UnloadedMeshExecutable: compiler_options.keys()) if compiler_options is not None else None compiler_options_values = tuple( compiler_options.values()) if compiler_options is not None else None - if isinstance(device_assignment, _DeviceAssignment): + if isinstance(device_assignment, xc.DeviceList): da = device_assignment else: da = _create_da_object(tuple(device_assignment))