mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove isinstance checks
PiperOrigin-RevId: 425745786
This commit is contained in:
parent
dcca99b052
commit
3acbd44952
@ -54,7 +54,7 @@ def _canonicalize_mesh_axes(mesh_axes):
|
||||
return pspec
|
||||
|
||||
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
|
||||
mesh_axes: MeshAxes) -> Tuple[Index, ...]:
|
||||
# Import here to avoid cyclic import error when importing gda in pjit.py.
|
||||
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
|
||||
|
||||
@ -66,11 +66,7 @@ def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
sharding_spec = pxla.mesh_sharding_specs(
|
||||
global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
|
||||
indices = pxla.spec_to_indices(global_shape, sharding_spec)
|
||||
for index in indices:
|
||||
assert isinstance(index, tuple)
|
||||
for idx in index:
|
||||
assert isinstance(idx, slice)
|
||||
return indices
|
||||
return indices # type: ignore
|
||||
|
||||
|
||||
@_convert_list_args_to_tuple
|
||||
|
Loading…
x
Reference in New Issue
Block a user