Optimize _create_copy_plan by iterating over only the shards that are needed for materialization

For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX.

The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`.

PiperOrigin-RevId: 624969222
This commit is contained in:
Junwhan Ahn 2024-04-15 08:29:02 -07:00 committed by jax authors
parent 3a09404426
commit ac1a53d8e4

View File

@ -125,23 +125,13 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
def _cached_index_calc(s, shape):
map_ = s.addressable_devices_indices_map(shape)
seen_h_indices = set()
m = {}
for d, index in map_.items():
l = []
for array_index, index in enumerate(map_.values()):
h_index = hashed_index(index)
if h_index not in seen_h_indices:
seen_h_indices.add(h_index)
m[d] = index
return m
def _create_copy_plan(arrays, s: Sharding, shape: Shape):
di_map = _cached_index_calc(s, shape)
copy_plan = []
for a in arrays:
ind = di_map.get(a.sharding._internal_device_list[0], None) # type:ignore
if ind is not None:
copy_plan.append((ind, a))
return copy_plan
l.append((array_index, index))
return l
@functools.lru_cache(maxsize=4096)
@ -607,9 +597,8 @@ class ArrayImpl(basearray.Array):
if self.is_fully_replicated:
self._copy_single_device_array_to_host_async()
return
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
for _, arr in copy_plan:
arr._copy_single_device_array_to_host_async()
for i, _ in _cached_index_calc(self.sharding, self.shape):
self._arrays[i]._copy_single_device_array_to_host_async()
@property
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
@ -631,13 +620,12 @@ class ArrayImpl(basearray.Array):
"`jax.experimental.multihost_utils.process_allgather` "
"for this use case.")
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
for _, arr in copy_plan:
arr._copy_single_device_array_to_host_async()
for i, _ in _cached_index_calc(self.sharding, self.shape):
self._arrays[i]._copy_single_device_array_to_host_async()
npy_value = np.empty(self.shape, self.dtype)
for ind, arr in copy_plan:
npy_value[ind] = arr._single_device_array_to_np_array()
for i, ind in _cached_index_calc(self.sharding, self.shape):
npy_value[ind] = self._arrays[i]._single_device_array_to_np_array()
self._npy_value = npy_value # type: ignore
self._npy_value.flags.writeable = False
# https://docs.python.org/3/library/typing.html#typing.cast