mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix ndarray comparision in sharding_impls.py
This commit is contained in:
parent
49e68dbe80
commit
e2f1e7d28e
@ -535,7 +535,7 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
all_ids_equal = bool(np.all(self._ids == other._ids))
|
||||
all_ids_equal = np.array_equal(self._ids,other._ids)
|
||||
if id(self._devices) == id(other._devices) and all_ids_equal:
|
||||
return True
|
||||
return self._devices == other._devices and all_ids_equal
|
||||
|
Loading…
x
Reference in New Issue
Block a user