mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Avoid re-constructing set. Expensive at scale.
PiperOrigin-RevId: 521310375
This commit is contained in:
parent
b8dfb97e57
commit
cf599c7d3e
@ -189,13 +189,13 @@ class ArrayImpl(basearray.Array):
|
||||
f"got {len(self._arrays)}")
|
||||
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = set(d.id for d in addressable_dev)
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = set(array_device_ids) ^ set(addressable_device_ids)
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = set(addressable_device_ids) - set(array_device_ids)
|
||||
dev_in_arrays_not_in_sharding = set(array_device_ids) - set(addressable_device_ids)
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
|
Loading…
x
Reference in New Issue
Block a user