mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Optimize _create_copy_plan
in array.py
* `_get_device` is called from many tight loops, so it's worth avoiding unnecessary work as much as possible. * `_create_copy_plan` now uses sharding's `_internal_device_list` instead of querying the device of every shard in a loop. PiperOrigin-RevId: 624288637
This commit is contained in:
parent
44e83d4e0a
commit
3245455900
@ -56,8 +56,9 @@ Index = tuple[slice, ...]
|
||||
PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
def _get_device(a: ArrayImpl) -> Device:
|
||||
assert len(a.devices()) == 1
|
||||
return next(iter(a.devices()))
|
||||
devices = a.devices()
|
||||
assert len(devices) == 1
|
||||
return next(iter(devices))
|
||||
|
||||
|
||||
class Shard:
|
||||
@ -137,7 +138,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
|
||||
di_map = _cached_index_calc(s, shape)
|
||||
copy_plan = []
|
||||
for a in arrays:
|
||||
ind = di_map.get(_get_device(a), None)
|
||||
ind = di_map.get(a.sharding._internal_device_list[0], None) # type:ignore
|
||||
if ind is not None:
|
||||
copy_plan.append((ind, a))
|
||||
return copy_plan
|
||||
|
Loading…
x
Reference in New Issue
Block a user