Avoid re-constructing set. Expensive at scale.

PiperOrigin-RevId: 521310375
This commit is contained in:
Qiao Zhang 2023-04-02 14:42:10 -07:00 committed by jax authors
parent b8dfb97e57
commit cf599c7d3e

View File

@ -189,13 +189,13 @@ class ArrayImpl(basearray.Array):
f"got {len(self._arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = set(d.id for d in addressable_dev)
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = set(array_device_ids) ^ set(addressable_device_ids)
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = set(addressable_device_ids) - set(array_device_ids)
dev_in_arrays_not_in_sharding = set(array_device_ids) - set(addressable_device_ids)
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays: