Optimize _create_copy_plan in array.py

* `_get_device` is called from many tight loops, so it's worth avoiding unnecessary work as much as possible.
* `_create_copy_plan` now uses sharding's `_internal_device_list` instead of querying the device of every shard in a loop.

PiperOrigin-RevId: 624288637
This commit is contained in:
Junwhan Ahn 2024-04-12 14:11:37 -07:00 committed by jax authors
parent 44e83d4e0a
commit 3245455900

View File

@ -56,8 +56,9 @@ Index = tuple[slice, ...]
PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
def _get_device(a: ArrayImpl) -> Device:
assert len(a.devices()) == 1
return next(iter(a.devices()))
devices = a.devices()
assert len(devices) == 1
return next(iter(devices))
class Shard:
@ -137,7 +138,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
di_map = _cached_index_calc(s, shape)
copy_plan = []
for a in arrays:
ind = di_map.get(_get_device(a), None)
ind = di_map.get(a.sharding._internal_device_list[0], None) # type:ignore
if ind is not None:
copy_plan.append((ind, a))
return copy_plan