mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix mypy error
This commit is contained in:
parent
6072d5993e
commit
f1fc2adfbd
@ -106,17 +106,17 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
|
||||
nonzero_indices = np.flatnonzero(local_slices)
|
||||
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
||||
subcube_indices.append(slice(start, end + 1))
|
||||
subcube_indices = tuple(subcube_indices)
|
||||
subcube_indices_tuple = tuple(subcube_indices)
|
||||
# We only end up with all conditions being true if the local devices formed a
|
||||
# subcube of the full array. This is because we were biased towards taking a
|
||||
# "hull" spanned by the devices, and in case the local devices don't form a
|
||||
# subcube that hull will contain non-local devices.
|
||||
if not is_local_device[subcube_indices].all():
|
||||
if not is_local_device[subcube_indices_tuple].all():
|
||||
raise ValueError(
|
||||
"When passing host local inputs to pjit or xmap, devices "
|
||||
"connected to a single host must form a contiguous subcube of the "
|
||||
"global device mesh")
|
||||
return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names)
|
||||
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)
|
||||
|
||||
|
||||
_mesh_object_dict = {} # type: ignore
|
||||
|
Loading…
x
Reference in New Issue
Block a user