Remove isinstance checks

PiperOrigin-RevId: 425745786
This commit is contained in:
Yash Katariya 2022-02-01 16:38:12 -08:00 committed by jax authors
parent dcca99b052
commit 3acbd44952

View File

@ -54,7 +54,7 @@ def _canonicalize_mesh_axes(mesh_axes):
return pspec
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
mesh_axes: MeshAxes) -> Tuple[Index, ...]:
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
@ -66,11 +66,7 @@ def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
sharding_spec = pxla.mesh_sharding_specs(
global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
indices = pxla.spec_to_indices(global_shape, sharding_spec)
for index in indices:
assert isinstance(index, tuple)
for idx in index:
assert isinstance(idx, slice)
return indices
return indices # type: ignore
@_convert_list_args_to_tuple