Fix ndarray comparision in sharding_impls.py

This commit is contained in:
Skye Wanderman-Milne 2023-04-12 20:43:57 +00:00
parent 49e68dbe80
commit e2f1e7d28e

View File

@ -535,7 +535,7 @@ class PositionalSharding(XLACompatibleSharding):
return False
if id(self) == id(other):
return True
all_ids_equal = bool(np.all(self._ids == other._ids))
all_ids_equal = np.array_equal(self._ids,other._ids)
if id(self._devices) == id(other._devices) and all_ids_equal:
return True
return self._devices == other._devices and all_ids_equal