mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reshape sharding spec indices to the mesh shape to preserve the old semantics.
PiperOrigin-RevId: 466346873
This commit is contained in:
parent
870e8a2928
commit
ce80a54805
@ -282,7 +282,7 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
if not has_unstacked:
|
||||
op_sharding_proto = sharding_spec_sharding_proto(self)
|
||||
return _op_sharding_to_numpy_indices(
|
||||
op_sharding_proto, shape, prod(self.mesh_shape))
|
||||
op_sharding_proto, shape, prod(self.mesh_shape)).reshape(self.mesh_shape)
|
||||
|
||||
axis_indices: List[Sequence[Index]] = []
|
||||
shard_indices_shape = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user