Fix mypy error

This commit is contained in:
Jake VanderPlas 2023-08-29 13:25:12 -07:00
parent 6072d5993e
commit f1fc2adfbd

View File

@ -106,17 +106,17 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
nonzero_indices = np.flatnonzero(local_slices)
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
subcube_indices.append(slice(start, end + 1))
subcube_indices = tuple(subcube_indices)
subcube_indices_tuple = tuple(subcube_indices)
# We only end up with all conditions being true if the local devices formed a
# subcube of the full array. This is because we were biased towards taking a
# "hull" spanned by the devices, and in case the local devices don't form a
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
if not is_local_device[subcube_indices_tuple].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names)
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)
_mesh_object_dict = {} # type: ignore