Reshape sharding spec indices to the mesh shape to preserve the old semantics.

PiperOrigin-RevId: 466346873
This commit is contained in:
Yash Katariya 2022-08-09 06:58:42 -07:00 committed by jax authors
parent 870e8a2928
commit ce80a54805

View File

@ -282,7 +282,7 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
if not has_unstacked:
op_sharding_proto = sharding_spec_sharding_proto(self)
return _op_sharding_to_numpy_indices(
op_sharding_proto, shape, prod(self.mesh_shape))
op_sharding_proto, shape, prod(self.mesh_shape)).reshape(self.mesh_shape)
axis_indices: List[Sequence[Index]] = []
shard_indices_shape = []