From ce80a5480520dde98f0d21bb79064e91da2a36fd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 9 Aug 2022 06:58:42 -0700 Subject: [PATCH] Reshape sharding spec indices to the mesh shape to preserve the old semantics. PiperOrigin-RevId: 466346873 --- jax/interpreters/pxla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index b25a8b95d..a03585c31 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -282,7 +282,7 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray: if not has_unstacked: op_sharding_proto = sharding_spec_sharding_proto(self) return _op_sharding_to_numpy_indices( - op_sharding_proto, shape, prod(self.mesh_shape)) + op_sharding_proto, shape, prod(self.mesh_shape)).reshape(self.mesh_shape) axis_indices: List[Sequence[Index]] = [] shard_indices_shape = []