mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Replace _DeviceAssignment with xc.DeviceList
PiperOrigin-RevId: 599597226
This commit is contained in:
parent
72f00ebaec
commit
71c9be14b8
@ -1849,13 +1849,10 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
nreps, tuple_args, lowering_result.shape_poly_state)
|
||||
|
||||
|
||||
_DeviceAssignment = xc.DeviceList
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _create_da_object( # pytype: disable=invalid-annotation
|
||||
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment: # type: ignore
|
||||
return _DeviceAssignment(device_assignment)
|
||||
device_assignment: tuple[xc.Device, ...]) -> xc.DeviceList: # type: ignore
|
||||
return xc.DeviceList(device_assignment)
|
||||
|
||||
|
||||
def jaxpr_transfer_mem_kinds(
|
||||
@ -2237,11 +2234,11 @@ if xla_extension_version < 229:
|
||||
def _get_input_indices(
|
||||
avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||
da_object: xc.DeviceList | Sequence[xc.Device], # type: ignore
|
||||
) -> Sequence[tuple[Index | None, ...]]:
|
||||
|
||||
input_indices = []
|
||||
if not isinstance(da_object, _DeviceAssignment):
|
||||
if not isinstance(da_object, xc.DeviceList):
|
||||
da_object = _create_da_object(tuple(da_object))
|
||||
num_addressable_devices = len(da_object.addressable_device_list)
|
||||
|
||||
@ -2568,7 +2565,7 @@ def _get_shardings_from_executable(
|
||||
@dataclasses.dataclass
|
||||
class UnloadedMeshExecutable:
|
||||
xla_executable: Any
|
||||
device_assignment: _DeviceAssignment | Sequence[xc.Device] # type: ignore
|
||||
device_assignment: xc.DeviceList | Sequence[xc.Device] # type: ignore
|
||||
backend: xb.XlaBackend
|
||||
input_avals: Sequence[ShapedArray]
|
||||
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
@ -2631,7 +2628,7 @@ class UnloadedMeshExecutable:
|
||||
keepalive: Any,
|
||||
kept_var_idx: set[int],
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: _DeviceAssignment | Sequence[xc.Device], # type: ignore
|
||||
device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore
|
||||
committed: bool,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
@ -2647,7 +2644,7 @@ class UnloadedMeshExecutable:
|
||||
compiler_options.keys()) if compiler_options is not None else None
|
||||
compiler_options_values = tuple(
|
||||
compiler_options.values()) if compiler_options is not None else None
|
||||
if isinstance(device_assignment, _DeviceAssignment):
|
||||
if isinstance(device_assignment, xc.DeviceList):
|
||||
da = device_assignment
|
||||
else:
|
||||
da = _create_da_object(tuple(device_assignment))
|
||||
|
Loading…
x
Reference in New Issue
Block a user